diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 70 |
1 files changed, 46 insertions, 24 deletions
diff --git a/lib/model.py b/lib/model.py index e77db01..427b5ec 100644 --- a/lib/model.py +++ b/lib/model.py @@ -237,6 +237,8 @@ class AnalyticModel: model[name][k] = v.get_static(use_mean=use_mean) def static_model_getter(name, key, **kwargs): + if "params" in kwargs: + return [model[name][key] for p in kwargs["params"]] return model[name][key] return static_model_getter @@ -266,18 +268,27 @@ class AnalyticModel: for param, model_value in v.by_param.items(): lut_model[name][k][param] = v.get_lut(param, use_mean=use_mean) - def lut_median_getter(name, key, param, arg=list(), **kwargs): - if arg: - if type(param) is tuple: - param = list(param) - param.extend(map(soft_cast_int, arg)) - param = tuple(param) - try: - return lut_model[name][key][param] - except KeyError: - if fallback: - return static_model[name][key] - raise + def lut_median_getter(name, key, **kwargs): + if "param" in kwargs: + param = tuple(kwargs["param"]) + try: + return lut_model[name][key][param] + except KeyError: + if fallback: + return static_model[name][key] + raise + params = kwargs["params"] + if fallback: + return list( + map( + lambda p: lut_model[name][key][tuple(p)] + if tuple(p) in lut_model[name][key] + else static_model[name][key], + params, + ) + ) + else: + return list(map(lambda p: lut_model[name][key][tuple(p)], params)) return lut_median_getter @@ -351,14 +362,32 @@ class AnalyticModel: # shortcut if type(model_info) is StaticFunction: + if "params" in kwargs: + return [static_model[name][key] for p in kwargs["params"]] return static_model[name][key] - if "arg" in kwargs and "param" in kwargs: - kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) - - if model_function.is_predictable(kwargs["param"]): + if "param" in kwargs and model_function.is_predictable(kwargs["param"]): return model_function.eval(kwargs["param"]) + if "params" in kwargs: + if model_function.has_eval_arr and ( + model_function.always_predictable + or all( + map( + lambda p: model_function.is_predictable(p), kwargs["params"] + ) + ) + ): + return model_function.eval_arr(kwargs["params"]) + return list( + map( + lambda p: model_function.eval(p) + if model_function.is_predictable(p) + else static_model[name][key], + kwargs["params"], + ) + ) + return static_model[name][key] def info_getter(name, key): @@ -395,14 +424,7 @@ class AnalyticModel: } for attribute in elem["attributes"]: predicted_data = np.array( - list( - map( - lambda i: model_function( - name, attribute, param=elem["param"][i] - ), - range(len(elem[attribute])), - ) - ) + model_function(name, attribute, params=elem["param"]) ) measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures |