diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-02 15:12:48 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-02 15:12:48 +0100 |
commit | a1a48721448e54eb2816e014bd01403aa26a6cdf (patch) | |
tree | 0038eff21d5bec0c079f698d7615b12b3c889586 | |
parent | a4ec801cba35f2f4eab720a3ec1b6df4a15a146f (diff) |
ModelAttribute: remove get_fitted(), use .model_function, .model_info instead
-rw-r--r-- | lib/functions.py | 43 | ||||
-rw-r--r-- | lib/model.py | 13 | ||||
-rw-r--r-- | lib/parameters.py | 63 |
3 files changed, 62 insertions, 57 deletions
diff --git a/lib/functions.py b/lib/functions.py index 7af1d22..8e43dcb 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -155,26 +155,7 @@ class NormalizationFunction: class ModelInfo: def __init__(self): - pass - - -class StaticInfo: - def __init__(self, data): - self.mean = np.mean(data) - self.median = np.median(data) - self.std = np.std(data) - - -class AnalyticInfo(ModelInfo): - def __init__(self, fit_result, function): - self.fit_result = fit_result - self.function = function - - -class SplitInfo(ModelInfo): - def __init__(self, param_index, child): - self.param_index = param_index - self.child = child + self.error = None class ModelFunction: @@ -210,6 +191,14 @@ class StaticFunction(ModelFunction): return self.value +class StaticInfo(ModelInfo): + def __init__(self, data): + super() + self.mean = np.mean(data) + self.median = np.median(data) + self.std = np.std(data) + + class SplitFunction(ModelFunction): def __init__(self, param_index, child): self.param_index = param_index @@ -235,6 +224,13 @@ class SplitFunction(ModelFunction): return self.child[param_value].eval(param_list, arg_list) +class SplitInfo(ModelInfo): + def __init__(self, param_index, child): + super() + self.param_index = param_index + self.child = child + + class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. @@ -420,6 +416,13 @@ class AnalyticFunction(ModelFunction): return self._function(self.model_args, param_list) +class AnalyticInfo(ModelInfo): + def __init__(self, fit_result, function): + super() + self.fit_result = fit_result + self.function = function + + class analytic: """ Utilities for analytic description of parameter-dependent model attributes and regression analysis. diff --git a/lib/model.py b/lib/model.py index e97cfbb..4b0f46d 100644 --- a/lib/model.py +++ b/lib/model.py @@ -254,25 +254,26 @@ class AnalyticModel: static_model[name][k] = v.get_static(use_mean=use_mean) def model_getter(name, key, **kwargs): - param_function, param_info = self.attr_by_name[name][key].get_fitted() + model_function = self.attr_by_name[name][key].model_function + model_info = self.attr_by_name[name][key].model_info - if type(param_info) is StaticInfo: + # shortcut + if type(model_info) is StaticInfo: return static_model[name][key] if "arg" in kwargs and "param" in kwargs: kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - if param_function.is_predictable(kwargs["param"]): - return param_function.eval(kwargs["param"]) + if model_function.is_predictable(kwargs["param"]): + return model_function.eval(kwargs["param"]) return static_model[name][key] def info_getter(name, key): try: - model_function, model_info = self.attr_by_name[name][key].get_fitted() + return self.attr_by_name[name][key].model_info except KeyError: return None - return model_info return model_getter, info_getter diff --git a/lib/parameters.py b/lib/parameters.py index cf76f00..368b24c 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -568,25 +568,42 @@ class ParamStats: class ModelAttribute: def __init__(self, name, attr, data, param_values, param_names, arg_count=0): + + # Data for model generation + self.data = np.array(data) + + # Meta data self.name = name self.attr = attr - self.data = np.array(data) self.param_values = param_values self.param_names = sorted(param_names) self.arg_count = arg_count + + # Static model used as lower bound of model accuracy + self.mean = np.mean(data) + self.median = np.median(data) + + # LUT model used as upper bound of model accuracy self.by_param = None # set via ParallelParamStats - self.function_override = None - self.param_model = None + + # Split (decision tree) information self.split = None + # param model override + self.function_override = None + + # The best model we have. May be Static, Split, or Param (and later perhaps Substate) + self.model_function = None + self.model_info = None + def __repr__(self): mean = np.mean(self.data) return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>" def get_static(self, use_mean=False): if use_mean: - return np.mean(self.data) - return np.median(self.data) + return self.mean + return self.median def get_lut(self, param, use_mean=False): if use_mean: @@ -757,24 +774,22 @@ class ModelAttribute: info_child = dict() for param_value, child in child_by_param_value.items(): child.set_data_from_paramfit(paramfit, prefix + (param_value,)) - function_child[param_value], info_child[param_value] = child.get_fitted() - function_map = df.SplitFunction(split_param_index, function_child) - info_map = df.SplitInfo(split_param_index, info_child) - - self.param_model = function_map, info_map + function_child[param_value] = child.model_function + info_child[param_value] = child.model_info + self.model_function = df.SplitFunction(split_param_index, function_child) + self.model_info = df.SplitInfo(split_param_index, info_child) def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) - param_model = ( - df.StaticFunction(np.median(self.data)), - df.StaticInfo(self.data), - ) + self.model_function = df.StaticFunction(self.median) + self.model_info = df.StaticInfo(self.data) if self.function_override is not None: function_str = self.function_override x = df.AnalyticFunction(function_str, self.param_names, self.arg_count) x.fit(self.by_param) if x.fit_success: - param_model = (x, df.AnalyticInfo(fit_result, x)) + self.model_function = x + self.model_info = df.AnalyticInfo(fit_result, x) elif os.getenv("DFATOOL_NO_PARAM"): pass elif len(fit_result.keys()): @@ -784,19 +799,5 @@ class ModelAttribute: x.fit(self.by_param) if x.fit_success: - param_model = (x, df.AnalyticInfo(fit_result, x)) - - self.param_model = param_model - - def get_fitted(self): - """ - Get paramete-aware model function and model information function. - They must have been set via get_data_for_paramfit -> ParallelParamFit -> set-data_from_paramfit first. - - Returns a tuple (function, info): - function -> AnalyticFunction for model. function(param=parameter values) -> model value. - info -> {'fit_result' : ..., 'function' : ... } - - Returns (None, None) if fitting failed. Returns None if ParamFit has not been performed yet. - """ - return self.param_model + self.model_function = x + self.model_info = df.AnalyticInfo(fit_result, x) |