summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-03-04 13:32:36 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-03-04 13:32:36 +0100
commit9bf7d10f3310147c7e85330a79da655b9f7a5bad (patch)
tree00283ccf137cd3a87e1d5d08869717a8ffddd4cc /lib/functions.py
parentf33c69dcaf24ecc7e039dec83a4a5c74908da52f (diff)
PTA State/Transition: Use ModelFunction instead of PTAAttribute
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py82
1 files changed, 73 insertions, 9 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 663b65e..7950b5a 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -154,8 +154,15 @@ class NormalizationFunction:
class ModelFunction:
- def __init__(self):
- pass
+ def __init__(self, value):
+ # a model always has a static (median/mean) value. For StaticFunction, it's the only data point.
+ # For more complex models, it's usede both as fallback in case the model cannot predict the current
+ # parameter combination, and for use cases requiring static models
+ self.value = value
+
+ # Legacy(?) attributes for PTA
+ self.value_error = None
+ self.function_error = None
def is_predictable(self, param_list):
raise NotImplementedError
@@ -163,11 +170,28 @@ class ModelFunction:
def eval(self, param_list, arg_list):
raise NotImplementedError
+ def to_json(self):
+ raise NotImplementedError
+
+ @classmethod
+ def from_json(cls, data):
+ if data["type"] == "static":
+ return StaticFunction.from_json(data)
+ if data["type"] == "split":
+ return SplitFunction.from_json(data)
+ if data["type"] == "analytic":
+ return AnalyticFunction.from_json(data)
+ raise ValueError("Unknown ModelFunction type: " + data["type"])
+
+ @classmethod
+ def from_json_maybe(cls, json_wrapped: dict, attribute: str):
+ # Legacy Code for PTA / tests. Do not use.
+ if type(json_wrapped) is dict and attribute in json_wrapped:
+ return cls.from_json(json_wrapped[attribute])
+ return StaticFunction(0)
-class StaticFunction(ModelFunction):
- def __init__(self, value):
- self.value = value
+class StaticFunction(ModelFunction):
def is_predictable(self, param_list=None):
"""
Return whether the model function can be evaluated on the given parameter values.
@@ -188,9 +212,18 @@ class StaticFunction(ModelFunction):
def to_json(self):
return {"type": "static", "value": self.value}
+ @classmethod
+ def from_json(cls, data):
+ assert data["type"] == "static"
+ return cls(data["value"])
+
+ def __repr__(self):
+ return f"StaticFunction({self.value})"
+
class SplitFunction(ModelFunction):
- def __init__(self, param_index, child):
+ def __init__(self, value, param_index, child):
+ super().__init__(value)
self.param_index = param_index
self.child = child
@@ -216,10 +249,22 @@ class SplitFunction(ModelFunction):
def to_json(self):
return {
"type": "split",
+ "value": self.value,
"paramIndex": self.param_index,
"child": dict([[k, v.to_json()] for k, v in self.child.items()]),
}
+ @classmethod
+ def from_json(cls, data):
+ assert data["type"] == "split"
+ self = cls(data["value"], data["paramIndex"], dict())
+
+ for k, v in data["child"].items():
+ self.child[k] = ModelFunction.from_json(v)
+
+ def __repr__(self):
+ return f"SplitFunction<{self.value}, param_index={self.param_index}>"
+
class AnalyticFunction(ModelFunction):
"""
@@ -232,9 +277,10 @@ class AnalyticFunction(ModelFunction):
def __init__(
self,
+ value,
function_str,
parameters,
- num_args,
+ num_args=0,
regression_args=None,
fit_by_param=None,
):
@@ -256,6 +302,7 @@ class AnalyticFunction(ModelFunction):
both for function usage and least squares optimization.
If unset, defaults to [1, 1, 1, ...]
"""
+ super().__init__(value)
self._parameter_names = parameters
self._num_args = num_args
self.model_function = function_str
@@ -416,11 +463,28 @@ class AnalyticFunction(ModelFunction):
def to_json(self):
return {
"type": "analytic",
+ "value": self.value,
"functionStr": self.model_function,
- "dependsOnParam": self._dependson,
+ "argCount": self._num_args,
+ "parameterNames": self._parameter_names,
"regressionModel": list(self.model_args),
}
+ @classmethod
+ def from_json(cls, data):
+ assert data["type"] == "analytic"
+
+ return cls(
+ data["value"],
+ data["functionStr"],
+ data["parameterNames"],
+ data["argCount"],
+ data["regressionModel"],
+ )
+
+ def __repr__(self):
+ return f"AnalyticFunction<{self.value}, {self.model_function}>"
+
class analytic:
"""
@@ -617,5 +681,5 @@ class analytic:
)
)
return AnalyticFunction(
- buf, parameter_names, num_args, fit_by_param=fit_results
+ None, buf, parameter_names, num_args, fit_by_param=fit_results
)