summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py519
1 files changed, 210 insertions, 309 deletions
diff --git a/lib/model.py b/lib/model.py
index 518566c..ce73a02 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -8,9 +8,14 @@ from multiprocessing import Pool
from .automata import PTA
from .functions import analytic
from .functions import AnalyticFunction
-from .parameters import ParamStats
+from .parameters import ParallelParamStats
from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
-from .utils import by_name_to_by_param, by_param_to_by_name, match_parameter_values
+from .utils import (
+ by_name_to_by_param,
+ by_param_to_by_name,
+ match_parameter_values,
+ partition_by_param,
+)
logger = logging.getLogger(__name__)
arg_support_enabled = True
@@ -95,38 +100,21 @@ class ParallelParamFit:
function type for each parameter.
"""
- def __init__(self, by_param):
+ def __init__(self):
"""Create a new ParallelParamFit object."""
- self.fit_queue = []
- self.by_param = by_param
+ self.fit_queue = list()
- def enqueue(
- self,
- state_or_tran,
- attribute,
- param_index,
- param_name,
- safe_functions_enabled=False,
- param_filter=None,
- ):
+ def enqueue(self, key, args):
"""
Add state_or_tran/attribute/param_name to fit queue.
This causes fit() to compute the best-fitting function for this model part.
+
+ :param key: (state/transition name, model attribute, parameter name)
+ :param args: [by_param, param_index, safe_functions_enabled, param_filter]
+ by_param[(param 1, param2, ...)] holds measurements.
"""
- # Transform by_param[(state_or_tran, param_value)][attribute] = ...
- # into n_by_param[param_value] = ...
- # (param_value is dynamic, the rest is fixed)
- n_by_param = dict()
- for k, v in self.by_param.items():
- if k[0] == state_or_tran:
- n_by_param[k[1]] = v[attribute]
- self.fit_queue.append(
- {
- "key": [state_or_tran, attribute, param_name, param_filter],
- "args": [n_by_param, param_index, safe_functions_enabled, param_filter],
- }
- )
+ self.fit_queue.append({"key": key, "args": args})
def fit(self):
"""
@@ -139,14 +127,12 @@ class ParallelParamFit:
with Pool() as pool:
self.results = pool.map(_try_fits_parallel, self.fit_queue)
- def get_result(self, name, attribute, param_filter: dict = None):
+ def get_result(self, name, attr):
"""
- Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'.
+ Parse and sanitize fit results.
Filters out results where the best function is worse (or not much better than) static mean/median estimates.
- :param name: state/transition/... name, e.g. 'TX'
- :param attribute: model attribute, e.g. 'duration'
:param param_filter:
:returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} }
"""
@@ -154,8 +140,7 @@ class ParallelParamFit:
for result in self.results:
if (
result["key"][0] == name
- and result["key"][1] == attribute
- and result["key"][3] == param_filter
+ and result["key"][1] == attr
and result["result"]["best"] is not None
): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
this_result = result["result"]
@@ -163,9 +148,7 @@ class ParallelParamFit:
this_result["mean_rmsd"], this_result["median_rmsd"]
):
logger.debug(
- "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
- name,
- attribute,
+ "Not modeling as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
result["key"][2],
this_result["best_rmsd"],
this_result["mean_rmsd"],
@@ -177,9 +160,7 @@ class ParallelParamFit:
this_result["mean_rmsd"], this_result["median_rmsd"]
):
logger.debug(
- "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
- name,
- attribute,
+ "Not modeling as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
result["key"][2],
this_result["best_rmsd"],
this_result["mean_rmsd"],
@@ -360,6 +341,81 @@ def _num_args_from_by_name(by_name):
return num_args
+class ModelAttribute:
+ def __init__(self, name, attr, data, param_values, param_names, arg_count=0):
+ self.name = name
+ self.attr = attr
+ self.data = data
+ self.param_values = param_values
+ self.param_names = sorted(param_names)
+ self.arg_count = arg_count
+ self.by_param = None # set via ParallelParamStats
+ self.function_override = None
+ self.param_model = None
+
+ def get_static(self, use_mean=False):
+ if use_mean:
+ return np.mean(self.data)
+ return np.median(self.data)
+
+ def get_lut(self, param, use_mean=False):
+ if use_mean:
+ return np.mean(self.by_param[param])
+ return np.median(self.by_param[param])
+
+ def get_data_for_paramfit(self, safe_functions_enabled=False):
+ ret = list()
+ for param_index, param_name in enumerate(self.param_names):
+ if self.stats.depends_on_param(param_name):
+ ret.append(
+ (param_name, (self.by_param, param_index, safe_functions_enabled))
+ )
+ if self.arg_count:
+ for arg_index in range(self.arg_count):
+ if self.stats.depends_on_arg(arg_index):
+ ret.append(
+ (
+ arg_index,
+ (
+ self.by_param,
+ len(self.param_names) + arg_index,
+ safe_functions_enabled,
+ ),
+ )
+ )
+ return ret
+
+ def set_data_from_paramfit(self, fit_result):
+ param_model = (None, None)
+ if self.function_override is not None:
+ function_str = self.function_override
+ x = AnalyticFunction(function_str, self.param_names, self.arg_count)
+ x.fit(self.by_param)
+ if x.fit_success:
+ param_model = (x, fit_result)
+ elif len(fit_result.keys()):
+ x = analytic.function_powerset(fit_result, self.param_names, self.arg_count)
+ x.fit(self.by_param)
+
+ if x.fit_success:
+ param_model = (x, fit_result)
+
+ self.param_model = param_model
+
+ def get_fitted(self):
+ """
+ Get paramete-aware model function and model information function.
+ They must have been set via get_data_for_paramfit -> ParallelParamFit -> set-data_from_paramfit first.
+
+ Returns a tuple (function, info):
+ function -> AnalyticFunction for model. function(param=parameter values) -> model value.
+ info -> {'fit_result' : ..., 'function' : ... }
+
+ Returns (None, None) if fitting failed. Returns None if ParamFit has not been performed yet.
+ """
+ return self.param_model
+
+
class AnalyticModel:
"""
Parameter-aware analytic energy/data size/... model.
@@ -447,8 +503,8 @@ class AnalyticModel:
:param use_corrcoef: use correlation coefficient instead of stddev comparison to detect whether a model attribute depends on a parameter
"""
self.cache = dict()
- self.by_name = by_name
- self.by_param = by_name_to_by_param(by_name)
+ self.by_name = by_name # no longer required?
+ self.attr_by_name = dict()
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
self.function_override = function_override.copy()
@@ -457,26 +513,33 @@ class AnalyticModel:
if self._num_args is None:
self._num_args = _num_args_from_by_name(by_name)
- self.stats = ParamStats(
- self.by_name,
- self.by_param,
- self.parameters,
- self._num_args,
- use_corrcoef=use_corrcoef,
- )
+ self.fit_done = False
- def _get_model_from_dict(self, model_dict, model_function):
- model = {}
- for name, elem in model_dict.items():
- model[name] = {}
- for key in elem["attributes"]:
- try:
- model[name][key] = model_function(elem[key])
- except RuntimeWarning:
- logger.warning("Got no data for {} {}".format(name, key))
- except FloatingPointError as fpe:
- logger.warning("Got no data for {} {}: {}".format(name, key, fpe))
- return model
+ self._compute_stats(by_name)
+
+ def _compute_stats(self, by_name):
+ paramstats = ParallelParamStats()
+
+ for name, data in by_name.items():
+ self.attr_by_name[name] = dict()
+ for attr in data["attributes"]:
+ model_attr = ModelAttribute(
+ name,
+ attr,
+ data[attr],
+ data["param"],
+ self.parameters,
+ self._num_args.get(name, 0),
+ )
+ self.attr_by_name[name][attr] = model_attr
+ paramstats.enqueue((name, attr), model_attr)
+ if (name, attr) in self.function_override:
+ model_attr.function_override = self.function_override[(name, attr)]
+
+ paramstats.compute()
+
+ def attributes(self, name):
+ return self.attr_by_name[name].keys()
def param_index(self, param_name):
if param_name in self.parameters:
@@ -492,21 +555,20 @@ class AnalyticModel:
"""
Get static model function: name, attribute -> model value.
- Uses the median of by_name for modeling.
+ Uses the median of by_name for modeling, unless `use_mean` is set.
"""
- getter_function = np.median
-
- if use_mean:
- getter_function = np.mean
-
- static_model = self._get_model_from_dict(self.by_name, getter_function)
+ model = dict()
+ for name, attr in self.attr_by_name.items():
+ model[name] = dict()
+ for k, v in attr.items():
+ model[name][k] = v.get_static(use_mean=use_mean)
def static_model_getter(name, key, **kwargs):
- return static_model[name][key]
+ return model[name][key]
return static_model_getter
- def get_param_lut(self, fallback=False):
+ def get_param_lut(self, use_mean=False, fallback=False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
@@ -516,13 +578,22 @@ class AnalyticModel:
arguments:
fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
"""
- static_model = self._get_model_from_dict(self.by_name, np.median)
- lut_model = self._get_model_from_dict(self.by_param, np.median)
-
- def lut_median_getter(name, key, param, arg=[], **kwargs):
+ static_model = dict()
+ lut_model = dict()
+ for name, attr in self.attr_by_name.items():
+ static_model[name] = dict()
+ lut_model[name] = dict()
+ for k, v in attr.items():
+ static_model[name][k] = v.get_static(use_mean=use_mean)
+ lut_model[name][k] = dict()
+ for param, model_value in v.by_param.items():
+ lut_model[name][k][param] = v.get_lut(param, use_mean=use_mean)
+
+ def lut_median_getter(name, key, param, arg=list(), **kwargs):
param.extend(map(soft_cast_int, arg))
+ param = tuple(param)
try:
- return lut_model[(name, tuple(param))][key]
+ return lut_model[name][key][param]
except KeyError:
if fallback:
return static_model[name][key]
@@ -530,84 +601,67 @@ class AnalyticModel:
return lut_median_getter
- def get_fitted(self, safe_functions_enabled=False):
+ def get_fitted(self, use_mean=False, safe_functions_enabled=False):
"""
- Get paramete-aware model function and model information function.
+ Get parameter-aware model function and model information function.
Returns two functions:
model_function(name, attribute, param=parameter values) -> model value.
model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
"""
- if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache:
- return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"]
-
- static_model = self._get_model_from_dict(self.by_name, np.median)
- param_model = dict([[name, {}] for name in self.by_name.keys()])
- paramfit = ParallelParamFit(self.by_param)
-
- for name in self.by_name.keys():
- for attribute in self.by_name[name]["attributes"]:
- for param_index, param in enumerate(self.parameters):
- if self.stats.depends_on_param(name, attribute, param):
- paramfit.enqueue(name, attribute, param_index, param, False)
- if arg_support_enabled and name in self._num_args:
- for arg_index in range(self._num_args[name]):
- if self.stats.depends_on_arg(name, attribute, arg_index):
- paramfit.enqueue(
- name,
- attribute,
- len(self.parameters) + arg_index,
- arg_index,
- False,
- )
- paramfit.fit()
-
- for name in self.by_name.keys():
- num_args = 0
- if name in self._num_args:
- num_args = self._num_args[name]
- for attribute in self.by_name[name]["attributes"]:
- fit_result = paramfit.get_result(name, attribute)
-
- if (name, attribute) in self.function_override:
- function_str = self.function_override[(name, attribute)]
- x = AnalyticFunction(function_str, self.parameters, num_args)
- x.fit(self.by_param, name, attribute)
- if x.fit_success:
- param_model[name][attribute] = {
- "fit_result": fit_result,
- "function": x,
- }
- elif len(fit_result.keys()):
- x = analytic.function_powerset(
- fit_result, self.parameters, num_args
+ if not self.fit_done:
+
+ paramfit = ParallelParamFit()
+
+ for name in self.names:
+ for attr in self.attr_by_name[name].keys():
+ for key, args in self.attr_by_name[name][
+ attr
+ ].get_data_for_paramfit(
+ safe_functions_enabled=safe_functions_enabled
+ ):
+ key = (name, attr, key)
+ paramfit.enqueue(key, args)
+
+ paramfit.fit()
+
+ for name in self.names:
+ for attr in self.attr_by_name[name].keys():
+ self.attr_by_name[name][attr].set_data_from_paramfit(
+ paramfit.get_result(name, attr)
)
- x.fit(self.by_param, name, attribute)
- if x.fit_success:
- param_model[name][attribute] = {
- "fit_result": fit_result,
- "function": x,
- }
+ self.fit_done = True
+
+ static_model = dict()
+ for name, attr in self.attr_by_name.items():
+ static_model[name] = dict()
+ for k, v in attr.items():
+ static_model[name][k] = v.get_static(use_mean=use_mean)
def model_getter(name, key, **kwargs):
+ param_function, _ = self.attr_by_name[name][key].get_fitted()
+
+ if param_function is None:
+ return static_model[name][key]
+
if "arg" in kwargs and "param" in kwargs:
kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
- if key in param_model[name]:
- param_list = kwargs["param"]
- param_function = param_model[name][key]["function"]
- if param_function.is_predictable(param_list):
- return param_function.eval(param_list)
+
+ if param_function.is_predictable(kwargs["param"]):
+ return param_function.eval(kwargs["param"])
+
return static_model[name][key]
def info_getter(name, key):
- if key in param_model[name]:
- return param_model[name][key]
- return None
-
- self.cache["fitted_model_getter"] = model_getter
- self.cache["fitted_info_getter"] = info_getter
+ try:
+ model_function, fit_result = self.attr_by_name[name][key].get_fitted()
+ except KeyError:
+ return None
+ if model_function is None:
+ return None
+ return {"function": model_function, "fit_result": fit_result}
return model_getter, info_getter
@@ -625,20 +679,22 @@ class AnalyticModel:
overfitting cannot be detected.
"""
detailed_results = {}
- for name, elem in sorted(self.by_name.items()):
+ for name in self.names:
detailed_results[name] = {}
- for attribute in elem["attributes"]:
+ for attribute in self.attr_by_name[name].keys():
+ data = self.attr_by_name[name][attribute].data
+ param_values = self.attr_by_name[name][attribute].param_values
predicted_data = np.array(
list(
map(
lambda i: model_function(
- name, attribute, param=elem["param"][i]
+ name, attribute, param=param_values[i]
),
- range(len(elem[attribute])),
+ range(len(data)),
)
)
)
- measures = regression_measures(predicted_data, elem[attribute])
+ measures = regression_measures(predicted_data, data)
detailed_results[name][attribute] = measures
return {"by_name": detailed_results}
@@ -648,7 +704,7 @@ class AnalyticModel:
pass
-class PTAModel:
+class PTAModel(AnalyticModel):
"""
Parameter-aware PTA-based energy model.
@@ -718,11 +774,18 @@ class PTAModel:
pelt -- perform sub-state detection via PELT and model sub-states as well. Requires traces to be set.
"""
self.by_name = by_name
+ self.attr_by_name = dict()
self.by_param = by_name_to_by_param(by_name)
+ self.names = sorted(by_name.keys())
self._parameter_names = sorted(parameters)
+ self.parameters = sorted(parameters)
self._num_args = arg_count
self._use_corrcoef = use_corrcoef
self.traces = traces
+ self.function_override = function_override.copy()
+
+ self.fit_done = False
+
if traces is not None and pelt is not None:
from .pelt import PELT
@@ -730,19 +793,14 @@ class PTAModel:
self.find_substates()
else:
self.pelt = None
- self.stats = ParamStats(
- self.by_name,
- self.by_param,
- self._parameter_names,
- self._num_args,
- self._use_corrcoef,
- )
- self.cache = {}
+
+ self._aggregate_to_ndarray(self.by_name)
+
+ self._compute_stats(by_name)
+
np.seterr("raise")
- self.function_override = function_override.copy()
self.pta = pta
self.ignore_trace_indexes = ignore_trace_indexes
- self._aggregate_to_ndarray(self.by_name)
def _aggregate_to_ndarray(self, aggregate):
for elem in aggregate.values():
@@ -773,157 +831,6 @@ class PTAModel:
logger.warning("Got no data for {} {}: {}".format(name, key, fpe))
return model
- def get_static(self, use_mean=False):
- """
- Get static model function: name, attribute -> model value.
-
- Uses the median of by_name for modeling, unless `use_mean` is set.
- """
- getter_function = np.median
-
- if use_mean:
- getter_function = np.mean
-
- static_model = self._get_model_from_dict(self.by_name, getter_function)
-
- def static_model_getter(name, key, **kwargs):
- return static_model[name][key]
-
- return static_model_getter
-
- def get_param_lut(self, fallback=False):
- """
- Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
-
- The function can only give model values for parameter combinations
- present in by_param. By default, it raises KeyError for other values.
-
- arguments:
- fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
- """
- static_model = self._get_model_from_dict(self.by_name, np.median)
- lut_model = self._get_model_from_dict(self.by_param, np.median)
-
- def lut_median_getter(name, key, param, arg=[], **kwargs):
- param.extend(map(soft_cast_int, arg))
- try:
- return lut_model[(name, tuple(param))][key]
- except KeyError:
- if fallback:
- return static_model[name][key]
- raise
-
- return lut_median_getter
-
- def param_index(self, param_name):
- if param_name in self._parameter_names:
- return self._parameter_names.index(param_name)
- return len(self._parameter_names) + int(param_name)
-
- def param_name(self, param_index):
- if param_index < len(self._parameter_names):
- return self._parameter_names[param_index]
- return str(param_index)
-
- def get_fitted(self, safe_functions_enabled=False):
- """
- Get parameter-aware model function and model information function.
-
- Returns two functions:
- model_function(name, attribute, param=parameter values) -> model value.
- model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
- """
- if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache:
- return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"]
-
- static_model = self._get_model_from_dict(self.by_name, np.median)
- param_model = dict(
- [[state_or_tran, {}] for state_or_tran in self.by_name.keys()]
- )
- paramfit = ParallelParamFit(self.by_param)
- for state_or_tran in self.by_name.keys():
- for model_attribute in self.by_name[state_or_tran]["attributes"]:
- fit_results = {}
- for parameter_index, parameter_name in enumerate(self._parameter_names):
- if self.depends_on_param(
- state_or_tran, model_attribute, parameter_name
- ):
- paramfit.enqueue(
- state_or_tran,
- model_attribute,
- parameter_index,
- parameter_name,
- safe_functions_enabled,
- )
- if (
- arg_support_enabled
- and self.by_name[state_or_tran]["isa"] == "transition"
- ):
- for arg_index in range(self._num_args[state_or_tran]):
- if self.depends_on_arg(
- state_or_tran, model_attribute, arg_index
- ):
- paramfit.enqueue(
- state_or_tran,
- model_attribute,
- len(self._parameter_names) + arg_index,
- arg_index,
- safe_functions_enabled,
- )
- paramfit.fit()
-
- for state_or_tran in self.by_name.keys():
- num_args = 0
- if (
- arg_support_enabled
- and self.by_name[state_or_tran]["isa"] == "transition"
- ):
- num_args = self._num_args[state_or_tran]
- for model_attribute in self.by_name[state_or_tran]["attributes"]:
- fit_results = paramfit.get_result(state_or_tran, model_attribute)
-
- if (state_or_tran, model_attribute) in self.function_override:
- function_str = self.function_override[
- (state_or_tran, model_attribute)
- ]
- x = AnalyticFunction(function_str, self._parameter_names, num_args)
- x.fit(self.by_param, state_or_tran, model_attribute)
- if x.fit_success:
- param_model[state_or_tran][model_attribute] = {
- "fit_result": fit_results,
- "function": x,
- }
- elif len(fit_results.keys()):
- x = analytic.function_powerset(
- fit_results, self._parameter_names, num_args
- )
- x.fit(self.by_param, state_or_tran, model_attribute)
- if x.fit_success:
- param_model[state_or_tran][model_attribute] = {
- "fit_result": fit_results,
- "function": x,
- }
-
- def model_getter(name, key, **kwargs):
- if "arg" in kwargs and "param" in kwargs:
- kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
- if key in param_model[name]:
- param_list = kwargs["param"]
- param_function = param_model[name][key]["function"]
- if param_function.is_predictable(param_list):
- return param_function.eval(param_list)
- return static_model[name][key]
-
- def info_getter(name, key):
- if key in param_model[name]:
- return param_model[name][key]
- return None
-
- self.cache["fitted_model_getter"] = model_getter
- self.cache["fitted_info_getter"] = info_getter
-
- return model_getter, info_getter
-
def pelt_refine(self, by_param_key):
logger.debug(f"PELT: {by_param_key} needs refinement")
@@ -1112,12 +1019,6 @@ class PTAModel:
ret.extend(self.transitions())
return ret
- def parameters(self):
- return self._parameter_names
-
- def attributes(self, state_or_trans):
- return self.by_name[state_or_trans]["attributes"]
-
def assess(self, model_function, ref=None):
"""
Calculate MAE, SMAPE, etc. of model_function for each by_name entry.