summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2025-03-20 14:21:11 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2025-03-20 14:21:11 +0100
commit192dc253d9c576d2b2bfffce6cc645b695ab1e77 (patch)
tree50488bd58bbdedb08d414947391c58394993cf63 /lib
parent5872348e31cc4698c2bb976662df5efac4467cae (diff)
workload: Support LUT lookups
Diffstat (limited to 'lib')
-rw-r--r--lib/model.py32
-rw-r--r--lib/parameters.py4
-rw-r--r--lib/utils.py12
3 files changed, 40 insertions, 8 deletions
diff --git a/lib/model.py b/lib/model.py
index 0026249..dbe05aa 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -20,6 +20,7 @@ from .utils import (
by_name_to_by_param,
by_param_to_by_name,
regression_measures,
+ param_eq_or_none,
)
logger = logging.getLogger(__name__)
@@ -85,6 +86,7 @@ class AnalyticModel:
compute_stats=True,
force_tree=False,
max_std=None,
+ by_param=None,
from_json=None,
):
"""
@@ -154,9 +156,18 @@ class AnalyticModel:
for name, name_data in from_json["name"].items():
self.attr_by_name[name] = dict()
for attr, attr_data in name_data.items():
- self.attr_by_name[name][attr] = ModelAttribute.from_json(
- name, attr, attr_data
- )
+ if by_param:
+ self.attr_by_name[name][attr] = ModelAttribute.from_json(
+ name,
+ attr,
+ attr_data,
+ data_values=by_name[name][attr],
+ param_values=by_name[name]["param"],
+ )
+ else:
+ self.attr_by_name[name][attr] = ModelAttribute.from_json(
+ name, attr, attr_data
+ )
self.fit_done = True
return
@@ -255,7 +266,7 @@ class AnalyticModel:
return static_model_getter
- def get_param_lut(self, use_mean=False, fallback=False):
+ def get_param_lut(self, use_mean=False, fallback=False, allow_none=False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
@@ -285,7 +296,16 @@ class AnalyticModel:
try:
return lut_model[name][key][param]
except KeyError:
- if fallback:
+ if allow_none:
+ keys = filter(
+ lambda p: param_eq_or_none(param, p),
+ lut_model[name][key].keys(),
+ )
+ values = list(map(lambda p: lut_model[name][key][p], keys))
+ if not values:
+ raise
+ return np.mean(values)
+ elif fallback:
return static_model[name][key]
raise
params = kwargs["params"]
@@ -684,7 +704,7 @@ class AnalyticModel:
for (nk, pk), v in data["byParam"]:
by_param[(nk, tuple(pk))] = v
by_name = by_param_to_by_name(by_param)
- return cls(by_name, data["parameters"], from_json=data)
+ return cls(by_name, data["parameters"], by_param=by_param, from_json=data)
else:
assert data["parameters"] == parameters
return cls(by_name, parameters, from_json=data)
diff --git a/lib/parameters.py b/lib/parameters.py
index b648c4c..acb044c 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -731,11 +731,11 @@ class ModelAttribute:
return self.mutual_information_cache
@classmethod
- def from_json(cls, name, attr, data):
+ def from_json(cls, name, attr, data, data_values=None, param_values=None):
param_names = data["paramNames"]
arg_count = data["argCount"]
- self = cls(name, attr, None, None, param_names, arg_count)
+ self = cls(name, attr, data_values, param_values, param_names, arg_count)
self.model_function = df.ModelFunction.from_json(data["modelFunction"])
self.mean = self.model_function.value
diff --git a/lib/utils.py b/lib/utils.py
index 208db44..228e78c 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -207,6 +207,18 @@ def param_slice_eq(a, b, index):
return False
+def param_eq_or_none(a, b):
+ """
+ Check if by_param keys a and b are identical, allowing a None in a to match any key in b.
+ """
+ set_keys = tuple(filter(lambda i: a[i] is not None, range(len(a))))
+ a_not_none = tuple(map(lambda i: a[i], set_keys))
+ b_not_none = tuple(map(lambda i: b[i], set_keys))
+ if a_not_none == b_not_none:
+ return True
+ return False
+
+
def match_parameter_values(input_param: dict, match_param: dict):
"""
Check whether one of the paramaters in `input_param` has the same value in `match_param`.