diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-01 11:43:39 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-01 11:43:39 +0100 |
commit | 4516319ff93c65650dade21a34f33fb7e9f8b96c (patch) | |
tree | 373d36351d3df8d12f6e9d8df7a00770ca1294d7 | |
parent | f0c6aced6ec7670b38daa18a0dc0631eb19f2e11 (diff) |
get_fitted: also provide information on static (sub)models
-rwxr-xr-x | bin/analyze-archive.py | 17 | ||||
-rw-r--r-- | lib/functions.py | 7 | ||||
-rw-r--r-- | lib/model.py | 7 |
3 files changed, 23 insertions, 8 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 872025e..78f5d79 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -43,7 +43,7 @@ import random import sys from dfatool import plotter from dfatool.loader import RawData, pta_trace_to_aggregate -from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo +from dfatool.functions import gplearn_to_function, SplitInfo, AnalyticInfo, StaticInfo from dfatool.model import PTAModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param, detect_outliers_in_aggregate @@ -89,12 +89,15 @@ def model_quality_table(header, result_lists, info_list): buf += " ||| " if ( info is None - or info(state_or_tran, key) + or ( + key != "energy_Pt" + and type(info(state_or_tran, key)) is not StaticInfo + ) or ( key == "energy_Pt" and ( - info(state_or_tran, "power") - or info(state_or_tran, "duration") + type(info(state_or_tran, "power")) is not StaticInfo + or type(info(state_or_tran, "duration")) is not StaticInfo ) ) ): @@ -380,9 +383,11 @@ def print_splitinfo(param_names, info, prefix=""): param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") elif type(info) is AnalyticInfo: - print(f"{prefix} = analytic") + print_analyticinfo(prefix, info) + elif type(info) is StaticInfo: + print(f"{prefix}: {info.median}") else: - print(f"{prefix} = static") + print(f"{prefix}: UNKNOWN") if __name__ == "__main__": diff --git a/lib/functions.py b/lib/functions.py index 067514f..0f0bf47 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -157,6 +157,13 @@ class ModelInfo: pass +class StaticInfo: + def __init__(self, data): + self.mean = np.mean(data) + self.median = np.median(data) + self.std = np.std(data) + + class AnalyticInfo(ModelInfo): def __init__(self, fit_result, function): self.fit_result = fit_result diff --git a/lib/model.py b/lib/model.py index cddfe27..5ee9e4f 100644 --- a/lib/model.py +++ b/lib/model.py @@ -543,7 +543,10 @@ class ModelAttribute: def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) - param_model = (df.StaticFunction(np.median(self.data)), None) + param_model = ( + df.StaticFunction(np.median(self.data)), + df.StaticInfo(self.data), + ) if self.function_override is not None: function_str = self.function_override x = df.AnalyticFunction(function_str, self.param_names, self.arg_count) @@ -811,7 +814,7 @@ class AnalyticModel: def model_getter(name, key, **kwargs): param_function, param_info = self.attr_by_name[name][key].get_fitted() - if param_info is None: + if type(param_info) is df.StaticInfo: return static_model[name][key] if "arg" in kwargs and "param" in kwargs: |