diff options
-rwxr-xr-x | lib/automata.py | 76 | ||||
-rw-r--r-- | lib/functions.py | 70 | ||||
-rw-r--r-- | lib/model.py | 3 |
3 files changed, 81 insertions, 68 deletions
diff --git a/lib/automata.py b/lib/automata.py index b1e5623..a89155c 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -276,7 +276,9 @@ class State: def to_json(self) -> dict: """Return JSON encoding of this state object.""" - ret = {"name": self.name, "power": self.power.to_json()} + ret = {"name": self.name, "power": None} + if self.power is not None: + ret["power"] = self.power.to_json() return ret @@ -428,10 +430,16 @@ class Transition: "argument_combination": self.argument_combination, "arg_to_param_map": self.arg_to_param_map, "set_param": self.set_param, - "duration": self.duration.to_json(), - "energy": self.energy.to_json(), - "timeout": self.timeout.to_json(), + "duration": None, + "energy": None, + "timeout": None, } + if self.duration is not None: + ret["duration"] = (self.duration.to_json(),) + if self.energy is not None: + ret["energy"] = (self.energy.to_json(),) + if self.timeout is not None: + ret["timeout"] = (self.timeout.to_json(),) return ret @@ -1158,21 +1166,15 @@ class PTA: energy_mae=np.sqrt(total_energy_error), ) - def update(self, static_model, param_model, static_error=None, analytic_error=None): + def update(self, model, static_error=None, function_error=None): for state in self.state.values(): if state.name != "UNINITIALIZED": try: - state.power.value = static_model(state.name, "power") + state.power = model(state.name, "power") if static_error is not None: state.power.value_error = static_error[state.name]["power"] - if param_model(state.name, "power"): - state.power.function = param_model(state.name, "power")[ - "function" - ] - if analytic_error is not None: - state.power.function_error = analytic_error[state.name][ - "power" - ] + if function_error is not None: + state.power.function_error = function_error[state.name]["power"] except KeyError: logger.warning( "skipping model update of state {} due to missing data".format( @@ -1182,34 +1184,10 @@ class PTA: pass for transition in self.transitions: try: - transition.duration.value = static_model(transition.name, "duration") - if param_model(transition.name, "duration"): - transition.duration.function = param_model( - transition.name, "duration" - )["function"] - if analytic_error is not None: - transition.duration.function_error = analytic_error[ - transition.name - ]["duration"] - transition.energy.value = static_model(transition.name, "energy") - if param_model(transition.name, "energy"): - transition.energy.function = param_model(transition.name, "energy")[ - "function" - ] - if analytic_error is not None: - transition.energy.function_error = analytic_error[ - transition.name - ]["energy"] + transition.duration = model(transition.name, "duration") + transition.energy = model(transition.name, "energy") if transition.is_interrupt: - transition.timeout.value = static_model(transition.name, "timeout") - if param_model(transition.name, "timeout"): - transition.timeout.function = param_model( - transition.name, "timeout" - )["function"] - if analytic_error is not None: - transition.timeout.function_error = analytic_error[ - transition.name - ]["timeout"] + transition.timeout = model(transition.name, "timeout") if static_error is not None: transition.duration.value_error = static_error[transition.name][ @@ -1218,9 +1196,21 @@ class PTA: transition.energy.value_error = static_error[transition.name][ "energy" ] - transition.timeout.value_error = static_error[transition.name][ - "timeout" + if transition.is_interrupt: + transition.timeout.value_error = static_error[transition.name][ + "timeout" + ] + if function_error is not None: + transition.duration.function_error = function_error[ + transition.name + ]["duration"] + transition.energy.function_error = function_error[transition.name][ + "energy" ] + if transition.is_interrupt: + transition.timeout.function_error = function_error[ + transition.name + ]["timeout"] except KeyError: logger.warning( "skipping model update of transition {} due to missing data".format( diff --git a/lib/functions.py b/lib/functions.py index 3031a01..48049b5 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -160,7 +160,8 @@ class ModelFunction: # parameter combination, and for use cases requiring static models self.value = value - # Legacy(?) attributes for PTA + # A ModelFunction may track its own accuracy, both of the static value and of the eval() method. + # However, it does not specify how the accuracy was calculated (e.g. which data was used and whether cross-validation was performed) self.value_error = None self.function_error = None @@ -171,17 +172,30 @@ class ModelFunction: raise NotImplementedError def to_json(self): - raise NotImplementedError + ret = { + "value": self.value, + "value_error": self.value_error, + "function_error": self.function_error, + } + return ret @classmethod def from_json(cls, data): if data["type"] == "static": - return StaticFunction.from_json(data) - if data["type"] == "split": - return SplitFunction.from_json(data) - if data["type"] == "analytic": - return AnalyticFunction.from_json(data) - raise ValueError("Unknown ModelFunction type: " + data["type"]) + mf = StaticFunction.from_json(data) + elif data["type"] == "split": + mf = SplitFunction.from_json(data) + elif data["type"] == "analytic": + mf = AnalyticFunction.from_json(data) + else: + raise ValueError("Unknown ModelFunction type: " + data["type"]) + + if "value_error" in data: + mf.value_error = data["value_error"] + if "function_error" in data: + mf.function_error = data["function_error"] + + return mf @classmethod def from_json_maybe(cls, json_wrapped: dict, attribute: str): @@ -220,7 +234,9 @@ class StaticFunction(ModelFunction): return self.value def to_json(self): - return {"type": "static", "value": self.value} + ret = super().to_json() + ret.update({"type": "static", "value": self.value}) + return ret @classmethod def from_json(cls, data): @@ -257,12 +273,16 @@ class SplitFunction(ModelFunction): return self.child[param_value].eval(param_list, arg_list) def to_json(self): - return { - "type": "split", - "value": self.value, - "paramIndex": self.param_index, - "child": dict([[k, v.to_json()] for k, v in self.child.items()]), - } + ret = super().to_json() + ret.update( + { + "type": "split", + "value": self.value, + "paramIndex": self.param_index, + "child": dict([[k, v.to_json()] for k, v in self.child.items()]), + } + ) + return ret @classmethod def from_json(cls, data): @@ -471,14 +491,18 @@ class AnalyticFunction(ModelFunction): return self._function(self.model_args, param_list) def to_json(self): - return { - "type": "analytic", - "value": self.value, - "functionStr": self.model_function, - "argCount": self._num_args, - "parameterNames": self._parameter_names, - "regressionModel": list(self.model_args), - } + ret = super().to_json() + ret.update( + { + "type": "analytic", + "value": self.value, + "functionStr": self.model_function, + "argCount": self._num_args, + "parameterNames": self._parameter_names, + "regressionModel": list(self.model_args), + } + ) + return ret @classmethod def from_json(cls, data): diff --git a/lib/model.py b/lib/model.py index ccbf719..4c4c226 100644 --- a/lib/model.py +++ b/lib/model.py @@ -790,10 +790,9 @@ class PTAModel(AnalyticModel): if pta is None: pta = PTA(self.states, parameters=self._parameter_names) pta.update( - static_model, param_info, static_error=static_quality["by_name"], - analytic_error=analytic_quality["by_name"], + function_error=analytic_quality["by_name"], ) return pta.to_json() |