From 9957a7c1f29c842f31ba7adb2506df021f13388a Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 14 Feb 2018 08:51:38 +0100 Subject: proper paramfuncton support (no crossvalidation yet for assessment) --- lib/dfatool.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/dfatool.py b/lib/dfatool.py index e39f58f..3387de0 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -426,6 +426,15 @@ class AnalyticFunction: return self._regression_args = res.x + def is_predictable(self, param_list): + for i, param in enumerate(param_list): + if self._dependson[i] and not is_numeric(param): + return False + return True + + def eval(self, param_list): + return self._function(self._regression_args, param_list) + class analytic: _num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1")) _num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1")) @@ -742,7 +751,7 @@ class EnergyModel: static_model = self._get_model_from_dict(self.by_name, np.median) def get_fitted(self): - static_model = self._get_model_from_dict(self.by_name, np.mean) + static_model = self._get_model_from_dict(self.by_name, np.median) param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()]) for state_or_tran in self.by_name.keys(): if self.by_name[state_or_tran]['isa'] == 'state': @@ -764,6 +773,11 @@ class EnergyModel: } def model_getter(name, key, **kwargs): + if key in param_model[name]: + param_list = kwargs['param'] + param_function = param_model[name][key]['function'] + if param_function.is_predictable(param_list): + return param_function.eval(param_list) return static_model[name][key] def info_getter(name, key): -- cgit v1.2.3