summaryrefslogtreecommitdiff
path: root/bin/analyze-timing.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-03-03 09:36:43 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-03-03 09:36:43 +0100
commitf33c69dcaf24ecc7e039dec83a4a5c74908da52f (patch)
treedf5ada7af874161f70d866b4f8c1d878f1a373cc /bin/analyze-timing.py
parentd0d3f335739d9333f15ede487574f78f1eb5e638 (diff)
Remove ModelInfo; add info to ModelFunction instead
Diffstat (limited to 'bin/analyze-timing.py')
-rwxr-xr-xbin/analyze-timing.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index 10c8f1d..e8af0fb 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -80,10 +80,10 @@ import re
import sys
from dfatool import plotter
from dfatool.loader import TimingData, pta_trace_to_aggregate
-from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo
+from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunction
from dfatool.model import AnalyticModel
from dfatool.validation import CrossValidator
-from dfatool.utils import filter_aggregate_by_param
+from dfatool.utils import filter_aggregate_by_param, NpEncoder
from dfatool.parameters import prune_dependent_parameters
opt = dict()
@@ -117,7 +117,7 @@ def model_quality_table(result_lists, info_list):
for i, results in enumerate(result_lists):
info = info_list[i]
buf += " ||| "
- if info is None or info(state_or_tran, key):
+ if info is None or type(info(state_or_tran, key)) is not StaticFunction:
result = results["by_name"][state_or_tran][key]
buf += format_quality_measures(result)
else:
@@ -387,9 +387,9 @@ if __name__ == "__main__":
].stats.arg_dependence_ratio(i),
)
)
- if type(info) is AnalyticInfo:
- for param_name in sorted(info.fit_result.keys(), key=str):
- param_fit = info.fit_result[param_name]["results"]
+ if type(info) is AnalyticFunction:
+ for param_name in sorted(info.fit_by_param.keys(), key=str):
+ param_fit = info.fit_by_param[param_name]["results"]
for function_type in sorted(param_fit.keys()):
function_rmsd = param_fit[function_type]["rmsd"]
print(
@@ -406,13 +406,13 @@ if __name__ == "__main__":
for trans in model.names:
for attribute in ["duration"]:
info = param_info(trans, attribute)
- if type(info) is AnalyticInfo:
+ if type(info) is AnalyticFunction:
print(
"{:10s}: {:10s}: {}".format(
- trans, attribute, info.function.model_function
+ trans, attribute, info.model_function
)
)
- print("{:10s} {:10s} {}".format("", "", info.function.model_args))
+ print("{:10s} {:10s} {}".format("", "", info.model_args))
if xv_method == "montecarlo":
analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
@@ -451,4 +451,6 @@ if __name__ == "__main__":
extra_function=function,
)
+ # print(json.dumps(model.to_json(), cls=NpEncoder, indent=2))
+
sys.exit(0)