diff options
-rw-r--r-- | lib/model.py | 38 |
1 files changed, 32 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py index 56bf1f2..7b840ec 100644 --- a/lib/model.py +++ b/lib/model.py @@ -79,6 +79,7 @@ class AnalyticModel: compute_stats=True, force_tree=False, max_std=None, + from_json=None, ): """ Create a new AnalyticModel and compute parameter statistics. @@ -117,13 +118,20 @@ class AnalyticModel: :param use_corrcoef: use correlation coefficient instead of stddev comparison to detect whether a model attribute depends on a parameter """ self.cache = dict() - self.by_name = by_name # no longer required? + self.by_name = by_name self.attr_by_name = dict() self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) - self.function_override = function_override.copy() - self.dtree_max_std = max_std - self._use_corrcoef = use_corrcoef + + if from_json: + self.function_override = dict() + self.dtree_max_std = None + self._use_corrcoef = False + else: + self.function_override = function_override.copy() + self.dtree_max_std = max_std + self._use_corrcoef = use_corrcoef + self._num_args = arg_count if self._num_args is None: self._num_args = _num_args_from_by_name(by_name) @@ -136,6 +144,16 @@ class AnalyticModel: distinct_values, values_are_distinct=True ) + if from_json: + for name, name_data in from_json["name"].items(): + self.attr_by_name[name] = dict() + for attr, attr_data in name_data.items(): + self.attr_by_name[name][attr] = ModelAttribute.from_json( + name, attr, attr_data + ) + self.fit_done = True + return + self.fit_done = False if compute_stats: @@ -182,7 +200,7 @@ class AnalyticModel: return self.cache["by_param"] def __repr__(self): - names = ", ".join(self.by_name.keys()) + names = ", ".join(self.names) return f"AnalyticModel<names=[{names}]>" def _compute_stats(self, by_name): @@ -259,7 +277,10 @@ class AnalyticModel: for k, v in attr.items(): static_model[name][k] = v.get_static(use_mean=use_mean) lut_model[name][k] = dict() - for param, model_value in v.get_by_param().items(): + by_param = v.get_by_param() + if by_param is None: + return None + for param, model_value in by_param.items(): lut_model[name][k][param] = v.get_lut(param, use_mean=use_mean) def lut_median_getter(name, key, **kwargs): @@ -596,6 +617,11 @@ class AnalyticModel: return ret + @classmethod + def from_json(cls, data, by_name, parameters): + assert data["parameters"] == parameters + return cls(by_name, parameters, from_json=data) + def webconf_function_map(self) -> list: ret = list() for name in self.names: |