diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 70 |
1 files changed, 47 insertions, 23 deletions
diff --git a/lib/functions.py b/lib/functions.py index 3031a01..48049b5 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -160,7 +160,8 @@ class ModelFunction: # parameter combination, and for use cases requiring static models self.value = value - # Legacy(?) attributes for PTA + # A ModelFunction may track its own accuracy, both of the static value and of the eval() method. + # However, it does not specify how the accuracy was calculated (e.g. which data was used and whether cross-validation was performed) self.value_error = None self.function_error = None @@ -171,17 +172,30 @@ class ModelFunction: raise NotImplementedError def to_json(self): - raise NotImplementedError + ret = { + "value": self.value, + "value_error": self.value_error, + "function_error": self.function_error, + } + return ret @classmethod def from_json(cls, data): if data["type"] == "static": - return StaticFunction.from_json(data) - if data["type"] == "split": - return SplitFunction.from_json(data) - if data["type"] == "analytic": - return AnalyticFunction.from_json(data) - raise ValueError("Unknown ModelFunction type: " + data["type"]) + mf = StaticFunction.from_json(data) + elif data["type"] == "split": + mf = SplitFunction.from_json(data) + elif data["type"] == "analytic": + mf = AnalyticFunction.from_json(data) + else: + raise ValueError("Unknown ModelFunction type: " + data["type"]) + + if "value_error" in data: + mf.value_error = data["value_error"] + if "function_error" in data: + mf.function_error = data["function_error"] + + return mf @classmethod def from_json_maybe(cls, json_wrapped: dict, attribute: str): @@ -220,7 +234,9 @@ class StaticFunction(ModelFunction): return self.value def to_json(self): - return {"type": "static", "value": self.value} + ret = super().to_json() + ret.update({"type": "static", "value": self.value}) + return ret @classmethod def from_json(cls, data): @@ -257,12 +273,16 @@ class SplitFunction(ModelFunction): return self.child[param_value].eval(param_list, arg_list) def to_json(self): - return { - "type": "split", - "value": self.value, - "paramIndex": self.param_index, - "child": dict([[k, v.to_json()] for k, v in self.child.items()]), - } + ret = super().to_json() + ret.update( + { + "type": "split", + "value": self.value, + "paramIndex": self.param_index, + "child": dict([[k, v.to_json()] for k, v in self.child.items()]), + } + ) + return ret @classmethod def from_json(cls, data): @@ -471,14 +491,18 @@ class AnalyticFunction(ModelFunction): return self._function(self.model_args, param_list) def to_json(self): - return { - "type": "analytic", - "value": self.value, - "functionStr": self.model_function, - "argCount": self._num_args, - "parameterNames": self._parameter_names, - "regressionModel": list(self.model_args), - } + ret = super().to_json() + ret.update( + { + "type": "analytic", + "value": self.value, + "functionStr": self.model_function, + "argCount": self._num_args, + "parameterNames": self._parameter_names, + "regressionModel": list(self.model_args), + } + ) + return ret @classmethod def from_json(cls, data): |