summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-archive.py42
-rwxr-xr-xbin/analyze-timing.py20
-rw-r--r--lib/functions.py50
-rw-r--r--lib/model.py8
-rw-r--r--lib/parameters.py11
-rwxr-xr-xtest/test_ptamodel.py40
-rwxr-xr-xtest/test_timingharness.py56
7 files changed, 99 insertions, 128 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 3344d8a..65e25cc 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -43,7 +43,12 @@ import random
import sys
from dfatool import plotter
from dfatool.loader import RawData, pta_trace_to_aggregate
-from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo, StaticInfo
+from dfatool.functions import (
+ gplearn_to_function,
+ SplitFunction,
+ AnalyticFunction,
+ StaticFunction,
+)
from dfatool.model import PTAModel
from dfatool.validation import CrossValidator
from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate
@@ -91,13 +96,14 @@ def model_quality_table(header, result_lists, info_list):
info is None
or (
key != "energy_Pt"
- and type(info(state_or_tran, key)) is not StaticInfo
+ and type(info(state_or_tran, key)) is not StaticFunction
)
or (
key == "energy_Pt"
and (
- type(info(state_or_tran, "power")) is not StaticInfo
- or type(info(state_or_tran, "duration")) is not StaticInfo
+ type(info(state_or_tran, "power")) is not StaticFunction
+ or type(info(state_or_tran, "duration"))
+ is not StaticFunction
)
)
):
@@ -370,22 +376,22 @@ def print_static(model, static_model, name, attribute):
def print_analyticinfo(prefix, info):
empty = ""
- print(f"{prefix}: {info.function.model_function}")
- print(f"{empty:{len(prefix)}s} {info.function.model_args}")
+ print(f"{prefix}: {info.model_function}")
+ print(f"{empty:{len(prefix)}s} {info.model_args}")
def print_splitinfo(param_names, info, prefix=""):
- if type(info) is SplitInfo:
+ if type(info) is SplitFunction:
for k, v in info.child.items():
if info.param_index < len(param_names):
param_name = param_names[info.param_index]
else:
param_name = f"arg{info.param_index - len(param_names)}"
print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
- elif type(info) is AnalyticInfo:
+ elif type(info) is AnalyticFunction:
print_analyticinfo(prefix, info)
- elif type(info) is StaticInfo:
- print(f"{prefix}: {info.median}")
+ elif type(info) is StaticFunction:
+ print(f"{prefix}: {info.value}")
else:
print(f"{prefix}: UNKNOWN")
@@ -896,9 +902,9 @@ if __name__ == "__main__":
],
)
)
- if type(info) is AnalyticInfo:
- for param_name in sorted(info.fit_result.keys(), key=str):
- param_fit = info.fit_result[param_name]["results"]
+ if type(info) is AnalyticFunction:
+ for param_name in sorted(info.fit_by_param.keys(), key=str):
+ param_fit = info.fit_by_param[param_name]["results"]
for function_type in sorted(param_fit.keys()):
function_rmsd = param_fit[function_type]["rmsd"]
print(
@@ -915,18 +921,18 @@ if __name__ == "__main__":
for state in model.states:
for attribute in model.attributes(state):
info = param_info(state, attribute)
- if type(info) is AnalyticInfo:
+ if type(info) is AnalyticFunction:
print_analyticinfo(f"{state:10s} {attribute:15s}", info)
- elif type(info) is SplitInfo:
+ elif type(info) is SplitFunction:
print_splitinfo(
model.parameters, info, f"{state:10s} {attribute:15s}"
)
for trans in model.transitions:
for attribute in model.attributes(trans):
info = param_info(trans, attribute)
- if type(info) is AnalyticInfo:
+ if type(info) is AnalyticFunction:
print_analyticinfo(f"{trans:10s} {attribute:15s}", info)
- elif type(info) is SplitInfo:
+ elif type(info) is SplitFunction:
print_splitinfo(
model.parameters, info, f"{trans:10s} {attribute:15s}"
)
@@ -936,7 +942,7 @@ if __name__ == "__main__":
for substate in submodel.states:
for subattribute in submodel.attributes(substate):
info = sub_param_info(substate, subattribute)
- if type(info) is AnalyticInfo:
+ if type(info) is AnalyticFunction:
print(
"{:10s} {:15s}: {}".format(
substate, subattribute, info.function.model_function
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index 10c8f1d..e8af0fb 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -80,10 +80,10 @@ import re
import sys
from dfatool import plotter
from dfatool.loader import TimingData, pta_trace_to_aggregate
-from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo
+from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunction
from dfatool.model import AnalyticModel
from dfatool.validation import CrossValidator
-from dfatool.utils import filter_aggregate_by_param
+from dfatool.utils import filter_aggregate_by_param, NpEncoder
from dfatool.parameters import prune_dependent_parameters
opt = dict()
@@ -117,7 +117,7 @@ def model_quality_table(result_lists, info_list):
for i, results in enumerate(result_lists):
info = info_list[i]
buf += " ||| "
- if info is None or info(state_or_tran, key):
+ if info is None or type(info(state_or_tran, key)) is not StaticFunction:
result = results["by_name"][state_or_tran][key]
buf += format_quality_measures(result)
else:
@@ -387,9 +387,9 @@ if __name__ == "__main__":
].stats.arg_dependence_ratio(i),
)
)
- if type(info) is AnalyticInfo:
- for param_name in sorted(info.fit_result.keys(), key=str):
- param_fit = info.fit_result[param_name]["results"]
+ if type(info) is AnalyticFunction:
+ for param_name in sorted(info.fit_by_param.keys(), key=str):
+ param_fit = info.fit_by_param[param_name]["results"]
for function_type in sorted(param_fit.keys()):
function_rmsd = param_fit[function_type]["rmsd"]
print(
@@ -406,13 +406,13 @@ if __name__ == "__main__":
for trans in model.names:
for attribute in ["duration"]:
info = param_info(trans, attribute)
- if type(info) is AnalyticInfo:
+ if type(info) is AnalyticFunction:
print(
"{:10s}: {:10s}: {}".format(
- trans, attribute, info.function.model_function
+ trans, attribute, info.model_function
)
)
- print("{:10s} {:10s} {}".format("", "", info.function.model_args))
+ print("{:10s} {:10s} {}".format("", "", info.model_args))
if xv_method == "montecarlo":
analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
@@ -451,4 +451,6 @@ if __name__ == "__main__":
extra_function=function,
)
+ # print(json.dumps(model.to_json(), cls=NpEncoder, indent=2))
+
sys.exit(0)
diff --git a/lib/functions.py b/lib/functions.py
index f3e5ac8..663b65e 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -153,11 +153,6 @@ class NormalizationFunction:
return self._function(param_value)
-class ModelInfo:
- def __init__(self):
- self.error = None
-
-
class ModelFunction:
def __init__(self):
pass
@@ -194,17 +189,6 @@ class StaticFunction(ModelFunction):
return {"type": "static", "value": self.value}
-class StaticInfo(ModelInfo):
- def __init__(self, data):
- super()
- self.mean = np.mean(data)
- self.median = np.median(data)
- self.std = np.std(data)
-
- def to_json(self):
- return "FIXME"
-
-
class SplitFunction(ModelFunction):
def __init__(self, param_index, child):
self.param_index = param_index
@@ -237,16 +221,6 @@ class SplitFunction(ModelFunction):
}
-class SplitInfo(ModelInfo):
- def __init__(self, param_index, child):
- super()
- self.param_index = param_index
- self.child = child
-
- def to_json(self):
- return "FIXME"
-
-
class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.
@@ -256,7 +230,14 @@ class AnalyticFunction(ModelFunction):
packet length.
"""
- def __init__(self, function_str, parameters, num_args, regression_args=None):
+ def __init__(
+ self,
+ function_str,
+ parameters,
+ num_args,
+ regression_args=None,
+ fit_by_param=None,
+ ):
"""
Create a new AnalyticFunction object from a function string.
@@ -281,6 +262,7 @@ class AnalyticFunction(ModelFunction):
rawfunction = function_str
self._dependson = [False] * (len(parameters) + num_args)
self.fit_success = False
+ self.fit_by_param = fit_by_param
if type(function_str) == str:
num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)")
@@ -440,16 +422,6 @@ class AnalyticFunction(ModelFunction):
}
-class AnalyticInfo(ModelInfo):
- def __init__(self, fit_result, function):
- super()
- self.fit_result = fit_result
- self.function = function
-
- def to_json(self):
- return "FIXME"
-
-
class analytic:
"""
Utilities for analytic description of parameter-dependent model attributes and regression analysis.
@@ -644,4 +616,6 @@ class analytic:
"parameter", function_item[0], function_item[1]["best"]
)
)
- return AnalyticFunction(buf, parameter_names, num_args)
+ return AnalyticFunction(
+ buf, parameter_names, num_args, fit_by_param=fit_results
+ )
diff --git a/lib/model.py b/lib/model.py
index be63a8a..ccbf719 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -4,7 +4,7 @@ import logging
import numpy as np
import os
from .automata import PTA, ModelAttribute
-from .functions import StaticInfo
+from .functions import StaticFunction
from .parameters import ParallelParamStats
from .paramfit import ParallelParamFit
from .utils import soft_cast_int, by_name_to_by_param, regression_measures
@@ -255,10 +255,10 @@ class AnalyticModel:
def model_getter(name, key, **kwargs):
model_function = self.attr_by_name[name][key].model_function
- model_info = self.attr_by_name[name][key].model_info
+ model_info = self.attr_by_name[name][key].model_function
# shortcut
- if type(model_info) is StaticInfo:
+ if type(model_info) is StaticFunction:
return static_model[name][key]
if "arg" in kwargs and "param" in kwargs:
@@ -271,7 +271,7 @@ class AnalyticModel:
def info_getter(name, key):
try:
- return self.attr_by_name[name][key].model_info
+ return self.attr_by_name[name][key].model_function
except KeyError:
return None
diff --git a/lib/parameters.py b/lib/parameters.py
index fa01804..1cad7a5 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -594,7 +594,6 @@ class ModelAttribute:
# The best model we have. May be Static, Split, or Param (and later perhaps Substate)
self.model_function = None
- self.model_info = None
def __repr__(self):
mean = np.mean(self.data)
@@ -605,7 +604,6 @@ class ModelAttribute:
"paramNames": self.param_names,
"argCount": self.arg_count,
"modelFunction": self.model_function.to_json(),
- "modelInfo": self.model_info.to_json(),
}
return ret
@@ -784,21 +782,19 @@ class ModelAttribute:
for param_value, child in child_by_param_value.items():
child.set_data_from_paramfit(paramfit, prefix + (param_value,))
function_child[param_value] = child.model_function
- info_child[param_value] = child.model_info
self.model_function = df.SplitFunction(split_param_index, function_child)
- self.model_info = df.SplitInfo(split_param_index, info_child)
def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
self.model_function = df.StaticFunction(self.median)
- self.model_info = df.StaticInfo(self.data)
if self.function_override is not None:
function_str = self.function_override
- x = df.AnalyticFunction(function_str, self.param_names, self.arg_count)
+ x = df.AnalyticFunction(
+ function_str, self.param_names, self.arg_count, fit_by_param=fit_result
+ )
x.fit(self.by_param)
if x.fit_success:
self.model_function = x
- self.model_info = df.AnalyticInfo(fit_result, x)
elif os.getenv("DFATOOL_NO_PARAM"):
pass
elif len(fit_result.keys()):
@@ -809,4 +805,3 @@ class ModelAttribute:
if x.fit_success:
self.model_function = x
- self.model_info = df.AnalyticInfo(fit_result, x)
diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py
index e571dcc..9f5076c 100755
--- a/test/test_ptamodel.py
+++ b/test/test_ptamodel.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-from dfatool.functions import StaticInfo
+from dfatool.functions import StaticFunction
from dfatool.loader import RawData, pta_trace_to_aggregate
from dfatool.model import PTAModel
from dfatool.utils import by_name_to_by_param
@@ -639,28 +639,26 @@ class TestFromFile(unittest.TestCase):
)
param_model, param_info = model.get_fitted()
- self.assertIsInstance(param_info("POWERDOWN", "power"), StaticInfo)
+ self.assertIsInstance(param_info("POWERDOWN", "power"), StaticFunction)
self.assertEqual(
- param_info("RX", "power").function.model_function,
+ param_info("RX", "power").model_function,
"0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))",
)
self.assertAlmostEqual(
- param_info("RX", "power").function.model_args[0], 48530.7, places=0
+ param_info("RX", "power").model_args[0], 48530.7, places=0
)
- self.assertAlmostEqual(
- param_info("RX", "power").function.model_args[1], 117, places=0
- )
- self.assertIsInstance(param_info("STANDBY1", "power"), StaticInfo)
+ self.assertAlmostEqual(param_info("RX", "power").model_args[1], 117, places=0)
+ self.assertIsInstance(param_info("STANDBY1", "power"), StaticFunction)
self.assertEqual(
- param_info("TX", "power").function.model_function,
+ param_info("TX", "power").model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)",
)
self.assertEqual(
- param_info("epilogue", "timeout").function.model_function,
+ param_info("epilogue", "timeout").model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
)
self.assertEqual(
- param_info("stopListening", "duration").function.model_function,
+ param_info("stopListening", "duration").model_function,
"0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
)
@@ -1823,22 +1821,18 @@ class TestFromFile(unittest.TestCase):
"""
param_model, param_info = model.get_fitted()
- self.assertIsInstance(param_info("IDLE", "power"), StaticInfo)
+ self.assertIsInstance(param_info("IDLE", "power"), StaticFunction)
self.assertEqual(
- param_info("RX", "power").function.model_function,
+ param_info("RX", "power").model_function,
"0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)",
)
- self.assertIsInstance(param_info("SLEEP", "power"), StaticInfo)
- self.assertIsInstance(param_info("SLEEP_EWOR", "power"), StaticInfo)
- self.assertIsInstance(param_info("SYNTH_ON", "power"), StaticInfo)
- self.assertIsInstance(param_info("XOFF", "power"), StaticInfo)
+ self.assertIsInstance(param_info("SLEEP", "power"), StaticFunction)
+ self.assertIsInstance(param_info("SLEEP_EWOR", "power"), StaticFunction)
+ self.assertIsInstance(param_info("SYNTH_ON", "power"), StaticFunction)
+ self.assertIsInstance(param_info("XOFF", "power"), StaticFunction)
- self.assertAlmostEqual(
- param_info("RX", "power").function.model_args[0], 84415, places=0
- )
- self.assertAlmostEqual(
- param_info("RX", "power").function.model_args[1], 206, places=0
- )
+ self.assertAlmostEqual(param_info("RX", "power").model_args[0], 84415, places=0)
+ self.assertAlmostEqual(param_info("RX", "power").model_args[1], 206, places=0)
if __name__ == "__main__":
diff --git a/test/test_timingharness.py b/test/test_timingharness.py
index 8c68e4a..06edc16 100755
--- a/test/test_timingharness.py
+++ b/test/test_timingharness.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-from dfatool.functions import StaticInfo
+from dfatool.functions import StaticFunction
from dfatool.loader import TimingData, pta_trace_to_aggregate
from dfatool.model import AnalyticModel
from dfatool.parameters import prune_dependent_parameters
@@ -31,25 +31,25 @@ class TestModels(unittest.TestCase):
)
param_model, param_info = model.get_fitted()
- self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setup", "duration"), StaticInfo)
+ self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setup", "duration"), StaticFunction)
self.assertEqual(
- param_info("write", "duration").function.model_function,
+ param_info("write", "duration").model_function,
"0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[0], 1163, places=0
+ param_info("write", "duration").model_args[0], 1163, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[1], 464, places=0
+ param_info("write", "duration").model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[2], 1, places=0
+ param_info("write", "duration").model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[3], 1, places=0
+ param_info("write", "duration").model_args[3], 1, places=0
)
def test_dependent_parameter_pruning(self):
@@ -78,26 +78,26 @@ class TestModels(unittest.TestCase):
)
param_model, param_info = model.get_fitted()
- self.assertIsInstance(param_info("getObserveTx", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setup", "duration"), StaticInfo)
+ self.assertIsInstance(param_info("getObserveTx", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setup", "duration"), StaticFunction)
self.assertEqual(
- param_info("write", "duration").function.model_function,
+ param_info("write", "duration").model_function,
"0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[0], 1163, places=0
+ param_info("write", "duration").model_args[0], 1163, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[1], 464, places=0
+ param_info("write", "duration").model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[2], 1, places=0
+ param_info("write", "duration").model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[3], 1, places=0
+ param_info("write", "duration").model_args[3], 1, places=0
)
def test_function_override(self):
@@ -136,29 +136,29 @@ class TestModels(unittest.TestCase):
)
param_model, param_info = model.get_fitted()
- self.assertIsInstance(param_info("setAutoAck", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setPALevel", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setRetries", "duration"), StaticInfo)
- self.assertIsInstance(param_info("setup", "duration"), StaticInfo)
+ self.assertIsInstance(param_info("setAutoAck", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setPALevel", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setRetries", "duration"), StaticFunction)
+ self.assertIsInstance(param_info("setup", "duration"), StaticFunction)
self.assertEqual(
- param_info("write", "duration").function.model_function,
+ param_info("write", "duration").model_function,
"(parameter(auto_ack!) * (regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay))) + ((1 - parameter(auto_ack!)) * regression_arg(4))",
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[0], 1162, places=0
+ param_info("write", "duration").model_args[0], 1162, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[1], 464, places=0
+ param_info("write", "duration").model_args[1], 464, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[2], 1, places=0
+ param_info("write", "duration").model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[3], 1, places=0
+ param_info("write", "duration").model_args[3], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").function.model_args[4], 1086, places=0
+ param_info("write", "duration").model_args[4], 1086, places=0
)
os.environ.pop("DFATOOL_NO_DECISIONTREES")