summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py32
1 files changed, 26 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py
index 0026249..dbe05aa 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -20,6 +20,7 @@ from .utils import (
by_name_to_by_param,
by_param_to_by_name,
regression_measures,
+ param_eq_or_none,
)
logger = logging.getLogger(__name__)
@@ -85,6 +86,7 @@ class AnalyticModel:
compute_stats=True,
force_tree=False,
max_std=None,
+ by_param=None,
from_json=None,
):
"""
@@ -154,9 +156,18 @@ class AnalyticModel:
for name, name_data in from_json["name"].items():
self.attr_by_name[name] = dict()
for attr, attr_data in name_data.items():
- self.attr_by_name[name][attr] = ModelAttribute.from_json(
- name, attr, attr_data
- )
+ if by_param:
+ self.attr_by_name[name][attr] = ModelAttribute.from_json(
+ name,
+ attr,
+ attr_data,
+ data_values=by_name[name][attr],
+ param_values=by_name[name]["param"],
+ )
+ else:
+ self.attr_by_name[name][attr] = ModelAttribute.from_json(
+ name, attr, attr_data
+ )
self.fit_done = True
return
@@ -255,7 +266,7 @@ class AnalyticModel:
return static_model_getter
- def get_param_lut(self, use_mean=False, fallback=False):
+ def get_param_lut(self, use_mean=False, fallback=False, allow_none=False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
@@ -285,7 +296,16 @@ class AnalyticModel:
try:
return lut_model[name][key][param]
except KeyError:
- if fallback:
+ if allow_none:
+ keys = filter(
+ lambda p: param_eq_or_none(param, p),
+ lut_model[name][key].keys(),
+ )
+ values = list(map(lambda p: lut_model[name][key][p], keys))
+ if not values:
+ raise
+ return np.mean(values)
+ elif fallback:
return static_model[name][key]
raise
params = kwargs["params"]
@@ -684,7 +704,7 @@ class AnalyticModel:
for (nk, pk), v in data["byParam"]:
by_param[(nk, tuple(pk))] = v
by_name = by_param_to_by_name(by_param)
- return cls(by_name, data["parameters"], from_json=data)
+ return cls(by_name, data["parameters"], by_param=by_param, from_json=data)
else:
assert data["parameters"] == parameters
return cls(by_name, parameters, from_json=data)