summaryrefslogtreecommitdiff
path: root/lib/dfatool.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/dfatool.py')
-rw-r--r--lib/dfatool.py40
1 files changed, 25 insertions, 15 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 478f800..a3d5c0f 100644
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -18,7 +18,7 @@ from functions import analytic
from functions import AnalyticFunction
from parameters import ParamStats
from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
-from utils import by_name_to_by_param
+from utils import by_name_to_by_param, match_parameter_values
arg_support_enabled = True
@@ -505,12 +505,10 @@ class RawData:
self.cache_file = '{}/{}.json'.format(self.cache_dir, cache_key)
def load_cache(self):
- print('checking {}...'.format(self.cache_file))
if os.path.exists(self.cache_file):
with open(self.cache_file, 'r') as f:
self.traces = json.load(f)
self.preprocessed = True
- print('loaded cache')
def save_cache(self):
try:
@@ -902,15 +900,15 @@ class ParallelParamFit:
self.fit_queue = []
self.by_param = by_param
- def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False):
+ def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False, param_filter = None):
"""
Add state_or_tran/attribute/param_name to fit queue.
This causes fit() to compute the best-fitting function for this model part.
"""
self.fit_queue.append({
- 'key' : [state_or_tran, attribute, param_name],
- 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled]
+ 'key' : [state_or_tran, attribute, param_name, param_filter],
+ 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled, param_filter]
})
def fit(self):
@@ -935,16 +933,17 @@ def _try_fits_parallel(arg):
'result' : _try_fits(*arg['args'])
}
-
-def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False):
+def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False, param_filter: dict = None):
"""
Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions.
This is done by varying `param_index` while keeping all other parameters constant and doing one least squares optimization for each function and for each combination of the remaining parameters.
The value of the parameter corresponding to `param_index` (e.g. txpower or packet length) is the sole input to the model function.
+ Only numeric parameter values (as determined by `utils.is_numeric`) are used for fitting, non-numeric values such as None or enum strings are ignored.
+ Fitting is only performed if at least three distinct parameter values exist in `by_param[(state_or_tran, *)]`.
- :return: a dictionary with the following elements:
- best -- name of the best-fitting function (see `analytic.functions`)
+ :returns: a dictionary with the following elements:
+ best -- name of the best-fitting function (see `analytic.functions`). `None` in case of insufficient data.
best_rmsd -- mean Root Mean Square Deviation of best-fitting function over all combinations of the remaining parameters
mean_rmsd -- mean Root Mean Square Deviation of a reference model using the mean of its respective input data as model value
median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value
@@ -961,6 +960,7 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
:param param_index: index of the parameter used as model input
:param safe_functions_enabled: Include "safe" variants of functions with limited argument range.
+ :param param_filter: Only use measurements whose parameters match param_filter for fitting.
"""
functions = analytic.functions(safe_functions_enabled = safe_functions_enabled)
@@ -987,7 +987,7 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
seen_parameter_combinations = set()
# for each parameter combination:
- for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations, by_param.keys()):
+ for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations and len(by_param[x]['param']) and match_parameter_values(by_param[x]['param'][0], param_filter), by_param.keys()):
X = []
Y = []
num_valid = 0
@@ -1087,7 +1087,7 @@ def _num_args_from_by_name(by_name):
num_args[key] = len(value['args'][0])
return num_args
-def get_fit_result(results, name, attribute, verbose = False):
+def get_fit_result(results, name, attribute, verbose = False, param_filter: dict = None):
"""
Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'.
@@ -1097,10 +1097,12 @@ def get_fit_result(results, name, attribute, verbose = False):
:param name: state/transition/... name, e.g. 'TX'
:param attribute: model attribute, e.g. 'duration'
:param verbose: print debug message to stdout when deliberately not using a determined fit function
+ :param param_filter:
+ :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} }
"""
fit_result = dict()
for result in results:
- if result['key'][0] == name and result['key'][1] == attribute and result['result']['best'] != None:
+ if result['key'][0] == name and result['key'][1] == attribute and result['key'][3] == param_filter and result['result']['best'] != None: # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
this_result = result['result']
if this_result['best_rmsd'] >= min(this_result['mean_rmsd'], this_result['median_rmsd']):
vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})'.format(
@@ -1583,7 +1585,7 @@ class PTAModel:
"""
Get static model function: name, attribute -> model value.
- Uses the median of by_name for modeling.
+ Uses the median of by_name for modeling, unless `use_mean` is set.
"""
getter_function = np.median
@@ -1633,7 +1635,7 @@ class PTAModel:
def get_fitted(self, safe_functions_enabled = False):
"""
- Get paramete-aware model function and model information function.
+ Get parameter-aware model function and model information function.
Returns two functions:
model_function(name, attribute, param=parameter values) -> model value.
@@ -1651,6 +1653,8 @@ class PTAModel:
for parameter_index, parameter_name in enumerate(self._parameter_names):
if self.depends_on_param(state_or_tran, model_attribute, parameter_name):
paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled)
+ for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name):
+ paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled, codependent_param_dict)
if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition':
for arg_index in range(self._num_args[state_or_tran]):
if self.depends_on_arg(state_or_tran, model_attribute, arg_index):
@@ -1664,6 +1668,12 @@ class PTAModel:
for model_attribute in self.by_name[state_or_tran]['attributes']:
fit_results = get_fit_result(paramfit.results, state_or_tran, model_attribute, self.verbose)
+ for parameter_name in self._parameter_names:
+ if self.depends_on_param(state_or_tran, model_attribute, parameter_name):
+ for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name):
+ pass
+ # FIXME get_fit_result hat ja gar keinen Parameter als Argument...
+
if (state_or_tran, model_attribute) in self.function_override:
function_str = self.function_override[(state_or_tran, model_attribute)]
x = AnalyticFunction(function_str, self._parameter_names, num_args)