summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-archive.py81
-rwxr-xr-xbin/analyze-timing.py21
-rw-r--r--lib/functions.py77
-rw-r--r--lib/model.py45
-rwxr-xr-xtest/test_ptamodel.py18
-rwxr-xr-xtest/test_timingharness.py32
6 files changed, 164 insertions, 110 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 0be9ab0..ee23a75 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -43,7 +43,7 @@ import random
import sys
from dfatool import plotter
from dfatool.loader import RawData, pta_trace_to_aggregate
-from dfatool.functions import gplearn_to_function
+from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo
from dfatool.model import PTAModel
from dfatool.validation import CrossValidator
from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate
@@ -365,6 +365,24 @@ def print_static(model, static_model, name, attribute):
)
+def print_analyticinfo(prefix, info):
+ empty = ""
+ print(f"{prefix}: {info.function.model_function}")
+ print(f"{empty:{len(prefix)}s} {info.function.model_args}")
+
+
+def print_splitinfo(param_names, info, prefix=""):
+ if type(info) is SplitInfo:
+ for k, v in info.child.items():
+ print_splitinfo(
+ param_names, v, f"{prefix} {param_names[info.param_index]}={k}"
+ )
+ elif type(info) is AnalyticInfo:
+ print(f"{prefix} = analytic")
+ else:
+ print(f"{prefix} = static")
+
+
if __name__ == "__main__":
ignored_trace_indexes = []
@@ -871,9 +889,9 @@ if __name__ == "__main__":
],
)
)
- if info is not None:
- for param_name in sorted(info["fit_result"].keys(), key=str):
- param_fit = info["fit_result"][param_name]["results"]
+ if type(info) is AnalyticInfo:
+ for param_name in sorted(info.fit_result.keys(), key=str):
+ param_fit = info.fit_result[param_name]["results"]
for function_type in sorted(param_fit.keys()):
function_rmsd = param_fit[function_type]["rmsd"]
print(
@@ -889,60 +907,37 @@ if __name__ == "__main__":
if "param" in show_models or "all" in show_models:
for state in model.states():
for attribute in model.attributes(state):
- if param_info(state, attribute):
- print(
- "{:10s} {:15s}: {}".format(
- state,
- attribute,
- param_info(state, attribute)["function"].model_function,
- )
- )
- print(
- "{:10s} {:15s} {}".format(
- "", "", param_info(state, attribute)["function"].model_args
- )
+ info = param_info(state, attribute)
+ if type(info) is AnalyticInfo:
+ print_analyticinfo(f"{state:10s} {attribute:15s}", info)
+ elif type(info) is SplitInfo:
+ print_splitinfo(
+ model.parameters, info, f"{state:10s} {attribute:15s}"
)
for trans in model.transitions():
for attribute in model.attributes(trans):
- if param_info(trans, attribute):
- print(
- "{:10s} {:15s}: {:10s}: {}".format(
- trans,
- attribute,
- attribute,
- param_info(trans, attribute)["function"].model_function,
- )
- )
- print(
- "{:10s} {:15s} {:10s} {}".format(
- "",
- "",
- "",
- param_info(trans, attribute)["function"].model_args,
- )
+ info = param_info(trans, attribute)
+ if type(info) is AnalyticInfo:
+ print_analyticinfo(f"{trans:10s} {attribute:15s}", info)
+ elif type(info) is SplitInfo:
+ print_splitinfo(
+ model.parameters, info, f"{trans:10s} {attribute:15s}"
)
if args.with_substates:
for submodel in model.submodel_by_name.values():
sub_param_model, sub_param_info = submodel.get_fitted()
for substate in submodel.states():
for subattribute in submodel.attributes(substate):
- if sub_param_info(substate, subattribute):
+ info = sub_param_info(substate, subattribute)
+ if type(info) is AnalyticInfo:
print(
"{:10s} {:15s}: {}".format(
- substate,
- subattribute,
- sub_param_info(substate, subattribute)[
- "function"
- ].model_function,
+ substate, subattribute, info.function.model_function
)
)
print(
"{:10s} {:15s} {}".format(
- "",
- "",
- sub_param_info(substate, subattribute)[
- "function"
- ].model_args,
+ "", "", info.function.model_args
)
)
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index 1460dd3..4a11298 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -80,7 +80,7 @@ import re
import sys
from dfatool import plotter
from dfatool.loader import TimingData, pta_trace_to_aggregate
-from dfatool.functions import gplearn_to_function
+from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo
from dfatool.model import AnalyticModel
from dfatool.validation import CrossValidator
from dfatool.utils import filter_aggregate_by_param
@@ -387,9 +387,9 @@ if __name__ == "__main__":
].stats.arg_dependence_ratio(i),
)
)
- if info is not None:
- for param_name in sorted(info["fit_result"].keys(), key=str):
- param_fit = info["fit_result"][param_name]["results"]
+ if type(info) is AnalyticInfo:
+ for param_name in sorted(info.fit_result.keys(), key=str):
+ param_fit = info.fit_result[param_name]["results"]
for function_type in sorted(param_fit.keys()):
function_rmsd = param_fit[function_type]["rmsd"]
print(
@@ -405,19 +405,14 @@ if __name__ == "__main__":
if "param" in show_models or "all" in show_models:
for trans in model.names:
for attribute in ["duration"]:
- if param_info(trans, attribute):
+ info = param_info(trans, attribute)
+ if type(info) is AnalyticInfo:
print(
"{:10s}: {:10s}: {}".format(
- trans,
- attribute,
- param_info(trans, attribute)["function"].model_function,
- )
- )
- print(
- "{:10s} {:10s} {}".format(
- "", "", param_info(trans, attribute)["function"].model_args
+ trans, attribute, info.function.model_function
)
)
+ print("{:10s} {:10s} {}".format("", "", info.function.model_args))
if xv_method == "montecarlo":
analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
diff --git a/lib/functions.py b/lib/functions.py
index 0bdea45..067514f 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -152,7 +152,82 @@ class NormalizationFunction:
return self._function(param_value)
-class AnalyticFunction:
+class ModelInfo:
+ def __init__(self):
+ pass
+
+
+class AnalyticInfo(ModelInfo):
+ def __init__(self, fit_result, function):
+ self.fit_result = fit_result
+ self.function = function
+
+
+class SplitInfo(ModelInfo):
+ def __init__(self, param_index, child):
+ self.param_index = param_index
+ self.child = child
+
+
+class ModelFunction:
+ def __init__(self):
+ pass
+
+ def is_predictable(self, param_list):
+ raise NotImplementedError
+
+ def eval(self, param_list, arg_list):
+ raise NotImplementedError
+
+
+class StaticFunction(ModelFunction):
+ def __init__(self, value):
+ self.value = value
+
+ def is_predictable(self, param_list=None):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+
+ For a StaticFunction, this is always the case (i.e., this function always returns true).
+ """
+ return True
+
+ def eval(self, param_list=None, arg_list=None):
+ """
+ Evaluate model function with specified param/arg values.
+
+ Far a Staticfunction, this is just the static value
+
+ """
+ return self.value
+
+
+class SplitFunction(ModelFunction):
+ def __init__(self, param_index, child):
+ self.param_index = param_index
+ self.child = child
+
+ def is_predictable(self, param_list):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+
+ The first value corresponds to the lexically first model parameter, etc.
+ All parameters must be set, not just the ones this function depends on.
+
+ Returns False iff a parameter the function depends on is not numeric
+ (e.g. None).
+ """
+ param_value = param_list[self.param_index]
+ if param_value not in self.child:
+ return False
+ return self.child[param_value].is_predictable(param_list)
+
+ def eval(self, param_list, arg_list=list()):
+ param_value = param_list[self.param_index]
+ return self.child[param_value].eval(param_list, arg_list)
+
+
+class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.
diff --git a/lib/model.py b/lib/model.py
index 83c31b1..cddfe27 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -7,8 +7,7 @@ from scipy import optimize
from sklearn.metrics import r2_score
from multiprocessing import Pool
from .automata import PTA
-from .functions import analytic
-from .functions import AnalyticFunction
+import dfatool.functions as df
from .parameters import ParallelParamStats, ParamStats
from .utils import is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
from .utils import (
@@ -211,7 +210,7 @@ def _try_fits(
:param param_filter: Only use measurements whose parameters match param_filter for fitting.
"""
- functions = analytic.functions(safe_functions_enabled=safe_functions_enabled)
+ functions = df.analytic.functions(safe_functions_enabled=safe_functions_enabled)
for param_key in n_by_param.keys():
# We might remove elements from 'functions' while iterating over
@@ -532,31 +531,33 @@ class ModelAttribute:
"child": dict(),
"child_static": dict(),
}
- info_map = {"split_by": split_param_index, "child": dict()}
+ function_child = dict()
+ info_child = dict()
for param_value, child in child_by_param_value.items():
child.set_data_from_paramfit(paramfit, prefix + (param_value,))
- function_map["child"][param_value], info_map["child"][
- param_value
- ] = child.get_fitted()
- function_map["child_static"][param_value] = child.get_static()
+ function_child[param_value], info_child[param_value] = child.get_fitted()
+ function_map = df.SplitFunction(split_param_index, function_child)
+ info_map = df.SplitInfo(split_param_index, info_child)
self.param_model = function_map, info_map
def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
- param_model = (None, None)
+ param_model = (df.StaticFunction(np.median(self.data)), None)
if self.function_override is not None:
function_str = self.function_override
- x = AnalyticFunction(function_str, self.param_names, self.arg_count)
+ x = df.AnalyticFunction(function_str, self.param_names, self.arg_count)
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, fit_result)
+ param_model = (x, df.AnalyticInfo(fit_result, x))
elif len(fit_result.keys()):
- x = analytic.function_powerset(fit_result, self.param_names, self.arg_count)
+ x = df.analytic.function_powerset(
+ fit_result, self.param_names, self.arg_count
+ )
x.fit(self.by_param)
if x.fit_success:
- param_model = (x, fit_result)
+ param_model = (x, df.AnalyticInfo(fit_result, x))
self.param_model = param_model
@@ -810,22 +811,12 @@ class AnalyticModel:
def model_getter(name, key, **kwargs):
param_function, param_info = self.attr_by_name[name][key].get_fitted()
- if param_function is None:
+ if param_info is None:
return static_model[name][key]
if "arg" in kwargs and "param" in kwargs:
kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
- while type(param_function) is dict and "split_by" in param_function:
- split_param_value = kwargs["param"][param_function["split_by"]]
- split_static = param_function["child_static"][split_param_value]
- param_function = param_function["child"][split_param_value]
- param_info = param_info["child"][split_param_value]
-
- if param_function is None:
- # TODO return static model of child
- return split_static
-
if param_function.is_predictable(kwargs["param"]):
return param_function.eval(kwargs["param"])
@@ -833,12 +824,10 @@ class AnalyticModel:
def info_getter(name, key):
try:
- model_function, fit_result = self.attr_by_name[name][key].get_fitted()
+ model_function, model_info = self.attr_by_name[name][key].get_fitted()
except KeyError:
return None
- if model_function is None:
- return None
- return {"function": model_function, "fit_result": fit_result}
+ return model_info
return model_getter, info_getter
diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py
index 55f84b8..257215f 100755
--- a/test/test_ptamodel.py
+++ b/test/test_ptamodel.py
@@ -640,26 +640,26 @@ class TestFromFile(unittest.TestCase):
param_model, param_info = model.get_fitted()
self.assertEqual(param_info("POWERDOWN", "power"), None)
self.assertEqual(
- param_info("RX", "power")["function"].model_function,
+ param_info("RX", "power").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))",
)
self.assertAlmostEqual(
- param_info("RX", "power")["function"].model_args[0], 48530.7, places=0
+ param_info("RX", "power").function.model_args[0], 48530.7, places=0
)
self.assertAlmostEqual(
- param_info("RX", "power")["function"].model_args[1], 117, places=0
+ param_info("RX", "power").function.model_args[1], 117, places=0
)
self.assertEqual(param_info("STANDBY1", "power"), None)
self.assertEqual(
- param_info("TX", "power")["function"].model_function,
+ param_info("TX", "power").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)",
)
self.assertEqual(
- param_info("epilogue", "timeout")["function"].model_function,
+ param_info("epilogue", "timeout").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
)
self.assertEqual(
- param_info("stopListening", "duration")["function"].model_function,
+ param_info("stopListening", "duration").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
)
@@ -1825,7 +1825,7 @@ class TestFromFile(unittest.TestCase):
param_model, param_info = model.get_fitted()
self.assertEqual(param_info("IDLE", "power"), None)
self.assertEqual(
- param_info("RX", "power")["function"].model_function,
+ param_info("RX", "power").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)",
)
self.assertEqual(param_info("SLEEP", "power"), None)
@@ -1834,10 +1834,10 @@ class TestFromFile(unittest.TestCase):
self.assertEqual(param_info("XOFF", "power"), None)
self.assertAlmostEqual(
- param_info("RX", "power")["function"].model_args[0], 84415, places=0
+ param_info("RX", "power").function.model_args[0], 84415, places=0
)
self.assertAlmostEqual(
- param_info("RX", "power")["function"].model_args[1], 206, places=0
+ param_info("RX", "power").function.model_args[1], 206, places=0
)
diff --git a/test/test_timingharness.py b/test/test_timingharness.py
index 9b55231..5cc0bec 100755
--- a/test/test_timingharness.py
+++ b/test/test_timingharness.py
@@ -34,21 +34,21 @@ class TestModels(unittest.TestCase):
self.assertEqual(param_info("setRetries", "duration"), None)
self.assertEqual(param_info("setup", "duration"), None)
self.assertEqual(
- param_info("write", "duration")["function"].model_function,
+ param_info("write", "duration").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[0], 1163, places=0
+ param_info("write", "duration").function.model_args[0], 1163, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[1], 464, places=0
+ param_info("write", "duration").function.model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[2], 1, places=0
+ param_info("write", "duration").function.model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[3], 1, places=0
+ param_info("write", "duration").function.model_args[3], 1, places=0
)
def test_dependent_parameter_pruning(self):
@@ -82,21 +82,21 @@ class TestModels(unittest.TestCase):
self.assertEqual(param_info("setRetries", "duration"), None)
self.assertEqual(param_info("setup", "duration"), None)
self.assertEqual(
- param_info("write", "duration")["function"].model_function,
+ param_info("write", "duration").function.model_function,
"0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[0], 1163, places=0
+ param_info("write", "duration").function.model_args[0], 1163, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[1], 464, places=0
+ param_info("write", "duration").function.model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[2], 1, places=0
+ param_info("write", "duration").function.model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[3], 1, places=0
+ param_info("write", "duration").function.model_args[3], 1, places=0
)
def test_function_override(self):
@@ -140,24 +140,24 @@ class TestModels(unittest.TestCase):
self.assertEqual(param_info("setRetries", "duration"), None)
self.assertEqual(param_info("setup", "duration"), None)
self.assertEqual(
- param_info("write", "duration")["function"].model_function,
+ param_info("write", "duration").function.model_function,
"(parameter(auto_ack!) * (regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay))) + ((1 - parameter(auto_ack!)) * regression_arg(4))",
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[0], 1162, places=0
+ param_info("write", "duration").function.model_args[0], 1162, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[1], 464, places=0
+ param_info("write", "duration").function.model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[2], 1, places=0
+ param_info("write", "duration").function.model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[3], 1, places=0
+ param_info("write", "duration").function.model_args[3], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration")["function"].model_args[4], 1086, places=0
+ param_info("write", "duration").function.model_args[4], 1086, places=0
)
os.environ.pop("DFATOOL_NO_DECISIONTREES")