summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py38
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py
index 56bf1f2..7b840ec 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -79,6 +79,7 @@ class AnalyticModel:
compute_stats=True,
force_tree=False,
max_std=None,
+ from_json=None,
):
"""
Create a new AnalyticModel and compute parameter statistics.
@@ -117,13 +118,20 @@ class AnalyticModel:
:param use_corrcoef: use correlation coefficient instead of stddev comparison to detect whether a model attribute depends on a parameter
"""
self.cache = dict()
- self.by_name = by_name # no longer required?
+ self.by_name = by_name
self.attr_by_name = dict()
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
- self.function_override = function_override.copy()
- self.dtree_max_std = max_std
- self._use_corrcoef = use_corrcoef
+
+ if from_json:
+ self.function_override = dict()
+ self.dtree_max_std = None
+ self._use_corrcoef = False
+ else:
+ self.function_override = function_override.copy()
+ self.dtree_max_std = max_std
+ self._use_corrcoef = use_corrcoef
+
self._num_args = arg_count
if self._num_args is None:
self._num_args = _num_args_from_by_name(by_name)
@@ -136,6 +144,16 @@ class AnalyticModel:
distinct_values, values_are_distinct=True
)
+ if from_json:
+ for name, name_data in from_json["name"].items():
+ self.attr_by_name[name] = dict()
+ for attr, attr_data in name_data.items():
+ self.attr_by_name[name][attr] = ModelAttribute.from_json(
+ name, attr, attr_data
+ )
+ self.fit_done = True
+ return
+
self.fit_done = False
if compute_stats:
@@ -182,7 +200,7 @@ class AnalyticModel:
return self.cache["by_param"]
def __repr__(self):
- names = ", ".join(self.by_name.keys())
+ names = ", ".join(self.names)
return f"AnalyticModel<names=[{names}]>"
def _compute_stats(self, by_name):
@@ -259,7 +277,10 @@ class AnalyticModel:
for k, v in attr.items():
static_model[name][k] = v.get_static(use_mean=use_mean)
lut_model[name][k] = dict()
- for param, model_value in v.get_by_param().items():
+ by_param = v.get_by_param()
+ if by_param is None:
+ return None
+ for param, model_value in by_param.items():
lut_model[name][k][param] = v.get_lut(param, use_mean=use_mean)
def lut_median_getter(name, key, **kwargs):
@@ -596,6 +617,11 @@ class AnalyticModel:
return ret
+ @classmethod
+ def from_json(cls, data, by_name, parameters):
+ assert data["parameters"] == parameters
+ return cls(by_name, parameters, from_json=data)
+
def webconf_function_map(self) -> list:
ret = list()
for name in self.names: