diff options
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-x | lib/dfatool.py | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index e39f58f..3387de0 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -426,6 +426,15 @@ class AnalyticFunction: return self._regression_args = res.x + def is_predictable(self, param_list): + for i, param in enumerate(param_list): + if self._dependson[i] and not is_numeric(param): + return False + return True + + def eval(self, param_list): + return self._function(self._regression_args, param_list) + class analytic: _num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1")) _num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1")) @@ -742,7 +751,7 @@ class EnergyModel: static_model = self._get_model_from_dict(self.by_name, np.median) def get_fitted(self): - static_model = self._get_model_from_dict(self.by_name, np.mean) + static_model = self._get_model_from_dict(self.by_name, np.median) param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()]) for state_or_tran in self.by_name.keys(): if self.by_name[state_or_tran]['isa'] == 'state': @@ -764,6 +773,11 @@ class EnergyModel: } def model_getter(name, key, **kwargs): + if key in param_model[name]: + param_list = kwargs['param'] + param_function = param_model[name][key]['function'] + if param_function.is_predictable(param_list): + return param_function.eval(param_list) return static_model[name][key] def info_getter(name, key): |