summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py37
1 files changed, 30 insertions, 7 deletions
diff --git a/lib/model.py b/lib/model.py
index 99626bb..451a39a 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -2,6 +2,7 @@
import logging
import numpy as np
+import os
from scipy import optimize
from sklearn.metrics import r2_score
from multiprocessing import Pool
@@ -470,7 +471,7 @@ class ModelAttribute:
return split_param_index
def get_data_for_paramfit(self, safe_functions_enabled=False):
- if self.split and 0:
+ if self.split:
return self.get_data_for_paramfit_split(
safe_functions_enabled=safe_functions_enabled
)
@@ -519,15 +520,27 @@ class ModelAttribute:
return ret
def set_data_from_paramfit(self, paramfit, prefix=tuple()):
- if self.split and 0:
+ if self.split:
self.set_data_from_paramfit_split(paramfit, prefix)
else:
self.set_data_from_paramfit_this(paramfit, prefix)
def set_data_from_paramfit_split(self, paramfit, prefix):
split_param_index, child_by_param_value = self.split
+ function_map = {
+ "split_by": split_param_index,
+ "child": dict(),
+ "child_static": dict(),
+ }
+ info_map = {"split_by": split_param_index, "child": dict()}
for param_value, child in child_by_param_value.items():
child.set_data_from_paramfit(paramfit, prefix + (param_value,))
+ function_map["child"][param_value], info_map["child"][
+ param_value
+ ] = child.get_fitted()
+ function_map["child_static"][param_value] = child.get_static()
+
+ self.param_model = function_map, info_map
def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
@@ -687,10 +700,10 @@ class AnalyticModel:
paramstats.compute()
- np.seterr("raise")
- for name in self.names:
- for attr in self.attr_by_name[name].values():
- attr.build_dtree()
+ if not os.getenv("DFATOOL_NO_DECISIONTREES"):
+ for name in self.names:
+ for attr in self.attr_by_name[name].values():
+ attr.build_dtree()
def attributes(self, name):
return self.attr_by_name[name].keys()
@@ -795,7 +808,7 @@ class AnalyticModel:
static_model[name][k] = v.get_static(use_mean=use_mean)
def model_getter(name, key, **kwargs):
- param_function, _ = self.attr_by_name[name][key].get_fitted()
+ param_function, param_info = self.attr_by_name[name][key].get_fitted()
if param_function is None:
return static_model[name][key]
@@ -803,6 +816,16 @@ class AnalyticModel:
if "arg" in kwargs and "param" in kwargs:
kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
+ while type(param_function) is dict and "split_by" in param_function:
+ split_param_value = kwargs["param"][param_function["split_by"]]
+ split_static = param_function["child_static"][split_param_value]
+ param_function = param_function["child"][split_param_value]
+ param_info = param_info["child"][split_param_value]
+
+ if param_function is None:
+ # TODO return static model of child
+ return split_static
+
if param_function.is_predictable(kwargs["param"]):
return param_function.eval(kwargs["param"])