diff options
-rwxr-xr-x | bin/analyze-archive.py | 81 | ||||
-rwxr-xr-x | bin/analyze-timing.py | 21 | ||||
-rw-r--r-- | lib/functions.py | 77 | ||||
-rw-r--r-- | lib/model.py | 45 | ||||
-rwxr-xr-x | test/test_ptamodel.py | 18 | ||||
-rwxr-xr-x | test/test_timingharness.py | 32 |
6 files changed, 164 insertions, 110 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 0be9ab0..ee23a75 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -43,7 +43,7 @@ import random import sys from dfatool import plotter from dfatool.loader import RawData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function +from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo from dfatool.model import PTAModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate @@ -365,6 +365,24 @@ def print_static(model, static_model, name, attribute): ) +def print_analyticinfo(prefix, info): + empty = "" + print(f"{prefix}: {info.function.model_function}") + print(f"{empty:{len(prefix)}s} {info.function.model_args}") + + +def print_splitinfo(param_names, info, prefix=""): + if type(info) is SplitInfo: + for k, v in info.child.items(): + print_splitinfo( + param_names, v, f"{prefix} {param_names[info.param_index]}={k}" + ) + elif type(info) is AnalyticInfo: + print(f"{prefix} = analytic") + else: + print(f"{prefix} = static") + + if __name__ == "__main__": ignored_trace_indexes = [] @@ -871,9 +889,9 @@ if __name__ == "__main__": ], ) ) - if info is not None: - for param_name in sorted(info["fit_result"].keys(), key=str): - param_fit = info["fit_result"][param_name]["results"] + if type(info) is AnalyticInfo: + for param_name in sorted(info.fit_result.keys(), key=str): + param_fit = info.fit_result[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -889,60 +907,37 @@ if __name__ == "__main__": if "param" in show_models or "all" in show_models: for state in model.states(): for attribute in model.attributes(state): - if param_info(state, attribute): - print( - "{:10s} {:15s}: {}".format( - state, - attribute, - param_info(state, attribute)["function"].model_function, - ) - ) - print( - "{:10s} {:15s} {}".format( - "", "", param_info(state, attribute)["function"].model_args - ) + info = param_info(state, attribute) + if type(info) is AnalyticInfo: + print_analyticinfo(f"{state:10s} {attribute:15s}", info) + elif type(info) is SplitInfo: + print_splitinfo( + model.parameters, info, f"{state:10s} {attribute:15s}" ) for trans in model.transitions(): for attribute in model.attributes(trans): - if param_info(trans, attribute): - print( - "{:10s} {:15s}: {:10s}: {}".format( - trans, - attribute, - attribute, - param_info(trans, attribute)["function"].model_function, - ) - ) - print( - "{:10s} {:15s} {:10s} {}".format( - "", - "", - "", - param_info(trans, attribute)["function"].model_args, - ) + info = param_info(trans, attribute) + if type(info) is AnalyticInfo: + print_analyticinfo(f"{trans:10s} {attribute:15s}", info) + elif type(info) is SplitInfo: + print_splitinfo( + model.parameters, info, f"{trans:10s} {attribute:15s}" ) if args.with_substates: for submodel in model.submodel_by_name.values(): sub_param_model, sub_param_info = submodel.get_fitted() for substate in submodel.states(): for subattribute in submodel.attributes(substate): - if sub_param_info(substate, subattribute): + info = sub_param_info(substate, subattribute) + if type(info) is AnalyticInfo: print( "{:10s} {:15s}: {}".format( - substate, - subattribute, - sub_param_info(substate, subattribute)[ - "function" - ].model_function, + substate, subattribute, info.function.model_function ) ) print( "{:10s} {:15s} {}".format( - "", - "", - sub_param_info(substate, subattribute)[ - "function" - ].model_args, + "", "", info.function.model_args ) ) diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index 1460dd3..4a11298 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -80,7 +80,7 @@ import re import sys from dfatool import plotter from dfatool.loader import TimingData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function +from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo from dfatool.model import AnalyticModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param @@ -387,9 +387,9 @@ if __name__ == "__main__": ].stats.arg_dependence_ratio(i), ) ) - if info is not None: - for param_name in sorted(info["fit_result"].keys(), key=str): - param_fit = info["fit_result"][param_name]["results"] + if type(info) is AnalyticInfo: + for param_name in sorted(info.fit_result.keys(), key=str): + param_fit = info.fit_result[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -405,19 +405,14 @@ if __name__ == "__main__": if "param" in show_models or "all" in show_models: for trans in model.names: for attribute in ["duration"]: - if param_info(trans, attribute): + info = param_info(trans, attribute) + if type(info) is AnalyticInfo: print( "{:10s}: {:10s}: {}".format( - trans, - attribute, - param_info(trans, attribute)["function"].model_function, - ) - ) - print( - "{:10s} {:10s} {}".format( - "", "", param_info(trans, attribute)["function"].model_args + trans, attribute, info.function.model_function ) ) + print("{:10s} {:10s} {}".format("", "", info.function.model_args)) if xv_method == "montecarlo": analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) diff --git a/lib/functions.py b/lib/functions.py index 0bdea45..067514f 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -152,7 +152,82 @@ class NormalizationFunction: return self._function(param_value) -class AnalyticFunction: +class ModelInfo: + def __init__(self): + pass + + +class AnalyticInfo(ModelInfo): + def __init__(self, fit_result, function): + self.fit_result = fit_result + self.function = function + + +class SplitInfo(ModelInfo): + def __init__(self, param_index, child): + self.param_index = param_index + self.child = child + + +class ModelFunction: + def __init__(self): + pass + + def is_predictable(self, param_list): + raise NotImplementedError + + def eval(self, param_list, arg_list): + raise NotImplementedError + + +class StaticFunction(ModelFunction): + def __init__(self, value): + self.value = value + + def is_predictable(self, param_list=None): + """ + Return whether the model function can be evaluated on the given parameter values. + + For a StaticFunction, this is always the case (i.e., this function always returns true). + """ + return True + + def eval(self, param_list=None, arg_list=None): + """ + Evaluate model function with specified param/arg values. + + Far a Staticfunction, this is just the static value + + """ + return self.value + + +class SplitFunction(ModelFunction): + def __init__(self, param_index, child): + self.param_index = param_index + self.child = child + + def is_predictable(self, param_list): + """ + Return whether the model function can be evaluated on the given parameter values. + + The first value corresponds to the lexically first model parameter, etc. + All parameters must be set, not just the ones this function depends on. + + Returns False iff a parameter the function depends on is not numeric + (e.g. None). + """ + param_value = param_list[self.param_index] + if param_value not in self.child: + return False + return self.child[param_value].is_predictable(param_list) + + def eval(self, param_list, arg_list=list()): + param_value = param_list[self.param_index] + return self.child[param_value].eval(param_list, arg_list) + + +class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. diff --git a/lib/model.py b/lib/model.py index 83c31b1..cddfe27 100644 --- a/lib/model.py +++ b/lib/model.py @@ -7,8 +7,7 @@ from scipy import optimize from sklearn.metrics import r2_score from multiprocessing import Pool from .automata import PTA -from .functions import analytic -from .functions import AnalyticFunction +import dfatool.functions as df from .parameters import ParallelParamStats, ParamStats from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple from .utils import ( @@ -211,7 +210,7 @@ def _try_fits( :param param_filter: Only use measurements whose parameters match param_filter for fitting. """ - functions = analytic.functions(safe_functions_enabled=safe_functions_enabled) + functions = df.analytic.functions(safe_functions_enabled=safe_functions_enabled) for param_key in n_by_param.keys(): # We might remove elements from 'functions' while iterating over @@ -532,31 +531,33 @@ class ModelAttribute: "child": dict(), "child_static": dict(), } - info_map = {"split_by": split_param_index, "child": dict()} + function_child = dict() + info_child = dict() for param_value, child in child_by_param_value.items(): child.set_data_from_paramfit(paramfit, prefix + (param_value,)) - function_map["child"][param_value], info_map["child"][ - param_value - ] = child.get_fitted() - function_map["child_static"][param_value] = child.get_static() + function_child[param_value], info_child[param_value] = child.get_fitted() + function_map = df.SplitFunction(split_param_index, function_child) + info_map = df.SplitInfo(split_param_index, info_child) self.param_model = function_map, info_map def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) - param_model = (None, None) + param_model = (df.StaticFunction(np.median(self.data)), None) if self.function_override is not None: function_str = self.function_override - x = AnalyticFunction(function_str, self.param_names, self.arg_count) + x = df.AnalyticFunction(function_str, self.param_names, self.arg_count) x.fit(self.by_param) if x.fit_success: - param_model = (x, fit_result) + param_model = (x, df.AnalyticInfo(fit_result, x)) elif len(fit_result.keys()): - x = analytic.function_powerset(fit_result, self.param_names, self.arg_count) + x = df.analytic.function_powerset( + fit_result, self.param_names, self.arg_count + ) x.fit(self.by_param) if x.fit_success: - param_model = (x, fit_result) + param_model = (x, df.AnalyticInfo(fit_result, x)) self.param_model = param_model @@ -810,22 +811,12 @@ class AnalyticModel: def model_getter(name, key, **kwargs): param_function, param_info = self.attr_by_name[name][key].get_fitted() - if param_function is None: + if param_info is None: return static_model[name][key] if "arg" in kwargs and "param" in kwargs: kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - while type(param_function) is dict and "split_by" in param_function: - split_param_value = kwargs["param"][param_function["split_by"]] - split_static = param_function["child_static"][split_param_value] - param_function = param_function["child"][split_param_value] - param_info = param_info["child"][split_param_value] - - if param_function is None: - # TODO return static model of child - return split_static - if param_function.is_predictable(kwargs["param"]): return param_function.eval(kwargs["param"]) @@ -833,12 +824,10 @@ class AnalyticModel: def info_getter(name, key): try: - model_function, fit_result = self.attr_by_name[name][key].get_fitted() + model_function, model_info = self.attr_by_name[name][key].get_fitted() except KeyError: return None - if model_function is None: - return None - return {"function": model_function, "fit_result": fit_result} + return model_info return model_getter, info_getter diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py index 55f84b8..257215f 100755 --- a/test/test_ptamodel.py +++ b/test/test_ptamodel.py @@ -640,26 +640,26 @@ class TestFromFile(unittest.TestCase): param_model, param_info = model.get_fitted() self.assertEqual(param_info("POWERDOWN", "power"), None) self.assertEqual( - param_info("RX", "power")["function"].model_function, + param_info("RX", "power").function.model_function, "0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))", ) self.assertAlmostEqual( - param_info("RX", "power")["function"].model_args[0], 48530.7, places=0 + param_info("RX", "power").function.model_args[0], 48530.7, places=0 ) self.assertAlmostEqual( - param_info("RX", "power")["function"].model_args[1], 117, places=0 + param_info("RX", "power").function.model_args[1], 117, places=0 ) self.assertEqual(param_info("STANDBY1", "power"), None) self.assertEqual( - param_info("TX", "power")["function"].model_function, + param_info("TX", "power").function.model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)", ) self.assertEqual( - param_info("epilogue", "timeout")["function"].model_function, + param_info("epilogue", "timeout").function.model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", ) self.assertEqual( - param_info("stopListening", "duration")["function"].model_function, + param_info("stopListening", "duration").function.model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", ) @@ -1825,7 +1825,7 @@ class TestFromFile(unittest.TestCase): param_model, param_info = model.get_fitted() self.assertEqual(param_info("IDLE", "power"), None) self.assertEqual( - param_info("RX", "power")["function"].model_function, + param_info("RX", "power").function.model_function, "0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)", ) self.assertEqual(param_info("SLEEP", "power"), None) @@ -1834,10 +1834,10 @@ class TestFromFile(unittest.TestCase): self.assertEqual(param_info("XOFF", "power"), None) self.assertAlmostEqual( - param_info("RX", "power")["function"].model_args[0], 84415, places=0 + param_info("RX", "power").function.model_args[0], 84415, places=0 ) self.assertAlmostEqual( - param_info("RX", "power")["function"].model_args[1], 206, places=0 + param_info("RX", "power").function.model_args[1], 206, places=0 ) diff --git a/test/test_timingharness.py b/test/test_timingharness.py index 9b55231..5cc0bec 100755 --- a/test/test_timingharness.py +++ b/test/test_timingharness.py @@ -34,21 +34,21 @@ class TestModels(unittest.TestCase): self.assertEqual(param_info("setRetries", "duration"), None) self.assertEqual(param_info("setup", "duration"), None) self.assertEqual( - param_info("write", "duration")["function"].model_function, + param_info("write", "duration").function.model_function, "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[0], 1163, places=0 + param_info("write", "duration").function.model_args[0], 1163, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[1], 464, places=0 + param_info("write", "duration").function.model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[2], 1, places=0 + param_info("write", "duration").function.model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[3], 1, places=0 + param_info("write", "duration").function.model_args[3], 1, places=0 ) def test_dependent_parameter_pruning(self): @@ -82,21 +82,21 @@ class TestModels(unittest.TestCase): self.assertEqual(param_info("setRetries", "duration"), None) self.assertEqual(param_info("setup", "duration"), None) self.assertEqual( - param_info("write", "duration")["function"].model_function, + param_info("write", "duration").function.model_function, "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[0], 1163, places=0 + param_info("write", "duration").function.model_args[0], 1163, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[1], 464, places=0 + param_info("write", "duration").function.model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[2], 1, places=0 + param_info("write", "duration").function.model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[3], 1, places=0 + param_info("write", "duration").function.model_args[3], 1, places=0 ) def test_function_override(self): @@ -140,24 +140,24 @@ class TestModels(unittest.TestCase): self.assertEqual(param_info("setRetries", "duration"), None) self.assertEqual(param_info("setup", "duration"), None) self.assertEqual( - param_info("write", "duration")["function"].model_function, + param_info("write", "duration").function.model_function, "(parameter(auto_ack!) * (regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay))) + ((1 - parameter(auto_ack!)) * regression_arg(4))", ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[0], 1162, places=0 + param_info("write", "duration").function.model_args[0], 1162, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[1], 464, places=0 + param_info("write", "duration").function.model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[2], 1, places=0 + param_info("write", "duration").function.model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[3], 1, places=0 + param_info("write", "duration").function.model_args[3], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration")["function"].model_args[4], 1086, places=0 + param_info("write", "duration").function.model_args[4], 1086, places=0 ) os.environ.pop("DFATOOL_NO_DECISIONTREES") |