From eff4f40655a7f9e7f5c8d82f548ebf284d26b01c Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Thu, 25 Feb 2021 15:23:47 +0100 Subject: kinda proper dtree support (todo: refactoring) --- lib/model.py | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) (limited to 'lib/model.py') diff --git a/lib/model.py b/lib/model.py index 99626bb..451a39a 100644 --- a/lib/model.py +++ b/lib/model.py @@ -2,6 +2,7 @@ import logging import numpy as np +import os from scipy import optimize from sklearn.metrics import r2_score from multiprocessing import Pool @@ -470,7 +471,7 @@ class ModelAttribute: return split_param_index def get_data_for_paramfit(self, safe_functions_enabled=False): - if self.split and 0: + if self.split: return self.get_data_for_paramfit_split( safe_functions_enabled=safe_functions_enabled ) @@ -519,15 +520,27 @@ class ModelAttribute: return ret def set_data_from_paramfit(self, paramfit, prefix=tuple()): - if self.split and 0: + if self.split: self.set_data_from_paramfit_split(paramfit, prefix) else: self.set_data_from_paramfit_this(paramfit, prefix) def set_data_from_paramfit_split(self, paramfit, prefix): split_param_index, child_by_param_value = self.split + function_map = { + "split_by": split_param_index, + "child": dict(), + "child_static": dict(), + } + info_map = {"split_by": split_param_index, "child": dict()} for param_value, child in child_by_param_value.items(): child.set_data_from_paramfit(paramfit, prefix + (param_value,)) + function_map["child"][param_value], info_map["child"][ + param_value + ] = child.get_fitted() + function_map["child_static"][param_value] = child.get_static() + + self.param_model = function_map, info_map def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) @@ -687,10 +700,10 @@ class AnalyticModel: paramstats.compute() - np.seterr("raise") - for name in self.names: - for attr in self.attr_by_name[name].values(): - attr.build_dtree() + if not os.getenv("DFATOOL_NO_DECISIONTREES"): + for name in self.names: + for attr in self.attr_by_name[name].values(): + attr.build_dtree() def attributes(self, name): return self.attr_by_name[name].keys() @@ -795,7 +808,7 @@ class AnalyticModel: static_model[name][k] = v.get_static(use_mean=use_mean) def model_getter(name, key, **kwargs): - param_function, _ = self.attr_by_name[name][key].get_fitted() + param_function, param_info = self.attr_by_name[name][key].get_fitted() if param_function is None: return static_model[name][key] @@ -803,6 +816,16 @@ class AnalyticModel: if "arg" in kwargs and "param" in kwargs: kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) + while type(param_function) is dict and "split_by" in param_function: + split_param_value = kwargs["param"][param_function["split_by"]] + split_static = param_function["child_static"][split_param_value] + param_function = param_function["child"][split_param_value] + param_info = param_info["child"][split_param_value] + + if param_function is None: + # TODO return static model of child + return split_static + if param_function.is_predictable(kwargs["param"]): return param_function.eval(kwargs["param"]) -- cgit v1.2.3