diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-25 14:38:46 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-25 14:38:46 +0100 |
commit | f96a52a8b8e8e820f462b8f269a261b31a262441 (patch) | |
tree | 0c0d241fed4a9b6c763c46addbcb01b714dbb7f6 /lib/model.py | |
parent | fde97233d5e0bf8d9c357bac48caa8b5ac2c7a82 (diff) |
Adjust ParamStats interface in preparation for decision-tree models
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 189 |
1 files changed, 166 insertions, 23 deletions
diff --git a/lib/model.py b/lib/model.py index 7f054f5..07e56f1 100644 --- a/lib/model.py +++ b/lib/model.py @@ -8,7 +8,7 @@ from multiprocessing import Pool from .automata import PTA from .functions import analytic from .functions import AnalyticFunction -from .parameters import ParallelParamStats +from .parameters import ParallelParamStats, ParamStats from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple from .utils import ( by_name_to_by_param, @@ -104,17 +104,19 @@ class ParallelParamFit: """Create a new ParallelParamFit object.""" self.fit_queue = list() - def enqueue(self, key, args): + def enqueue(self, key, param, args): """ Add state_or_tran/attribute/param_name to fit queue. This causes fit() to compute the best-fitting function for this model part. - :param key: (state/transition name, model attribute, parameter name) + :param key: arbitrary key used to retrieve param result in `get_result`. Typically (state/transition name, model attribute). + Different parameter names may have the same key. Identical parameter names must have different keys. + :param param: parameter name :param args: [by_param, param_index, safe_functions_enabled, param_filter] by_param[(param 1, param2, ...)] holds measurements. """ - self.fit_queue.append({"key": key, "args": args}) + self.fit_queue.append({"key": (key, param), "args": args}) def fit(self): """ @@ -127,29 +129,29 @@ class ParallelParamFit: with Pool() as pool: self.results = pool.map(_try_fits_parallel, self.fit_queue) - def get_result(self, name, attr): + def get_result(self, key): """ Parse and sanitize fit results. Filters out results where the best function is worse (or not much better than) static mean/median estimates. - :param param_filter: + :param key: arbitrary key used in `enqueue`. Typically (state/transition name, model attribute). + :param param: parameter name :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} } """ fit_result = dict() for result in self.results: if ( - result["key"][0] == name - and result["key"][1] == attr - and result["result"]["best"] is not None + result["key"][0] == key and result["result"]["best"] is not None ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? this_result = result["result"] if this_result["best_rmsd"] >= min( this_result["mean_rmsd"], this_result["median_rmsd"] ): logger.debug( - "Not modeling as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( - result["key"][2], + "Not modeling {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( + result["key"][0], + result["key"][1], this_result["best_rmsd"], this_result["mean_rmsd"], this_result["median_rmsd"], @@ -160,15 +162,16 @@ class ParallelParamFit: this_result["mean_rmsd"], this_result["median_rmsd"] ): logger.debug( - "Not modeling as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( - result["key"][2], + "Not modeling {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( + result["key"][0], + result["key"][1], this_result["best_rmsd"], this_result["mean_rmsd"], this_result["median_rmsd"], ) ) else: - fit_result[result["key"][2]] = this_result + fit_result[result["key"][1]] = this_result return fit_result @@ -345,13 +348,14 @@ class ModelAttribute: def __init__(self, name, attr, data, param_values, param_names, arg_count=0): self.name = name self.attr = attr - self.data = data + self.data = np.array(data) self.param_values = param_values self.param_names = sorted(param_names) self.arg_count = arg_count self.by_param = None # set via ParallelParamStats self.function_override = None self.param_model = None + self.split = None def __repr__(self): mean = np.mean(self.data) @@ -367,18 +371,142 @@ class ModelAttribute: return np.mean(self.by_param[param]) return np.median(self.by_param[param]) + def build_dtree(self): + split_param_index = self.get_split_param_index() + if split_param_index is None: + return + + distinct_values = self.stats.distinct_values_by_param_index[split_param_index] + tt1 = list( + map( + lambda i: self.param_values[i][split_param_index] == distinct_values[0], + range(len(self.param_values)), + ) + ) + tt2 = np.invert(tt1) + + pv1 = list() + pv2 = list() + + for i, param_tuple in enumerate(self.param_values): + if tt1[i]: + pv1.append(param_tuple) + else: + pv2.append(param_tuple) + + # print( + # f">>> split {self.name} {self.attr} by param #{split_param_index}" + # ) + + child1 = ModelAttribute( + self.name, self.attr, self.data[tt1], pv1, self.param_names, self.arg_count + ) + child2 = ModelAttribute( + self.name, self.attr, self.data[tt2], pv2, self.param_names, self.arg_count + ) + + ParamStats.compute_for_attr(child1) + ParamStats.compute_for_attr(child2) + + child1.build_dtree() + child2.build_dtree() + + self.split = ( + split_param_index, + {distinct_values[0]: child1, distinct_values[1]: child2}, + ) + + # print( + # f"<<< split {self.name} {self.attr} by param #{split_param_index}" + # ) + + # None -> kein split notwendig + # andernfalls: Parameter, anhand dessen eine Decision Tree-Ebene aufgespannt wird + # (Kinder sind wiederum ModelAttributes, in denen dieser Parameter konstant ist) + def get_split_param_index(self): + if not self.param_names: + return None + std_by_param = list() + for param_index, param_name in enumerate(self.param_names): + distinct_values = self.stats.distinct_values_by_param_index[param_index] + if self.stats.depends_on_param(param_name) and len(distinct_values) == 2: + val1 = list( + map( + lambda i: self.param_values[i][param_index] + == distinct_values[0], + range(len(self.param_values)), + ) + ) + val2 = np.invert(val1) + val1_std = np.std(self.data[val1]) + val2_std = np.std(self.data[val2]) + std_by_param.append(np.mean([val1_std, val2_std])) + else: + std_by_param.append(np.inf) + for arg_index in range(self.arg_count): + distinct_values = self.stats.distinct_values_by_param_index[ + len(self.param_names) + arg_index + ] + if self.stats.depends_on_arg(arg_index) and len(distinct_values) == 2: + val1 = list( + map( + lambda i: self.param_values[i][ + len(self.param_names) + arg_index + ] + == distinct_values[0], + range(len(self.param_values)), + ) + ) + val2 = np.invert(val1) + val1_std = np.std(self.data[val1]) + val2_std = np.std(self.data[val2]) + std_by_param.append(np.mean([val1_std, val2_std])) + else: + std_by_param.append(np.inf) + split_param_index = np.argmin(std_by_param) + split_std = std_by_param[split_param_index] + if split_std == np.inf: + return None + return split_param_index + def get_data_for_paramfit(self, safe_functions_enabled=False): + if self.split and 0: + return self.get_data_for_paramfit_split( + safe_functions_enabled=safe_functions_enabled + ) + else: + return self.get_data_for_paramfit_this( + safe_functions_enabled=safe_functions_enabled + ) + + def get_data_for_paramfit_split(self, safe_functions_enabled=False): + split_param_index, child_by_param_value = self.split + ret = list() + for param_value, child in child_by_param_value.items(): + child_ret = child.get_data_for_paramfit( + safe_functions_enabled=safe_functions_enabled + ) + for key, param, val in child_ret: + ret.append((key[:2] + (param_value,) + key[2:], param, val)) + return ret + + def get_data_for_paramfit_this(self, safe_functions_enabled=False): ret = list() for param_index, param_name in enumerate(self.param_names): if self.stats.depends_on_param(param_name): ret.append( - (param_name, (self.by_param, param_index, safe_functions_enabled)) + ( + (self.name, self.attr), + param_name, + (self.by_param, param_index, safe_functions_enabled), + ) ) if self.arg_count: for arg_index in range(self.arg_count): if self.stats.depends_on_arg(arg_index): ret.append( ( + (self.name, self.attr), arg_index, ( self.by_param, @@ -387,9 +515,22 @@ class ModelAttribute: ), ) ) + return ret - def set_data_from_paramfit(self, fit_result): + def set_data_from_paramfit(self, paramfit, prefix=tuple()): + if self.split and 0: + self.set_data_from_paramfit_split(paramfit, prefix) + else: + self.set_data_from_paramfit_this(paramfit, prefix) + + def set_data_from_paramfit_split(self, paramfit, prefix): + split_param_index, child_by_param_value = self.split + for param_value, child in child_by_param_value.items(): + child.set_data_from_paramfit(paramfit, prefix + (param_value,)) + + def set_data_from_paramfit_this(self, paramfit, prefix): + fit_result = paramfit.get_result((self.name, self.attr) + prefix) param_model = (None, None) if self.function_override is not None: function_str = self.function_override @@ -546,6 +687,11 @@ class AnalyticModel: paramstats.compute() + np.seterr("raise") + for name in self.names: + for attr in self.attr_by_name[name].values(): + attr.build_dtree() + def attributes(self, name): return self.attr_by_name[name].keys() @@ -627,21 +773,18 @@ class AnalyticModel: for name in self.names: for attr in self.attr_by_name[name].keys(): - for key, args in self.attr_by_name[name][ + for key, param, args in self.attr_by_name[name][ attr ].get_data_for_paramfit( safe_functions_enabled=safe_functions_enabled ): - key = (name, attr, key) - paramfit.enqueue(key, args) + paramfit.enqueue(key, param, args) paramfit.fit() for name in self.names: for attr in self.attr_by_name[name].keys(): - self.attr_by_name[name][attr].set_data_from_paramfit( - paramfit.get_result(name, attr) - ) + self.attr_by_name[name][attr].set_data_from_paramfit(paramfit) self.fit_done = True |