diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-26 16:02:19 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-26 16:02:19 +0100 |
commit | 32bcad3482781e7e2e42c5de10d938c1567b8390 (patch) | |
tree | 3bbb58740d04c789f549de50dce1f0cc2a45480d /lib | |
parent | 21698b9915f02216a1afa5afb36b56f65f30b8ca (diff) |
refactor param_info, show splits in analyze-archive output
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 77 | ||||
-rw-r--r-- | lib/model.py | 45 |
2 files changed, 93 insertions, 29 deletions
diff --git a/lib/functions.py b/lib/functions.py index 0bdea45..067514f 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -152,7 +152,82 @@ class NormalizationFunction: return self._function(param_value) -class AnalyticFunction: +class ModelInfo: + def __init__(self): + pass + + +class AnalyticInfo(ModelInfo): + def __init__(self, fit_result, function): + self.fit_result = fit_result + self.function = function + + +class SplitInfo(ModelInfo): + def __init__(self, param_index, child): + self.param_index = param_index + self.child = child + + +class ModelFunction: + def __init__(self): + pass + + def is_predictable(self, param_list): + raise NotImplementedError + + def eval(self, param_list, arg_list): + raise NotImplementedError + + +class StaticFunction(ModelFunction): + def __init__(self, value): + self.value = value + + def is_predictable(self, param_list=None): + """ + Return whether the model function can be evaluated on the given parameter values. + + For a StaticFunction, this is always the case (i.e., this function always returns true). + """ + return True + + def eval(self, param_list=None, arg_list=None): + """ + Evaluate model function with specified param/arg values. + + Far a Staticfunction, this is just the static value + + """ + return self.value + + +class SplitFunction(ModelFunction): + def __init__(self, param_index, child): + self.param_index = param_index + self.child = child + + def is_predictable(self, param_list): + """ + Return whether the model function can be evaluated on the given parameter values. + + The first value corresponds to the lexically first model parameter, etc. + All parameters must be set, not just the ones this function depends on. + + Returns False iff a parameter the function depends on is not numeric + (e.g. None). + """ + param_value = param_list[self.param_index] + if param_value not in self.child: + return False + return self.child[param_value].is_predictable(param_list) + + def eval(self, param_list, arg_list=list()): + param_value = param_list[self.param_index] + return self.child[param_value].eval(param_list, arg_list) + + +class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. diff --git a/lib/model.py b/lib/model.py index 83c31b1..cddfe27 100644 --- a/lib/model.py +++ b/lib/model.py @@ -7,8 +7,7 @@ from scipy import optimize from sklearn.metrics import r2_score from multiprocessing import Pool from .automata import PTA -from .functions import analytic -from .functions import AnalyticFunction +import dfatool.functions as df from .parameters import ParallelParamStats, ParamStats from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple from .utils import ( @@ -211,7 +210,7 @@ def _try_fits( :param param_filter: Only use measurements whose parameters match param_filter for fitting. """ - functions = analytic.functions(safe_functions_enabled=safe_functions_enabled) + functions = df.analytic.functions(safe_functions_enabled=safe_functions_enabled) for param_key in n_by_param.keys(): # We might remove elements from 'functions' while iterating over @@ -532,31 +531,33 @@ class ModelAttribute: "child": dict(), "child_static": dict(), } - info_map = {"split_by": split_param_index, "child": dict()} + function_child = dict() + info_child = dict() for param_value, child in child_by_param_value.items(): child.set_data_from_paramfit(paramfit, prefix + (param_value,)) - function_map["child"][param_value], info_map["child"][ - param_value - ] = child.get_fitted() - function_map["child_static"][param_value] = child.get_static() + function_child[param_value], info_child[param_value] = child.get_fitted() + function_map = df.SplitFunction(split_param_index, function_child) + info_map = df.SplitInfo(split_param_index, info_child) self.param_model = function_map, info_map def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) - param_model = (None, None) + param_model = (df.StaticFunction(np.median(self.data)), None) if self.function_override is not None: function_str = self.function_override - x = AnalyticFunction(function_str, self.param_names, self.arg_count) + x = df.AnalyticFunction(function_str, self.param_names, self.arg_count) x.fit(self.by_param) if x.fit_success: - param_model = (x, fit_result) + param_model = (x, df.AnalyticInfo(fit_result, x)) elif len(fit_result.keys()): - x = analytic.function_powerset(fit_result, self.param_names, self.arg_count) + x = df.analytic.function_powerset( + fit_result, self.param_names, self.arg_count + ) x.fit(self.by_param) if x.fit_success: - param_model = (x, fit_result) + param_model = (x, df.AnalyticInfo(fit_result, x)) self.param_model = param_model @@ -810,22 +811,12 @@ class AnalyticModel: def model_getter(name, key, **kwargs): param_function, param_info = self.attr_by_name[name][key].get_fitted() - if param_function is None: + if param_info is None: return static_model[name][key] if "arg" in kwargs and "param" in kwargs: kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - while type(param_function) is dict and "split_by" in param_function: - split_param_value = kwargs["param"][param_function["split_by"]] - split_static = param_function["child_static"][split_param_value] - param_function = param_function["child"][split_param_value] - param_info = param_info["child"][split_param_value] - - if param_function is None: - # TODO return static model of child - return split_static - if param_function.is_predictable(kwargs["param"]): return param_function.eval(kwargs["param"]) @@ -833,12 +824,10 @@ class AnalyticModel: def info_getter(name, key): try: - model_function, fit_result = self.attr_by_name[name][key].get_fitted() + model_function, model_info = self.attr_by_name[name][key].get_fitted() except KeyError: return None - if model_function is None: - return None - return {"function": model_function, "fit_result": fit_result} + return model_info return model_getter, info_getter |