diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-04 13:32:36 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-04 13:32:36 +0100 |
commit | 9bf7d10f3310147c7e85330a79da655b9f7a5bad (patch) | |
tree | 00283ccf137cd3a87e1d5d08869717a8ffddd4cc /lib/functions.py | |
parent | f33c69dcaf24ecc7e039dec83a4a5c74908da52f (diff) |
PTA State/Transition: Use ModelFunction instead of PTAAttribute
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 82 |
1 files changed, 73 insertions, 9 deletions
diff --git a/lib/functions.py b/lib/functions.py index 663b65e..7950b5a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -154,8 +154,15 @@ class NormalizationFunction: class ModelFunction: - def __init__(self): - pass + def __init__(self, value): + # a model always has a static (median/mean) value. For StaticFunction, it's the only data point. + # For more complex models, it's usede both as fallback in case the model cannot predict the current + # parameter combination, and for use cases requiring static models + self.value = value + + # Legacy(?) attributes for PTA + self.value_error = None + self.function_error = None def is_predictable(self, param_list): raise NotImplementedError @@ -163,11 +170,28 @@ class ModelFunction: def eval(self, param_list, arg_list): raise NotImplementedError + def to_json(self): + raise NotImplementedError + + @classmethod + def from_json(cls, data): + if data["type"] == "static": + return StaticFunction.from_json(data) + if data["type"] == "split": + return SplitFunction.from_json(data) + if data["type"] == "analytic": + return AnalyticFunction.from_json(data) + raise ValueError("Unknown ModelFunction type: " + data["type"]) + + @classmethod + def from_json_maybe(cls, json_wrapped: dict, attribute: str): + # Legacy Code for PTA / tests. Do not use. + if type(json_wrapped) is dict and attribute in json_wrapped: + return cls.from_json(json_wrapped[attribute]) + return StaticFunction(0) -class StaticFunction(ModelFunction): - def __init__(self, value): - self.value = value +class StaticFunction(ModelFunction): def is_predictable(self, param_list=None): """ Return whether the model function can be evaluated on the given parameter values. @@ -188,9 +212,18 @@ class StaticFunction(ModelFunction): def to_json(self): return {"type": "static", "value": self.value} + @classmethod + def from_json(cls, data): + assert data["type"] == "static" + return cls(data["value"]) + + def __repr__(self): + return f"StaticFunction({self.value})" + class SplitFunction(ModelFunction): - def __init__(self, param_index, child): + def __init__(self, value, param_index, child): + super().__init__(value) self.param_index = param_index self.child = child @@ -216,10 +249,22 @@ class SplitFunction(ModelFunction): def to_json(self): return { "type": "split", + "value": self.value, "paramIndex": self.param_index, "child": dict([[k, v.to_json()] for k, v in self.child.items()]), } + @classmethod + def from_json(cls, data): + assert data["type"] == "split" + self = cls(data["value"], data["paramIndex"], dict()) + + for k, v in data["child"].items(): + self.child[k] = ModelFunction.from_json(v) + + def __repr__(self): + return f"SplitFunction<{self.value}, param_index={self.param_index}>" + class AnalyticFunction(ModelFunction): """ @@ -232,9 +277,10 @@ class AnalyticFunction(ModelFunction): def __init__( self, + value, function_str, parameters, - num_args, + num_args=0, regression_args=None, fit_by_param=None, ): @@ -256,6 +302,7 @@ class AnalyticFunction(ModelFunction): both for function usage and least squares optimization. If unset, defaults to [1, 1, 1, ...] """ + super().__init__(value) self._parameter_names = parameters self._num_args = num_args self.model_function = function_str @@ -416,11 +463,28 @@ class AnalyticFunction(ModelFunction): def to_json(self): return { "type": "analytic", + "value": self.value, "functionStr": self.model_function, - "dependsOnParam": self._dependson, + "argCount": self._num_args, + "parameterNames": self._parameter_names, "regressionModel": list(self.model_args), } + @classmethod + def from_json(cls, data): + assert data["type"] == "analytic" + + return cls( + data["value"], + data["functionStr"], + data["parameterNames"], + data["argCount"], + data["regressionModel"], + ) + + def __repr__(self): + return f"AnalyticFunction<{self.value}, {self.model_function}>" + class analytic: """ @@ -617,5 +681,5 @@ class analytic: ) ) return AnalyticFunction( - buf, parameter_names, num_args, fit_by_param=fit_results + None, buf, parameter_names, num_args, fit_by_param=fit_results ) |