summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-02-25 14:38:46 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-02-25 14:38:46 +0100
commitf96a52a8b8e8e820f462b8f269a261b31a262441 (patch)
tree0c0d241fed4a9b6c763c46addbcb01b714dbb7f6 /lib/model.py
parentfde97233d5e0bf8d9c357bac48caa8b5ac2c7a82 (diff)
Adjust ParamStats interface in preparation for decision-tree models
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py189
1 files changed, 166 insertions, 23 deletions
diff --git a/lib/model.py b/lib/model.py
index 7f054f5..07e56f1 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -8,7 +8,7 @@ from multiprocessing import Pool
from .automata import PTA
from .functions import analytic
from .functions import AnalyticFunction
-from .parameters import ParallelParamStats
+from .parameters import ParallelParamStats, ParamStats
from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
from .utils import (
by_name_to_by_param,
@@ -104,17 +104,19 @@ class ParallelParamFit:
"""Create a new ParallelParamFit object."""
self.fit_queue = list()
- def enqueue(self, key, args):
+ def enqueue(self, key, param, args):
"""
Add state_or_tran/attribute/param_name to fit queue.
This causes fit() to compute the best-fitting function for this model part.
- :param key: (state/transition name, model attribute, parameter name)
+ :param key: arbitrary key used to retrieve param result in `get_result`. Typically (state/transition name, model attribute).
+ Different parameter names may have the same key. Identical parameter names must have different keys.
+ :param param: parameter name
:param args: [by_param, param_index, safe_functions_enabled, param_filter]
by_param[(param 1, param2, ...)] holds measurements.
"""
- self.fit_queue.append({"key": key, "args": args})
+ self.fit_queue.append({"key": (key, param), "args": args})
def fit(self):
"""
@@ -127,29 +129,29 @@ class ParallelParamFit:
with Pool() as pool:
self.results = pool.map(_try_fits_parallel, self.fit_queue)
- def get_result(self, name, attr):
+ def get_result(self, key):
"""
Parse and sanitize fit results.
Filters out results where the best function is worse (or not much better than) static mean/median estimates.
- :param param_filter:
+ :param key: arbitrary key used in `enqueue`. Typically (state/transition name, model attribute).
+ :param param: parameter name
:returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} }
"""
fit_result = dict()
for result in self.results:
if (
- result["key"][0] == name
- and result["key"][1] == attr
- and result["result"]["best"] is not None
+ result["key"][0] == key and result["result"]["best"] is not None
): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
this_result = result["result"]
if this_result["best_rmsd"] >= min(
this_result["mean_rmsd"], this_result["median_rmsd"]
):
logger.debug(
- "Not modeling as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
- result["key"][2],
+ "Not modeling {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
+ result["key"][0],
+ result["key"][1],
this_result["best_rmsd"],
this_result["mean_rmsd"],
this_result["median_rmsd"],
@@ -160,15 +162,16 @@ class ParallelParamFit:
this_result["mean_rmsd"], this_result["median_rmsd"]
):
logger.debug(
- "Not modeling as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
- result["key"][2],
+ "Not modeling {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
+ result["key"][0],
+ result["key"][1],
this_result["best_rmsd"],
this_result["mean_rmsd"],
this_result["median_rmsd"],
)
)
else:
- fit_result[result["key"][2]] = this_result
+ fit_result[result["key"][1]] = this_result
return fit_result
@@ -345,13 +348,14 @@ class ModelAttribute:
def __init__(self, name, attr, data, param_values, param_names, arg_count=0):
self.name = name
self.attr = attr
- self.data = data
+ self.data = np.array(data)
self.param_values = param_values
self.param_names = sorted(param_names)
self.arg_count = arg_count
self.by_param = None # set via ParallelParamStats
self.function_override = None
self.param_model = None
+ self.split = None
def __repr__(self):
mean = np.mean(self.data)
@@ -367,18 +371,142 @@ class ModelAttribute:
return np.mean(self.by_param[param])
return np.median(self.by_param[param])
+ def build_dtree(self):
+ split_param_index = self.get_split_param_index()
+ if split_param_index is None:
+ return
+
+ distinct_values = self.stats.distinct_values_by_param_index[split_param_index]
+ tt1 = list(
+ map(
+ lambda i: self.param_values[i][split_param_index] == distinct_values[0],
+ range(len(self.param_values)),
+ )
+ )
+ tt2 = np.invert(tt1)
+
+ pv1 = list()
+ pv2 = list()
+
+ for i, param_tuple in enumerate(self.param_values):
+ if tt1[i]:
+ pv1.append(param_tuple)
+ else:
+ pv2.append(param_tuple)
+
+ # print(
+ # f">>> split {self.name} {self.attr} by param #{split_param_index}"
+ # )
+
+ child1 = ModelAttribute(
+ self.name, self.attr, self.data[tt1], pv1, self.param_names, self.arg_count
+ )
+ child2 = ModelAttribute(
+ self.name, self.attr, self.data[tt2], pv2, self.param_names, self.arg_count
+ )
+
+ ParamStats.compute_for_attr(child1)
+ ParamStats.compute_for_attr(child2)
+
+ child1.build_dtree()
+ child2.build_dtree()
+
+ self.split = (
+ split_param_index,
+ {distinct_values[0]: child1, distinct_values[1]: child2},
+ )
+
+ # print(
+ # f"<<< split {self.name} {self.attr} by param #{split_param_index}"
+ # )
+
+ # None -> kein split notwendig
+ # andernfalls: Parameter, anhand dessen eine Decision Tree-Ebene aufgespannt wird
+ # (Kinder sind wiederum ModelAttributes, in denen dieser Parameter konstant ist)
+ def get_split_param_index(self):
+ if not self.param_names:
+ return None
+ std_by_param = list()
+ for param_index, param_name in enumerate(self.param_names):
+ distinct_values = self.stats.distinct_values_by_param_index[param_index]
+ if self.stats.depends_on_param(param_name) and len(distinct_values) == 2:
+ val1 = list(
+ map(
+ lambda i: self.param_values[i][param_index]
+ == distinct_values[0],
+ range(len(self.param_values)),
+ )
+ )
+ val2 = np.invert(val1)
+ val1_std = np.std(self.data[val1])
+ val2_std = np.std(self.data[val2])
+ std_by_param.append(np.mean([val1_std, val2_std]))
+ else:
+ std_by_param.append(np.inf)
+ for arg_index in range(self.arg_count):
+ distinct_values = self.stats.distinct_values_by_param_index[
+ len(self.param_names) + arg_index
+ ]
+ if self.stats.depends_on_arg(arg_index) and len(distinct_values) == 2:
+ val1 = list(
+ map(
+ lambda i: self.param_values[i][
+ len(self.param_names) + arg_index
+ ]
+ == distinct_values[0],
+ range(len(self.param_values)),
+ )
+ )
+ val2 = np.invert(val1)
+ val1_std = np.std(self.data[val1])
+ val2_std = np.std(self.data[val2])
+ std_by_param.append(np.mean([val1_std, val2_std]))
+ else:
+ std_by_param.append(np.inf)
+ split_param_index = np.argmin(std_by_param)
+ split_std = std_by_param[split_param_index]
+ if split_std == np.inf:
+ return None
+ return split_param_index
+
def get_data_for_paramfit(self, safe_functions_enabled=False):
+ if self.split and 0:
+ return self.get_data_for_paramfit_split(
+ safe_functions_enabled=safe_functions_enabled
+ )
+ else:
+ return self.get_data_for_paramfit_this(
+ safe_functions_enabled=safe_functions_enabled
+ )
+
+ def get_data_for_paramfit_split(self, safe_functions_enabled=False):
+ split_param_index, child_by_param_value = self.split
+ ret = list()
+ for param_value, child in child_by_param_value.items():
+ child_ret = child.get_data_for_paramfit(
+ safe_functions_enabled=safe_functions_enabled
+ )
+ for key, param, val in child_ret:
+ ret.append((key[:2] + (param_value,) + key[2:], param, val))
+ return ret
+
+ def get_data_for_paramfit_this(self, safe_functions_enabled=False):
ret = list()
for param_index, param_name in enumerate(self.param_names):
if self.stats.depends_on_param(param_name):
ret.append(
- (param_name, (self.by_param, param_index, safe_functions_enabled))
+ (
+ (self.name, self.attr),
+ param_name,
+ (self.by_param, param_index, safe_functions_enabled),
+ )
)
if self.arg_count:
for arg_index in range(self.arg_count):
if self.stats.depends_on_arg(arg_index):
ret.append(
(
+ (self.name, self.attr),
arg_index,
(
self.by_param,
@@ -387,9 +515,22 @@ class ModelAttribute:
),
)
)
+
return ret
- def set_data_from_paramfit(self, fit_result):
+ def set_data_from_paramfit(self, paramfit, prefix=tuple()):
+ if self.split and 0:
+ self.set_data_from_paramfit_split(paramfit, prefix)
+ else:
+ self.set_data_from_paramfit_this(paramfit, prefix)
+
+ def set_data_from_paramfit_split(self, paramfit, prefix):
+ split_param_index, child_by_param_value = self.split
+ for param_value, child in child_by_param_value.items():
+ child.set_data_from_paramfit(paramfit, prefix + (param_value,))
+
+ def set_data_from_paramfit_this(self, paramfit, prefix):
+ fit_result = paramfit.get_result((self.name, self.attr) + prefix)
param_model = (None, None)
if self.function_override is not None:
function_str = self.function_override
@@ -546,6 +687,11 @@ class AnalyticModel:
paramstats.compute()
+ np.seterr("raise")
+ for name in self.names:
+ for attr in self.attr_by_name[name].values():
+ attr.build_dtree()
+
def attributes(self, name):
return self.attr_by_name[name].keys()
@@ -627,21 +773,18 @@ class AnalyticModel:
for name in self.names:
for attr in self.attr_by_name[name].keys():
- for key, args in self.attr_by_name[name][
+ for key, param, args in self.attr_by_name[name][
attr
].get_data_for_paramfit(
safe_functions_enabled=safe_functions_enabled
):
- key = (name, attr, key)
- paramfit.enqueue(key, args)
+ paramfit.enqueue(key, param, args)
paramfit.fit()
for name in self.names:
for attr in self.attr_by_name[name].keys():
- self.attr_by_name[name][attr].set_data_from_paramfit(
- paramfit.get_result(name, attr)
- )
+ self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
self.fit_done = True