diff options
-rwxr-xr-x | bin/analyze-archive.py | 42 | ||||
-rwxr-xr-x | bin/analyze-timing.py | 20 | ||||
-rw-r--r-- | lib/functions.py | 50 | ||||
-rw-r--r-- | lib/model.py | 8 | ||||
-rw-r--r-- | lib/parameters.py | 11 | ||||
-rwxr-xr-x | test/test_ptamodel.py | 40 | ||||
-rwxr-xr-x | test/test_timingharness.py | 56 |
7 files changed, 99 insertions, 128 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 3344d8a..65e25cc 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -43,7 +43,12 @@ import random import sys from dfatool import plotter from dfatool.loader import RawData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo, StaticInfo +from dfatool.functions import ( + gplearn_to_function, + SplitFunction, + AnalyticFunction, + StaticFunction, +) from dfatool.model import PTAModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate @@ -91,13 +96,14 @@ def model_quality_table(header, result_lists, info_list): info is None or ( key != "energy_Pt" - and type(info(state_or_tran, key)) is not StaticInfo + and type(info(state_or_tran, key)) is not StaticFunction ) or ( key == "energy_Pt" and ( - type(info(state_or_tran, "power")) is not StaticInfo - or type(info(state_or_tran, "duration")) is not StaticInfo + type(info(state_or_tran, "power")) is not StaticFunction + or type(info(state_or_tran, "duration")) + is not StaticFunction ) ) ): @@ -370,22 +376,22 @@ def print_static(model, static_model, name, attribute): def print_analyticinfo(prefix, info): empty = "" - print(f"{prefix}: {info.function.model_function}") - print(f"{empty:{len(prefix)}s} {info.function.model_args}") + print(f"{prefix}: {info.model_function}") + print(f"{empty:{len(prefix)}s} {info.model_args}") def print_splitinfo(param_names, info, prefix=""): - if type(info) is SplitInfo: + if type(info) is SplitFunction: for k, v in info.child.items(): if info.param_index < len(param_names): param_name = param_names[info.param_index] else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") - elif type(info) is AnalyticInfo: + elif type(info) is AnalyticFunction: print_analyticinfo(prefix, info) - elif type(info) is StaticInfo: - print(f"{prefix}: {info.median}") + elif type(info) is StaticFunction: + print(f"{prefix}: {info.value}") else: print(f"{prefix}: UNKNOWN") @@ -896,9 +902,9 @@ if __name__ == "__main__": ], ) ) - if type(info) is AnalyticInfo: - for param_name in sorted(info.fit_result.keys(), key=str): - param_fit = info.fit_result[param_name]["results"] + if type(info) is AnalyticFunction: + for param_name in sorted(info.fit_by_param.keys(), key=str): + param_fit = info.fit_by_param[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -915,18 +921,18 @@ if __name__ == "__main__": for state in model.states: for attribute in model.attributes(state): info = param_info(state, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print_analyticinfo(f"{state:10s} {attribute:15s}", info) - elif type(info) is SplitInfo: + elif type(info) is SplitFunction: print_splitinfo( model.parameters, info, f"{state:10s} {attribute:15s}" ) for trans in model.transitions: for attribute in model.attributes(trans): info = param_info(trans, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print_analyticinfo(f"{trans:10s} {attribute:15s}", info) - elif type(info) is SplitInfo: + elif type(info) is SplitFunction: print_splitinfo( model.parameters, info, f"{trans:10s} {attribute:15s}" ) @@ -936,7 +942,7 @@ if __name__ == "__main__": for substate in submodel.states: for subattribute in submodel.attributes(substate): info = sub_param_info(substate, subattribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print( "{:10s} {:15s}: {}".format( substate, subattribute, info.function.model_function diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index 10c8f1d..e8af0fb 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -80,10 +80,10 @@ import re import sys from dfatool import plotter from dfatool.loader import TimingData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo +from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunction from dfatool.model import AnalyticModel from dfatool.validation import CrossValidator -from dfatool.utils import filter_aggregate_by_param +from dfatool.utils import filter_aggregate_by_param, NpEncoder from dfatool.parameters import prune_dependent_parameters opt = dict() @@ -117,7 +117,7 @@ def model_quality_table(result_lists, info_list): for i, results in enumerate(result_lists): info = info_list[i] buf += " ||| " - if info is None or info(state_or_tran, key): + if info is None or type(info(state_or_tran, key)) is not StaticFunction: result = results["by_name"][state_or_tran][key] buf += format_quality_measures(result) else: @@ -387,9 +387,9 @@ if __name__ == "__main__": ].stats.arg_dependence_ratio(i), ) ) - if type(info) is AnalyticInfo: - for param_name in sorted(info.fit_result.keys(), key=str): - param_fit = info.fit_result[param_name]["results"] + if type(info) is AnalyticFunction: + for param_name in sorted(info.fit_by_param.keys(), key=str): + param_fit = info.fit_by_param[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -406,13 +406,13 @@ if __name__ == "__main__": for trans in model.names: for attribute in ["duration"]: info = param_info(trans, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print( "{:10s}: {:10s}: {}".format( - trans, attribute, info.function.model_function + trans, attribute, info.model_function ) ) - print("{:10s} {:10s} {}".format("", "", info.function.model_args)) + print("{:10s} {:10s} {}".format("", "", info.model_args)) if xv_method == "montecarlo": analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) @@ -451,4 +451,6 @@ if __name__ == "__main__": extra_function=function, ) + # print(json.dumps(model.to_json(), cls=NpEncoder, indent=2)) + sys.exit(0) diff --git a/lib/functions.py b/lib/functions.py index f3e5ac8..663b65e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -153,11 +153,6 @@ class NormalizationFunction: return self._function(param_value) -class ModelInfo: - def __init__(self): - self.error = None - - class ModelFunction: def __init__(self): pass @@ -194,17 +189,6 @@ class StaticFunction(ModelFunction): return {"type": "static", "value": self.value} -class StaticInfo(ModelInfo): - def __init__(self, data): - super() - self.mean = np.mean(data) - self.median = np.median(data) - self.std = np.std(data) - - def to_json(self): - return "FIXME" - - class SplitFunction(ModelFunction): def __init__(self, param_index, child): self.param_index = param_index @@ -237,16 +221,6 @@ class SplitFunction(ModelFunction): } -class SplitInfo(ModelInfo): - def __init__(self, param_index, child): - super() - self.param_index = param_index - self.child = child - - def to_json(self): - return "FIXME" - - class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. @@ -256,7 +230,14 @@ class AnalyticFunction(ModelFunction): packet length. """ - def __init__(self, function_str, parameters, num_args, regression_args=None): + def __init__( + self, + function_str, + parameters, + num_args, + regression_args=None, + fit_by_param=None, + ): """ Create a new AnalyticFunction object from a function string. @@ -281,6 +262,7 @@ class AnalyticFunction(ModelFunction): rawfunction = function_str self._dependson = [False] * (len(parameters) + num_args) self.fit_success = False + self.fit_by_param = fit_by_param if type(function_str) == str: num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)") @@ -440,16 +422,6 @@ class AnalyticFunction(ModelFunction): } -class AnalyticInfo(ModelInfo): - def __init__(self, fit_result, function): - super() - self.fit_result = fit_result - self.function = function - - def to_json(self): - return "FIXME" - - class analytic: """ Utilities for analytic description of parameter-dependent model attributes and regression analysis. @@ -644,4 +616,6 @@ class analytic: "parameter", function_item[0], function_item[1]["best"] ) ) - return AnalyticFunction(buf, parameter_names, num_args) + return AnalyticFunction( + buf, parameter_names, num_args, fit_by_param=fit_results + ) diff --git a/lib/model.py b/lib/model.py index be63a8a..ccbf719 100644 --- a/lib/model.py +++ b/lib/model.py @@ -4,7 +4,7 @@ import logging import numpy as np import os from .automata import PTA, ModelAttribute -from .functions import StaticInfo +from .functions import StaticFunction from .parameters import ParallelParamStats from .paramfit import ParallelParamFit from .utils import soft_cast_int, by_name_to_by_param, regression_measures @@ -255,10 +255,10 @@ class AnalyticModel: def model_getter(name, key, **kwargs): model_function = self.attr_by_name[name][key].model_function - model_info = self.attr_by_name[name][key].model_info + model_info = self.attr_by_name[name][key].model_function # shortcut - if type(model_info) is StaticInfo: + if type(model_info) is StaticFunction: return static_model[name][key] if "arg" in kwargs and "param" in kwargs: @@ -271,7 +271,7 @@ class AnalyticModel: def info_getter(name, key): try: - return self.attr_by_name[name][key].model_info + return self.attr_by_name[name][key].model_function except KeyError: return None diff --git a/lib/parameters.py b/lib/parameters.py index fa01804..1cad7a5 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -594,7 +594,6 @@ class ModelAttribute: # The best model we have. May be Static, Split, or Param (and later perhaps Substate) self.model_function = None - self.model_info = None def __repr__(self): mean = np.mean(self.data) @@ -605,7 +604,6 @@ class ModelAttribute: "paramNames": self.param_names, "argCount": self.arg_count, "modelFunction": self.model_function.to_json(), - "modelInfo": self.model_info.to_json(), } return ret @@ -784,21 +782,19 @@ class ModelAttribute: for param_value, child in child_by_param_value.items(): child.set_data_from_paramfit(paramfit, prefix + (param_value,)) function_child[param_value] = child.model_function - info_child[param_value] = child.model_info self.model_function = df.SplitFunction(split_param_index, function_child) - self.model_info = df.SplitInfo(split_param_index, info_child) def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) self.model_function = df.StaticFunction(self.median) - self.model_info = df.StaticInfo(self.data) if self.function_override is not None: function_str = self.function_override - x = df.AnalyticFunction(function_str, self.param_names, self.arg_count) + x = df.AnalyticFunction( + function_str, self.param_names, self.arg_count, fit_by_param=fit_result + ) x.fit(self.by_param) if x.fit_success: self.model_function = x - self.model_info = df.AnalyticInfo(fit_result, x) elif os.getenv("DFATOOL_NO_PARAM"): pass elif len(fit_result.keys()): @@ -809,4 +805,3 @@ class ModelAttribute: if x.fit_success: self.model_function = x - self.model_info = df.AnalyticInfo(fit_result, x) diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py index e571dcc..9f5076c 100755 --- a/test/test_ptamodel.py +++ b/test/test_ptamodel.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from dfatool.functions import StaticInfo +from dfatool.functions import StaticFunction from dfatool.loader import RawData, pta_trace_to_aggregate from dfatool.model import PTAModel from dfatool.utils import by_name_to_by_param @@ -639,28 +639,26 @@ class TestFromFile(unittest.TestCase): ) param_model, param_info = model.get_fitted() - self.assertIsInstance(param_info("POWERDOWN", "power"), StaticInfo) + self.assertIsInstance(param_info("POWERDOWN", "power"), StaticFunction) self.assertEqual( - param_info("RX", "power").function.model_function, + param_info("RX", "power").model_function, "0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))", ) self.assertAlmostEqual( - param_info("RX", "power").function.model_args[0], 48530.7, places=0 + param_info("RX", "power").model_args[0], 48530.7, places=0 ) - self.assertAlmostEqual( - param_info("RX", "power").function.model_args[1], 117, places=0 - ) - self.assertIsInstance(param_info("STANDBY1", "power"), StaticInfo) + self.assertAlmostEqual(param_info("RX", "power").model_args[1], 117, places=0) + self.assertIsInstance(param_info("STANDBY1", "power"), StaticFunction) self.assertEqual( - param_info("TX", "power").function.model_function, + param_info("TX", "power").model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)", ) self.assertEqual( - param_info("epilogue", "timeout").function.model_function, + param_info("epilogue", "timeout").model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", ) self.assertEqual( - param_info("stopListening", "duration").function.model_function, + param_info("stopListening", "duration").model_function, "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", ) @@ -1823,22 +1821,18 @@ class TestFromFile(unittest.TestCase): """ param_model, param_info = model.get_fitted() - self.assertIsInstance(param_info("IDLE", "power"), StaticInfo) + self.assertIsInstance(param_info("IDLE", "power"), StaticFunction) self.assertEqual( - param_info("RX", "power").function.model_function, + param_info("RX", "power").model_function, "0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)", ) - self.assertIsInstance(param_info("SLEEP", "power"), StaticInfo) - self.assertIsInstance(param_info("SLEEP_EWOR", "power"), StaticInfo) - self.assertIsInstance(param_info("SYNTH_ON", "power"), StaticInfo) - self.assertIsInstance(param_info("XOFF", "power"), StaticInfo) + self.assertIsInstance(param_info("SLEEP", "power"), StaticFunction) + self.assertIsInstance(param_info("SLEEP_EWOR", "power"), StaticFunction) + self.assertIsInstance(param_info("SYNTH_ON", "power"), StaticFunction) + self.assertIsInstance(param_info("XOFF", "power"), StaticFunction) - self.assertAlmostEqual( - param_info("RX", "power").function.model_args[0], 84415, places=0 - ) - self.assertAlmostEqual( - param_info("RX", "power").function.model_args[1], 206, places=0 - ) + self.assertAlmostEqual(param_info("RX", "power").model_args[0], 84415, places=0) + self.assertAlmostEqual(param_info("RX", "power").model_args[1], 206, places=0) if __name__ == "__main__": diff --git a/test/test_timingharness.py b/test/test_timingharness.py index 8c68e4a..06edc16 100755 --- a/test/test_timingharness.py +++ b/test/test_timingharness.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -from dfatool.functions import StaticInfo +from dfatool.functions import StaticFunction from dfatool.loader import TimingData, pta_trace_to_aggregate from dfatool.model import AnalyticModel from dfatool.parameters import prune_dependent_parameters @@ -31,25 +31,25 @@ class TestModels(unittest.TestCase): ) param_model, param_info = model.get_fitted() - self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo) - self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo) - self.assertIsInstance(param_info("setup", "duration"), StaticInfo) + self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction) + self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction) + self.assertIsInstance(param_info("setup", "duration"), StaticFunction) self.assertEqual( - param_info("write", "duration").function.model_function, + param_info("write", "duration").model_function, "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[0], 1163, places=0 + param_info("write", "duration").model_args[0], 1163, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[1], 464, places=0 + param_info("write", "duration").model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[2], 1, places=0 + param_info("write", "duration").model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[3], 1, places=0 + param_info("write", "duration").model_args[3], 1, places=0 ) def test_dependent_parameter_pruning(self): @@ -78,26 +78,26 @@ class TestModels(unittest.TestCase): ) param_model, param_info = model.get_fitted() - self.assertIsInstance(param_info("getObserveTx", "duration"), StaticInfo) - self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo) - self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo) - self.assertIsInstance(param_info("setup", "duration"), StaticInfo) + self.assertIsInstance(param_info("getObserveTx", "duration"), StaticFunction) + self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction) + self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction) + self.assertIsInstance(param_info("setup", "duration"), StaticFunction) self.assertEqual( - param_info("write", "duration").function.model_function, + param_info("write", "duration").model_function, "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[0], 1163, places=0 + param_info("write", "duration").model_args[0], 1163, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[1], 464, places=0 + param_info("write", "duration").model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[2], 1, places=0 + param_info("write", "duration").model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[3], 1, places=0 + param_info("write", "duration").model_args[3], 1, places=0 ) def test_function_override(self): @@ -136,29 +136,29 @@ class TestModels(unittest.TestCase): ) param_model, param_info = model.get_fitted() - self.assertIsInstance(param_info("setAutoAck", "duration"), StaticInfo) - self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo) - self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo) - self.assertIsInstance(param_info("setup", "duration"), StaticInfo) + self.assertIsInstance(param_info("setAutoAck", "duration"), StaticFunction) + self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction) + self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction) + self.assertIsInstance(param_info("setup", "duration"), StaticFunction) self.assertEqual( - param_info("write", "duration").function.model_function, + param_info("write", "duration").model_function, "(parameter(auto_ack!) * (regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay))) + ((1 - parameter(auto_ack!)) * regression_arg(4))", ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[0], 1162, places=0 + param_info("write", "duration").model_args[0], 1162, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[1], 464, places=0 + param_info("write", "duration").model_args[1], 464, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[2], 1, places=0 + param_info("write", "duration").model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[3], 1, places=0 + param_info("write", "duration").model_args[3], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").function.model_args[4], 1086, places=0 + param_info("write", "duration").model_args[4], 1086, places=0 ) os.environ.pop("DFATOOL_NO_DECISIONTREES") |