diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 519 |
1 files changed, 210 insertions, 309 deletions
diff --git a/lib/model.py b/lib/model.py index 518566c..ce73a02 100644 --- a/lib/model.py +++ b/lib/model.py @@ -8,9 +8,14 @@ from multiprocessing import Pool from .automata import PTA from .functions import analytic from .functions import AnalyticFunction -from .parameters import ParamStats +from .parameters import ParallelParamStats from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple -from .utils import by_name_to_by_param, by_param_to_by_name, match_parameter_values +from .utils import ( + by_name_to_by_param, + by_param_to_by_name, + match_parameter_values, + partition_by_param, +) logger = logging.getLogger(__name__) arg_support_enabled = True @@ -95,38 +100,21 @@ class ParallelParamFit: function type for each parameter. """ - def __init__(self, by_param): + def __init__(self): """Create a new ParallelParamFit object.""" - self.fit_queue = [] - self.by_param = by_param + self.fit_queue = list() - def enqueue( - self, - state_or_tran, - attribute, - param_index, - param_name, - safe_functions_enabled=False, - param_filter=None, - ): + def enqueue(self, key, args): """ Add state_or_tran/attribute/param_name to fit queue. This causes fit() to compute the best-fitting function for this model part. + + :param key: (state/transition name, model attribute, parameter name) + :param args: [by_param, param_index, safe_functions_enabled, param_filter] + by_param[(param 1, param2, ...)] holds measurements. """ - # Transform by_param[(state_or_tran, param_value)][attribute] = ... - # into n_by_param[param_value] = ... - # (param_value is dynamic, the rest is fixed) - n_by_param = dict() - for k, v in self.by_param.items(): - if k[0] == state_or_tran: - n_by_param[k[1]] = v[attribute] - self.fit_queue.append( - { - "key": [state_or_tran, attribute, param_name, param_filter], - "args": [n_by_param, param_index, safe_functions_enabled, param_filter], - } - ) + self.fit_queue.append({"key": key, "args": args}) def fit(self): """ @@ -139,14 +127,12 @@ class ParallelParamFit: with Pool() as pool: self.results = pool.map(_try_fits_parallel, self.fit_queue) - def get_result(self, name, attribute, param_filter: dict = None): + def get_result(self, name, attr): """ - Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'. + Parse and sanitize fit results. Filters out results where the best function is worse (or not much better than) static mean/median estimates. - :param name: state/transition/... name, e.g. 'TX' - :param attribute: model attribute, e.g. 'duration' :param param_filter: :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} } """ @@ -154,8 +140,7 @@ class ParallelParamFit: for result in self.results: if ( result["key"][0] == name - and result["key"][1] == attribute - and result["key"][3] == param_filter + and result["key"][1] == attr and result["result"]["best"] is not None ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? this_result = result["result"] @@ -163,9 +148,7 @@ class ParallelParamFit: this_result["mean_rmsd"], this_result["median_rmsd"] ): logger.debug( - "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( - name, - attribute, + "Not modeling as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( result["key"][2], this_result["best_rmsd"], this_result["mean_rmsd"], @@ -177,9 +160,7 @@ class ParallelParamFit: this_result["mean_rmsd"], this_result["median_rmsd"] ): logger.debug( - "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( - name, - attribute, + "Not modeling as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( result["key"][2], this_result["best_rmsd"], this_result["mean_rmsd"], @@ -360,6 +341,81 @@ def _num_args_from_by_name(by_name): return num_args +class ModelAttribute: + def __init__(self, name, attr, data, param_values, param_names, arg_count=0): + self.name = name + self.attr = attr + self.data = data + self.param_values = param_values + self.param_names = sorted(param_names) + self.arg_count = arg_count + self.by_param = None # set via ParallelParamStats + self.function_override = None + self.param_model = None + + def get_static(self, use_mean=False): + if use_mean: + return np.mean(self.data) + return np.median(self.data) + + def get_lut(self, param, use_mean=False): + if use_mean: + return np.mean(self.by_param[param]) + return np.median(self.by_param[param]) + + def get_data_for_paramfit(self, safe_functions_enabled=False): + ret = list() + for param_index, param_name in enumerate(self.param_names): + if self.stats.depends_on_param(param_name): + ret.append( + (param_name, (self.by_param, param_index, safe_functions_enabled)) + ) + if self.arg_count: + for arg_index in range(self.arg_count): + if self.stats.depends_on_arg(arg_index): + ret.append( + ( + arg_index, + ( + self.by_param, + len(self.param_names) + arg_index, + safe_functions_enabled, + ), + ) + ) + return ret + + def set_data_from_paramfit(self, fit_result): + param_model = (None, None) + if self.function_override is not None: + function_str = self.function_override + x = AnalyticFunction(function_str, self.param_names, self.arg_count) + x.fit(self.by_param) + if x.fit_success: + param_model = (x, fit_result) + elif len(fit_result.keys()): + x = analytic.function_powerset(fit_result, self.param_names, self.arg_count) + x.fit(self.by_param) + + if x.fit_success: + param_model = (x, fit_result) + + self.param_model = param_model + + def get_fitted(self): + """ + Get paramete-aware model function and model information function. + They must have been set via get_data_for_paramfit -> ParallelParamFit -> set-data_from_paramfit first. + + Returns a tuple (function, info): + function -> AnalyticFunction for model. function(param=parameter values) -> model value. + info -> {'fit_result' : ..., 'function' : ... } + + Returns (None, None) if fitting failed. Returns None if ParamFit has not been performed yet. + """ + return self.param_model + + class AnalyticModel: """ Parameter-aware analytic energy/data size/... model. @@ -447,8 +503,8 @@ class AnalyticModel: :param use_corrcoef: use correlation coefficient instead of stddev comparison to detect whether a model attribute depends on a parameter """ self.cache = dict() - self.by_name = by_name - self.by_param = by_name_to_by_param(by_name) + self.by_name = by_name # no longer required? + self.attr_by_name = dict() self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) self.function_override = function_override.copy() @@ -457,26 +513,33 @@ class AnalyticModel: if self._num_args is None: self._num_args = _num_args_from_by_name(by_name) - self.stats = ParamStats( - self.by_name, - self.by_param, - self.parameters, - self._num_args, - use_corrcoef=use_corrcoef, - ) + self.fit_done = False - def _get_model_from_dict(self, model_dict, model_function): - model = {} - for name, elem in model_dict.items(): - model[name] = {} - for key in elem["attributes"]: - try: - model[name][key] = model_function(elem[key]) - except RuntimeWarning: - logger.warning("Got no data for {} {}".format(name, key)) - except FloatingPointError as fpe: - logger.warning("Got no data for {} {}: {}".format(name, key, fpe)) - return model + self._compute_stats(by_name) + + def _compute_stats(self, by_name): + paramstats = ParallelParamStats() + + for name, data in by_name.items(): + self.attr_by_name[name] = dict() + for attr in data["attributes"]: + model_attr = ModelAttribute( + name, + attr, + data[attr], + data["param"], + self.parameters, + self._num_args.get(name, 0), + ) + self.attr_by_name[name][attr] = model_attr + paramstats.enqueue((name, attr), model_attr) + if (name, attr) in self.function_override: + model_attr.function_override = self.function_override[(name, attr)] + + paramstats.compute() + + def attributes(self, name): + return self.attr_by_name[name].keys() def param_index(self, param_name): if param_name in self.parameters: @@ -492,21 +555,20 @@ class AnalyticModel: """ Get static model function: name, attribute -> model value. - Uses the median of by_name for modeling. + Uses the median of by_name for modeling, unless `use_mean` is set. """ - getter_function = np.median - - if use_mean: - getter_function = np.mean - - static_model = self._get_model_from_dict(self.by_name, getter_function) + model = dict() + for name, attr in self.attr_by_name.items(): + model[name] = dict() + for k, v in attr.items(): + model[name][k] = v.get_static(use_mean=use_mean) def static_model_getter(name, key, **kwargs): - return static_model[name][key] + return model[name][key] return static_model_getter - def get_param_lut(self, fallback=False): + def get_param_lut(self, use_mean=False, fallback=False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. @@ -516,13 +578,22 @@ class AnalyticModel: arguments: fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ - static_model = self._get_model_from_dict(self.by_name, np.median) - lut_model = self._get_model_from_dict(self.by_param, np.median) - - def lut_median_getter(name, key, param, arg=[], **kwargs): + static_model = dict() + lut_model = dict() + for name, attr in self.attr_by_name.items(): + static_model[name] = dict() + lut_model[name] = dict() + for k, v in attr.items(): + static_model[name][k] = v.get_static(use_mean=use_mean) + lut_model[name][k] = dict() + for param, model_value in v.by_param.items(): + lut_model[name][k][param] = v.get_lut(param, use_mean=use_mean) + + def lut_median_getter(name, key, param, arg=list(), **kwargs): param.extend(map(soft_cast_int, arg)) + param = tuple(param) try: - return lut_model[(name, tuple(param))][key] + return lut_model[name][key][param] except KeyError: if fallback: return static_model[name][key] @@ -530,84 +601,67 @@ class AnalyticModel: return lut_median_getter - def get_fitted(self, safe_functions_enabled=False): + def get_fitted(self, use_mean=False, safe_functions_enabled=False): """ - Get paramete-aware model function and model information function. + Get parameter-aware model function and model information function. Returns two functions: model_function(name, attribute, param=parameter values) -> model value. model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None """ - if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache: - return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"] - - static_model = self._get_model_from_dict(self.by_name, np.median) - param_model = dict([[name, {}] for name in self.by_name.keys()]) - paramfit = ParallelParamFit(self.by_param) - - for name in self.by_name.keys(): - for attribute in self.by_name[name]["attributes"]: - for param_index, param in enumerate(self.parameters): - if self.stats.depends_on_param(name, attribute, param): - paramfit.enqueue(name, attribute, param_index, param, False) - if arg_support_enabled and name in self._num_args: - for arg_index in range(self._num_args[name]): - if self.stats.depends_on_arg(name, attribute, arg_index): - paramfit.enqueue( - name, - attribute, - len(self.parameters) + arg_index, - arg_index, - False, - ) - paramfit.fit() - - for name in self.by_name.keys(): - num_args = 0 - if name in self._num_args: - num_args = self._num_args[name] - for attribute in self.by_name[name]["attributes"]: - fit_result = paramfit.get_result(name, attribute) - - if (name, attribute) in self.function_override: - function_str = self.function_override[(name, attribute)] - x = AnalyticFunction(function_str, self.parameters, num_args) - x.fit(self.by_param, name, attribute) - if x.fit_success: - param_model[name][attribute] = { - "fit_result": fit_result, - "function": x, - } - elif len(fit_result.keys()): - x = analytic.function_powerset( - fit_result, self.parameters, num_args + if not self.fit_done: + + paramfit = ParallelParamFit() + + for name in self.names: + for attr in self.attr_by_name[name].keys(): + for key, args in self.attr_by_name[name][ + attr + ].get_data_for_paramfit( + safe_functions_enabled=safe_functions_enabled + ): + key = (name, attr, key) + paramfit.enqueue(key, args) + + paramfit.fit() + + for name in self.names: + for attr in self.attr_by_name[name].keys(): + self.attr_by_name[name][attr].set_data_from_paramfit( + paramfit.get_result(name, attr) ) - x.fit(self.by_param, name, attribute) - if x.fit_success: - param_model[name][attribute] = { - "fit_result": fit_result, - "function": x, - } + self.fit_done = True + + static_model = dict() + for name, attr in self.attr_by_name.items(): + static_model[name] = dict() + for k, v in attr.items(): + static_model[name][k] = v.get_static(use_mean=use_mean) def model_getter(name, key, **kwargs): + param_function, _ = self.attr_by_name[name][key].get_fitted() + + if param_function is None: + return static_model[name][key] + if "arg" in kwargs and "param" in kwargs: kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - if key in param_model[name]: - param_list = kwargs["param"] - param_function = param_model[name][key]["function"] - if param_function.is_predictable(param_list): - return param_function.eval(param_list) + + if param_function.is_predictable(kwargs["param"]): + return param_function.eval(kwargs["param"]) + return static_model[name][key] def info_getter(name, key): - if key in param_model[name]: - return param_model[name][key] - return None - - self.cache["fitted_model_getter"] = model_getter - self.cache["fitted_info_getter"] = info_getter + try: + model_function, fit_result = self.attr_by_name[name][key].get_fitted() + except KeyError: + return None + if model_function is None: + return None + return {"function": model_function, "fit_result": fit_result} return model_getter, info_getter @@ -625,20 +679,22 @@ class AnalyticModel: overfitting cannot be detected. """ detailed_results = {} - for name, elem in sorted(self.by_name.items()): + for name in self.names: detailed_results[name] = {} - for attribute in elem["attributes"]: + for attribute in self.attr_by_name[name].keys(): + data = self.attr_by_name[name][attribute].data + param_values = self.attr_by_name[name][attribute].param_values predicted_data = np.array( list( map( lambda i: model_function( - name, attribute, param=elem["param"][i] + name, attribute, param=param_values[i] ), - range(len(elem[attribute])), + range(len(data)), ) ) ) - measures = regression_measures(predicted_data, elem[attribute]) + measures = regression_measures(predicted_data, data) detailed_results[name][attribute] = measures return {"by_name": detailed_results} @@ -648,7 +704,7 @@ class AnalyticModel: pass -class PTAModel: +class PTAModel(AnalyticModel): """ Parameter-aware PTA-based energy model. @@ -718,11 +774,18 @@ class PTAModel: pelt -- perform sub-state detection via PELT and model sub-states as well. Requires traces to be set. """ self.by_name = by_name + self.attr_by_name = dict() self.by_param = by_name_to_by_param(by_name) + self.names = sorted(by_name.keys()) self._parameter_names = sorted(parameters) + self.parameters = sorted(parameters) self._num_args = arg_count self._use_corrcoef = use_corrcoef self.traces = traces + self.function_override = function_override.copy() + + self.fit_done = False + if traces is not None and pelt is not None: from .pelt import PELT @@ -730,19 +793,14 @@ class PTAModel: self.find_substates() else: self.pelt = None - self.stats = ParamStats( - self.by_name, - self.by_param, - self._parameter_names, - self._num_args, - self._use_corrcoef, - ) - self.cache = {} + + self._aggregate_to_ndarray(self.by_name) + + self._compute_stats(by_name) + np.seterr("raise") - self.function_override = function_override.copy() self.pta = pta self.ignore_trace_indexes = ignore_trace_indexes - self._aggregate_to_ndarray(self.by_name) def _aggregate_to_ndarray(self, aggregate): for elem in aggregate.values(): @@ -773,157 +831,6 @@ class PTAModel: logger.warning("Got no data for {} {}: {}".format(name, key, fpe)) return model - def get_static(self, use_mean=False): - """ - Get static model function: name, attribute -> model value. - - Uses the median of by_name for modeling, unless `use_mean` is set. - """ - getter_function = np.median - - if use_mean: - getter_function = np.mean - - static_model = self._get_model_from_dict(self.by_name, getter_function) - - def static_model_getter(name, key, **kwargs): - return static_model[name][key] - - return static_model_getter - - def get_param_lut(self, fallback=False): - """ - Get parameter-look-up-table model function: name, attribute, parameter values -> model value. - - The function can only give model values for parameter combinations - present in by_param. By default, it raises KeyError for other values. - - arguments: - fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values - """ - static_model = self._get_model_from_dict(self.by_name, np.median) - lut_model = self._get_model_from_dict(self.by_param, np.median) - - def lut_median_getter(name, key, param, arg=[], **kwargs): - param.extend(map(soft_cast_int, arg)) - try: - return lut_model[(name, tuple(param))][key] - except KeyError: - if fallback: - return static_model[name][key] - raise - - return lut_median_getter - - def param_index(self, param_name): - if param_name in self._parameter_names: - return self._parameter_names.index(param_name) - return len(self._parameter_names) + int(param_name) - - def param_name(self, param_index): - if param_index < len(self._parameter_names): - return self._parameter_names[param_index] - return str(param_index) - - def get_fitted(self, safe_functions_enabled=False): - """ - Get parameter-aware model function and model information function. - - Returns two functions: - model_function(name, attribute, param=parameter values) -> model value. - model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None - """ - if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache: - return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"] - - static_model = self._get_model_from_dict(self.by_name, np.median) - param_model = dict( - [[state_or_tran, {}] for state_or_tran in self.by_name.keys()] - ) - paramfit = ParallelParamFit(self.by_param) - for state_or_tran in self.by_name.keys(): - for model_attribute in self.by_name[state_or_tran]["attributes"]: - fit_results = {} - for parameter_index, parameter_name in enumerate(self._parameter_names): - if self.depends_on_param( - state_or_tran, model_attribute, parameter_name - ): - paramfit.enqueue( - state_or_tran, - model_attribute, - parameter_index, - parameter_name, - safe_functions_enabled, - ) - if ( - arg_support_enabled - and self.by_name[state_or_tran]["isa"] == "transition" - ): - for arg_index in range(self._num_args[state_or_tran]): - if self.depends_on_arg( - state_or_tran, model_attribute, arg_index - ): - paramfit.enqueue( - state_or_tran, - model_attribute, - len(self._parameter_names) + arg_index, - arg_index, - safe_functions_enabled, - ) - paramfit.fit() - - for state_or_tran in self.by_name.keys(): - num_args = 0 - if ( - arg_support_enabled - and self.by_name[state_or_tran]["isa"] == "transition" - ): - num_args = self._num_args[state_or_tran] - for model_attribute in self.by_name[state_or_tran]["attributes"]: - fit_results = paramfit.get_result(state_or_tran, model_attribute) - - if (state_or_tran, model_attribute) in self.function_override: - function_str = self.function_override[ - (state_or_tran, model_attribute) - ] - x = AnalyticFunction(function_str, self._parameter_names, num_args) - x.fit(self.by_param, state_or_tran, model_attribute) - if x.fit_success: - param_model[state_or_tran][model_attribute] = { - "fit_result": fit_results, - "function": x, - } - elif len(fit_results.keys()): - x = analytic.function_powerset( - fit_results, self._parameter_names, num_args - ) - x.fit(self.by_param, state_or_tran, model_attribute) - if x.fit_success: - param_model[state_or_tran][model_attribute] = { - "fit_result": fit_results, - "function": x, - } - - def model_getter(name, key, **kwargs): - if "arg" in kwargs and "param" in kwargs: - kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - if key in param_model[name]: - param_list = kwargs["param"] - param_function = param_model[name][key]["function"] - if param_function.is_predictable(param_list): - return param_function.eval(param_list) - return static_model[name][key] - - def info_getter(name, key): - if key in param_model[name]: - return param_model[name][key] - return None - - self.cache["fitted_model_getter"] = model_getter - self.cache["fitted_info_getter"] = info_getter - - return model_getter, info_getter - def pelt_refine(self, by_param_key): logger.debug(f"PELT: {by_param_key} needs refinement") @@ -1112,12 +1019,6 @@ class PTAModel: ret.extend(self.transitions()) return ret - def parameters(self): - return self._parameter_names - - def attributes(self, state_or_trans): - return self.by_name[state_or_trans]["attributes"] - def assess(self, model_function, ref=None): """ Calculate MAE, SMAPE, etc. of model_function for each by_name entry. |