summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-02-14 08:51:38 +0100
committerDaniel Friesel <derf@finalrewind.org>2018-02-14 08:51:38 +0100
commit9957a7c1f29c842f31ba7adb2506df021f13388a (patch)
tree5746c47a214e1dd16196ffffe14868fe5216302f /lib
parent7c26be2cad657e9a08e229bb541a7d926079f6a9 (diff)
proper paramfuncton support (no crossvalidation yet for assessment)
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py16
1 files changed, 15 insertions, 1 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index e39f58f..3387de0 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -426,6 +426,15 @@ class AnalyticFunction:
return
self._regression_args = res.x
+ def is_predictable(self, param_list):
+ for i, param in enumerate(param_list):
+ if self._dependson[i] and not is_numeric(param):
+ return False
+ return True
+
+ def eval(self, param_list):
+ return self._function(self._regression_args, param_list)
+
class analytic:
_num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1"))
_num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1"))
@@ -742,7 +751,7 @@ class EnergyModel:
static_model = self._get_model_from_dict(self.by_name, np.median)
def get_fitted(self):
- static_model = self._get_model_from_dict(self.by_name, np.mean)
+ static_model = self._get_model_from_dict(self.by_name, np.median)
param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()])
for state_or_tran in self.by_name.keys():
if self.by_name[state_or_tran]['isa'] == 'state':
@@ -764,6 +773,11 @@ class EnergyModel:
}
def model_getter(name, key, **kwargs):
+ if key in param_model[name]:
+ param_list = kwargs['param']
+ param_function = param_model[name][key]['function']
+ if param_function.is_predictable(param_list):
+ return param_function.eval(param_list)
return static_model[name][key]
def info_getter(name, key):