summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py45
1 files changed, 17 insertions, 28 deletions
diff --git a/lib/model.py b/lib/model.py
index 83c31b1..cddfe27 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -7,8 +7,7 @@ from scipy import optimize
from sklearn.metrics import r2_score
from multiprocessing import Pool
from .automata import PTA
-from .functions import analytic
-from .functions import AnalyticFunction
+import dfatool.functions as df
from .parameters import ParallelParamStats, ParamStats
from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
from .utils import (
@@ -211,7 +210,7 @@ def _try_fits(
:param param_filter: Only use measurements whose parameters match param_filter for fitting.
"""
- functions = analytic.functions(safe_functions_enabled=safe_functions_enabled)
+ functions = df.analytic.functions(safe_functions_enabled=safe_functions_enabled)
for param_key in n_by_param.keys():
# We might remove elements from 'functions' while iterating over
@@ -532,31 +531,33 @@ class ModelAttribute:
"child": dict(),
"child_static": dict(),
}
- info_map = {"split_by": split_param_index, "child": dict()}
+ function_child = dict()
+ info_child = dict()
for param_value, child in child_by_param_value.items():
child.set_data_from_paramfit(paramfit, prefix + (param_value,))
- function_map["child"][param_value], info_map["child"][
- param_value
- ] = child.get_fitted()
- function_map["child_static"][param_value] = child.get_static()
+ function_child[param_value], info_child[param_value] = child.get_fitted()
+ function_map = df.SplitFunction(split_param_index, function_child)
+ info_map = df.SplitInfo(split_param_index, info_child)
self.param_model = function_map, info_map
def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
- param_model = (None, None)
+ param_model = (df.StaticFunction(np.median(self.data)), None)
if self.function_override is not None:
function_str = self.function_override
- x = AnalyticFunction(function_str, self.param_names, self.arg_count)
+ x = df.AnalyticFunction(function_str, self.param_names, self.arg_count)
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, fit_result)
+ param_model = (x, df.AnalyticInfo(fit_result, x))
elif len(fit_result.keys()):
- x = analytic.function_powerset(fit_result, self.param_names, self.arg_count)
+ x = df.analytic.function_powerset(
+ fit_result, self.param_names, self.arg_count
+ )
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, fit_result)
+ param_model = (x, df.AnalyticInfo(fit_result, x))
self.param_model = param_model
@@ -810,22 +811,12 @@ class AnalyticModel:
def model_getter(name, key, **kwargs):
param_function, param_info = self.attr_by_name[name][key].get_fitted()
- if param_function is None:
+ if param_info is None:
return static_model[name][key]
if "arg" in kwargs and "param" in kwargs:
kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
- while type(param_function) is dict and "split_by" in param_function:
- split_param_value = kwargs["param"][param_function["split_by"]]
- split_static = param_function["child_static"][split_param_value]
- param_function = param_function["child"][split_param_value]
- param_info = param_info["child"][split_param_value]
-
- if param_function is None:
- # TODO return static model of child
- return split_static
-
if param_function.is_predictable(kwargs["param"]):
return param_function.eval(kwargs["param"])
@@ -833,12 +824,10 @@ class AnalyticModel:
def info_getter(name, key):
try:
- model_function, fit_result = self.attr_by_name[name][key].get_fitted()
+ model_function, model_info = self.attr_by_name[name][key].get_fitted()
except KeyError:
return None
- if model_function is None:
- return None
- return {"function": model_function, "fit_result": fit_result}
+ return model_info
return model_getter, info_getter