summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2025-03-17 10:46:11 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2025-03-17 10:46:11 +0100
commit61880a2d8cc36f306d142c8e63fd84792c449de4 (patch)
tree1fcdef416adac7f6ef6caef0faa3607093200165 /lib
parentfc83d91f138e6440dfdcb9d4fa0bf3fd3a559875 (diff)
Optionally export models with byParam data for easier loading
Diffstat (limited to 'lib')
-rw-r--r--lib/model.py30
1 files changed, 25 insertions, 5 deletions
diff --git a/lib/model.py b/lib/model.py
index 58f05a4..0026249 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -14,7 +14,13 @@ from .parameters import (
distinct_param_values,
)
from .paramfit import ParamFit
-from .utils import is_numeric, soft_cast_int, by_name_to_by_param, regression_measures
+from .utils import (
+ is_numeric,
+ soft_cast_int,
+ by_name_to_by_param,
+ by_param_to_by_name,
+ regression_measures,
+)
logger = logging.getLogger(__name__)
@@ -643,7 +649,7 @@ class AnalyticModel:
ret[f"xv/{name}/{attr_name}/{k}"] = np.mean(entry[k])
return ret
- def to_json(self, **kwargs) -> dict:
+ def to_json(self, with_by_param=False, **kwargs) -> dict:
"""
Return JSON encoding of this AnalyticModel.
"""
@@ -653,6 +659,12 @@ class AnalyticModel:
"paramValuesbyName": dict([[name, dict()] for name in self.names]),
}
+ if with_by_param:
+ by_param = self.get_by_param()
+ ret["byParam"] = list()
+ for k, v in by_param.items():
+ ret["byParam"].append((k, v))
+
for name in self.names:
for attr_name, attr in self.attr_by_name[name].items():
ret["name"][name][attr_name] = attr.to_json(**kwargs)
@@ -665,9 +677,17 @@ class AnalyticModel:
return ret
@classmethod
- def from_json(cls, data, by_name, parameters):
- assert data["parameters"] == parameters
- return cls(by_name, parameters, from_json=data)
+ def from_json(cls, data, by_name=None, parameters=None):
+ if by_name is None and parameters is None:
+ assert data["byParam"] is not None
+ by_param = dict()
+ for (nk, pk), v in data["byParam"]:
+ by_param[(nk, tuple(pk))] = v
+ by_name = by_param_to_by_name(by_param)
+ return cls(by_name, data["parameters"], from_json=data)
+ else:
+ assert data["parameters"] == parameters
+ return cls(by_name, parameters, from_json=data)
def webconf_function_map(self) -> list:
ret = list()