summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-03-02 15:12:48 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-03-02 15:12:48 +0100
commita1a48721448e54eb2816e014bd01403aa26a6cdf (patch)
tree0038eff21d5bec0c079f698d7615b12b3c889586
parenta4ec801cba35f2f4eab720a3ec1b6df4a15a146f (diff)
ModelAttribute: remove get_fitted(), use .model_function, .model_info instead
-rw-r--r--lib/functions.py43
-rw-r--r--lib/model.py13
-rw-r--r--lib/parameters.py63
3 files changed, 62 insertions, 57 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 7af1d22..8e43dcb 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -155,26 +155,7 @@ class NormalizationFunction:
class ModelInfo:
def __init__(self):
- pass
-
-
-class StaticInfo:
- def __init__(self, data):
- self.mean = np.mean(data)
- self.median = np.median(data)
- self.std = np.std(data)
-
-
-class AnalyticInfo(ModelInfo):
- def __init__(self, fit_result, function):
- self.fit_result = fit_result
- self.function = function
-
-
-class SplitInfo(ModelInfo):
- def __init__(self, param_index, child):
- self.param_index = param_index
- self.child = child
+ self.error = None
class ModelFunction:
@@ -210,6 +191,14 @@ class StaticFunction(ModelFunction):
return self.value
+class StaticInfo(ModelInfo):
+ def __init__(self, data):
+ super()
+ self.mean = np.mean(data)
+ self.median = np.median(data)
+ self.std = np.std(data)
+
+
class SplitFunction(ModelFunction):
def __init__(self, param_index, child):
self.param_index = param_index
@@ -235,6 +224,13 @@ class SplitFunction(ModelFunction):
return self.child[param_value].eval(param_list, arg_list)
+class SplitInfo(ModelInfo):
+ def __init__(self, param_index, child):
+ super()
+ self.param_index = param_index
+ self.child = child
+
+
class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.
@@ -420,6 +416,13 @@ class AnalyticFunction(ModelFunction):
return self._function(self.model_args, param_list)
+class AnalyticInfo(ModelInfo):
+ def __init__(self, fit_result, function):
+ super()
+ self.fit_result = fit_result
+ self.function = function
+
+
class analytic:
"""
Utilities for analytic description of parameter-dependent model attributes and regression analysis.
diff --git a/lib/model.py b/lib/model.py
index e97cfbb..4b0f46d 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -254,25 +254,26 @@ class AnalyticModel:
static_model[name][k] = v.get_static(use_mean=use_mean)
def model_getter(name, key, **kwargs):
- param_function, param_info = self.attr_by_name[name][key].get_fitted()
+ model_function = self.attr_by_name[name][key].model_function
+ model_info = self.attr_by_name[name][key].model_info
- if type(param_info) is StaticInfo:
+ # shortcut
+ if type(model_info) is StaticInfo:
return static_model[name][key]
if "arg" in kwargs and "param" in kwargs:
kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
- if param_function.is_predictable(kwargs["param"]):
- return param_function.eval(kwargs["param"])
+ if model_function.is_predictable(kwargs["param"]):
+ return model_function.eval(kwargs["param"])
return static_model[name][key]
def info_getter(name, key):
try:
- model_function, model_info = self.attr_by_name[name][key].get_fitted()
+ return self.attr_by_name[name][key].model_info
except KeyError:
return None
- return model_info
return model_getter, info_getter
diff --git a/lib/parameters.py b/lib/parameters.py
index cf76f00..368b24c 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -568,25 +568,42 @@ class ParamStats:
class ModelAttribute:
def __init__(self, name, attr, data, param_values, param_names, arg_count=0):
+
+ # Data for model generation
+ self.data = np.array(data)
+
+ # Meta data
self.name = name
self.attr = attr
- self.data = np.array(data)
self.param_values = param_values
self.param_names = sorted(param_names)
self.arg_count = arg_count
+
+ # Static model used as lower bound of model accuracy
+ self.mean = np.mean(data)
+ self.median = np.median(data)
+
+ # LUT model used as upper bound of model accuracy
self.by_param = None # set via ParallelParamStats
- self.function_override = None
- self.param_model = None
+
+ # Split (decision tree) information
self.split = None
+ # param model override
+ self.function_override = None
+
+ # The best model we have. May be Static, Split, or Param (and later perhaps Substate)
+ self.model_function = None
+ self.model_info = None
+
def __repr__(self):
mean = np.mean(self.data)
return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>"
def get_static(self, use_mean=False):
if use_mean:
- return np.mean(self.data)
- return np.median(self.data)
+ return self.mean
+ return self.median
def get_lut(self, param, use_mean=False):
if use_mean:
@@ -757,24 +774,22 @@ class ModelAttribute:
info_child = dict()
for param_value, child in child_by_param_value.items():
child.set_data_from_paramfit(paramfit, prefix + (param_value,))
- function_child[param_value], info_child[param_value] = child.get_fitted()
- function_map = df.SplitFunction(split_param_index, function_child)
- info_map = df.SplitInfo(split_param_index, info_child)
-
- self.param_model = function_map, info_map
+ function_child[param_value] = child.model_function
+ info_child[param_value] = child.model_info
+ self.model_function = df.SplitFunction(split_param_index, function_child)
+ self.model_info = df.SplitInfo(split_param_index, info_child)
def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
- param_model = (
- df.StaticFunction(np.median(self.data)),
- df.StaticInfo(self.data),
- )
+ self.model_function = df.StaticFunction(self.median)
+ self.model_info = df.StaticInfo(self.data)
if self.function_override is not None:
function_str = self.function_override
x = df.AnalyticFunction(function_str, self.param_names, self.arg_count)
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, df.AnalyticInfo(fit_result, x))
+ self.model_function = x
+ self.model_info = df.AnalyticInfo(fit_result, x)
elif os.getenv("DFATOOL_NO_PARAM"):
pass
elif len(fit_result.keys()):
@@ -784,19 +799,5 @@ class ModelAttribute:
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, df.AnalyticInfo(fit_result, x))
-
- self.param_model = param_model
-
- def get_fitted(self):
- """
- Get paramete-aware model function and model information function.
- They must have been set via get_data_for_paramfit -> ParallelParamFit -> set-data_from_paramfit first.
-
- Returns a tuple (function, info):
- function -> AnalyticFunction for model. function(param=parameter values) -> model value.
- info -> {'fit_result' : ..., 'function' : ... }
-
- Returns (None, None) if fitting failed. Returns None if ParamFit has not been performed yet.
- """
- return self.param_model
+ self.model_function = x
+ self.model_info = df.AnalyticInfo(fit_result, x)