diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-03 09:36:43 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-03 09:36:43 +0100 |
commit | f33c69dcaf24ecc7e039dec83a4a5c74908da52f (patch) | |
tree | df5ada7af874161f70d866b4f8c1d878f1a373cc /bin | |
parent | d0d3f335739d9333f15ede487574f78f1eb5e638 (diff) |
Remove ModelInfo; add info to ModelFunction instead
Diffstat (limited to 'bin')
-rwxr-xr-x | bin/analyze-archive.py | 42 | ||||
-rwxr-xr-x | bin/analyze-timing.py | 20 |
2 files changed, 35 insertions, 27 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 3344d8a..65e25cc 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -43,7 +43,12 @@ import random import sys from dfatool import plotter from dfatool.loader import RawData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo, StaticInfo +from dfatool.functions import ( + gplearn_to_function, + SplitFunction, + AnalyticFunction, + StaticFunction, +) from dfatool.model import PTAModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate @@ -91,13 +96,14 @@ def model_quality_table(header, result_lists, info_list): info is None or ( key != "energy_Pt" - and type(info(state_or_tran, key)) is not StaticInfo + and type(info(state_or_tran, key)) is not StaticFunction ) or ( key == "energy_Pt" and ( - type(info(state_or_tran, "power")) is not StaticInfo - or type(info(state_or_tran, "duration")) is not StaticInfo + type(info(state_or_tran, "power")) is not StaticFunction + or type(info(state_or_tran, "duration")) + is not StaticFunction ) ) ): @@ -370,22 +376,22 @@ def print_static(model, static_model, name, attribute): def print_analyticinfo(prefix, info): empty = "" - print(f"{prefix}: {info.function.model_function}") - print(f"{empty:{len(prefix)}s} {info.function.model_args}") + print(f"{prefix}: {info.model_function}") + print(f"{empty:{len(prefix)}s} {info.model_args}") def print_splitinfo(param_names, info, prefix=""): - if type(info) is SplitInfo: + if type(info) is SplitFunction: for k, v in info.child.items(): if info.param_index < len(param_names): param_name = param_names[info.param_index] else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") - elif type(info) is AnalyticInfo: + elif type(info) is AnalyticFunction: print_analyticinfo(prefix, info) - elif type(info) is StaticInfo: - print(f"{prefix}: {info.median}") + elif type(info) is StaticFunction: + print(f"{prefix}: {info.value}") else: print(f"{prefix}: UNKNOWN") @@ -896,9 +902,9 @@ if __name__ == "__main__": ], ) ) - if type(info) is AnalyticInfo: - for param_name in sorted(info.fit_result.keys(), key=str): - param_fit = info.fit_result[param_name]["results"] + if type(info) is AnalyticFunction: + for param_name in sorted(info.fit_by_param.keys(), key=str): + param_fit = info.fit_by_param[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -915,18 +921,18 @@ if __name__ == "__main__": for state in model.states: for attribute in model.attributes(state): info = param_info(state, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print_analyticinfo(f"{state:10s} {attribute:15s}", info) - elif type(info) is SplitInfo: + elif type(info) is SplitFunction: print_splitinfo( model.parameters, info, f"{state:10s} {attribute:15s}" ) for trans in model.transitions: for attribute in model.attributes(trans): info = param_info(trans, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print_analyticinfo(f"{trans:10s} {attribute:15s}", info) - elif type(info) is SplitInfo: + elif type(info) is SplitFunction: print_splitinfo( model.parameters, info, f"{trans:10s} {attribute:15s}" ) @@ -936,7 +942,7 @@ if __name__ == "__main__": for substate in submodel.states: for subattribute in submodel.attributes(substate): info = sub_param_info(substate, subattribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print( "{:10s} {:15s}: {}".format( substate, subattribute, info.function.model_function diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index 10c8f1d..e8af0fb 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -80,10 +80,10 @@ import re import sys from dfatool import plotter from dfatool.loader import TimingData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo +from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunction from dfatool.model import AnalyticModel from dfatool.validation import CrossValidator -from dfatool.utils import filter_aggregate_by_param +from dfatool.utils import filter_aggregate_by_param, NpEncoder from dfatool.parameters import prune_dependent_parameters opt = dict() @@ -117,7 +117,7 @@ def model_quality_table(result_lists, info_list): for i, results in enumerate(result_lists): info = info_list[i] buf += " ||| " - if info is None or info(state_or_tran, key): + if info is None or type(info(state_or_tran, key)) is not StaticFunction: result = results["by_name"][state_or_tran][key] buf += format_quality_measures(result) else: @@ -387,9 +387,9 @@ if __name__ == "__main__": ].stats.arg_dependence_ratio(i), ) ) - if type(info) is AnalyticInfo: - for param_name in sorted(info.fit_result.keys(), key=str): - param_fit = info.fit_result[param_name]["results"] + if type(info) is AnalyticFunction: + for param_name in sorted(info.fit_by_param.keys(), key=str): + param_fit = info.fit_by_param[param_name]["results"] for function_type in sorted(param_fit.keys()): function_rmsd = param_fit[function_type]["rmsd"] print( @@ -406,13 +406,13 @@ if __name__ == "__main__": for trans in model.names: for attribute in ["duration"]: info = param_info(trans, attribute) - if type(info) is AnalyticInfo: + if type(info) is AnalyticFunction: print( "{:10s}: {:10s}: {}".format( - trans, attribute, info.function.model_function + trans, attribute, info.model_function ) ) - print("{:10s} {:10s} {}".format("", "", info.function.model_args)) + print("{:10s} {:10s} {}".format("", "", info.model_args)) if xv_method == "montecarlo": analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) @@ -451,4 +451,6 @@ if __name__ == "__main__": extra_function=function, ) + # print(json.dumps(model.to_json(), cls=NpEncoder, indent=2)) + sys.exit(0) |