summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py70
1 files changed, 47 insertions, 23 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 3031a01..48049b5 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -160,7 +160,8 @@ class ModelFunction:
# parameter combination, and for use cases requiring static models
self.value = value
- # Legacy(?) attributes for PTA
+ # A ModelFunction may track its own accuracy, both of the static value and of the eval() method.
+ # However, it does not specify how the accuracy was calculated (e.g. which data was used and whether cross-validation was performed)
self.value_error = None
self.function_error = None
@@ -171,17 +172,30 @@ class ModelFunction:
raise NotImplementedError
def to_json(self):
- raise NotImplementedError
+ ret = {
+ "value": self.value,
+ "value_error": self.value_error,
+ "function_error": self.function_error,
+ }
+ return ret
@classmethod
def from_json(cls, data):
if data["type"] == "static":
- return StaticFunction.from_json(data)
- if data["type"] == "split":
- return SplitFunction.from_json(data)
- if data["type"] == "analytic":
- return AnalyticFunction.from_json(data)
- raise ValueError("Unknown ModelFunction type: " + data["type"])
+ mf = StaticFunction.from_json(data)
+ elif data["type"] == "split":
+ mf = SplitFunction.from_json(data)
+ elif data["type"] == "analytic":
+ mf = AnalyticFunction.from_json(data)
+ else:
+ raise ValueError("Unknown ModelFunction type: " + data["type"])
+
+ if "value_error" in data:
+ mf.value_error = data["value_error"]
+ if "function_error" in data:
+ mf.function_error = data["function_error"]
+
+ return mf
@classmethod
def from_json_maybe(cls, json_wrapped: dict, attribute: str):
@@ -220,7 +234,9 @@ class StaticFunction(ModelFunction):
return self.value
def to_json(self):
- return {"type": "static", "value": self.value}
+ ret = super().to_json()
+ ret.update({"type": "static", "value": self.value})
+ return ret
@classmethod
def from_json(cls, data):
@@ -257,12 +273,16 @@ class SplitFunction(ModelFunction):
return self.child[param_value].eval(param_list, arg_list)
def to_json(self):
- return {
- "type": "split",
- "value": self.value,
- "paramIndex": self.param_index,
- "child": dict([[k, v.to_json()] for k, v in self.child.items()]),
- }
+ ret = super().to_json()
+ ret.update(
+ {
+ "type": "split",
+ "value": self.value,
+ "paramIndex": self.param_index,
+ "child": dict([[k, v.to_json()] for k, v in self.child.items()]),
+ }
+ )
+ return ret
@classmethod
def from_json(cls, data):
@@ -471,14 +491,18 @@ class AnalyticFunction(ModelFunction):
return self._function(self.model_args, param_list)
def to_json(self):
- return {
- "type": "analytic",
- "value": self.value,
- "functionStr": self.model_function,
- "argCount": self._num_args,
- "parameterNames": self._parameter_names,
- "regressionModel": list(self.model_args),
- }
+ ret = super().to_json()
+ ret.update(
+ {
+ "type": "analytic",
+ "value": self.value,
+ "functionStr": self.model_function,
+ "argCount": self._num_args,
+ "parameterNames": self._parameter_names,
+ "regressionModel": list(self.model_args),
+ }
+ )
+ return ret
@classmethod
def from_json(cls, data):