summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xlib/automata.py76
-rw-r--r--lib/functions.py70
-rw-r--r--lib/model.py3
3 files changed, 81 insertions, 68 deletions
diff --git a/lib/automata.py b/lib/automata.py
index b1e5623..a89155c 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -276,7 +276,9 @@ class State:
def to_json(self) -> dict:
"""Return JSON encoding of this state object."""
- ret = {"name": self.name, "power": self.power.to_json()}
+ ret = {"name": self.name, "power": None}
+ if self.power is not None:
+ ret["power"] = self.power.to_json()
return ret
@@ -428,10 +430,16 @@ class Transition:
"argument_combination": self.argument_combination,
"arg_to_param_map": self.arg_to_param_map,
"set_param": self.set_param,
- "duration": self.duration.to_json(),
- "energy": self.energy.to_json(),
- "timeout": self.timeout.to_json(),
+ "duration": None,
+ "energy": None,
+ "timeout": None,
}
+ if self.duration is not None:
+ ret["duration"] = (self.duration.to_json(),)
+ if self.energy is not None:
+ ret["energy"] = (self.energy.to_json(),)
+ if self.timeout is not None:
+ ret["timeout"] = (self.timeout.to_json(),)
return ret
@@ -1158,21 +1166,15 @@ class PTA:
energy_mae=np.sqrt(total_energy_error),
)
- def update(self, static_model, param_model, static_error=None, analytic_error=None):
+ def update(self, model, static_error=None, function_error=None):
for state in self.state.values():
if state.name != "UNINITIALIZED":
try:
- state.power.value = static_model(state.name, "power")
+ state.power = model(state.name, "power")
if static_error is not None:
state.power.value_error = static_error[state.name]["power"]
- if param_model(state.name, "power"):
- state.power.function = param_model(state.name, "power")[
- "function"
- ]
- if analytic_error is not None:
- state.power.function_error = analytic_error[state.name][
- "power"
- ]
+ if function_error is not None:
+ state.power.function_error = function_error[state.name]["power"]
except KeyError:
logger.warning(
"skipping model update of state {} due to missing data".format(
@@ -1182,34 +1184,10 @@ class PTA:
pass
for transition in self.transitions:
try:
- transition.duration.value = static_model(transition.name, "duration")
- if param_model(transition.name, "duration"):
- transition.duration.function = param_model(
- transition.name, "duration"
- )["function"]
- if analytic_error is not None:
- transition.duration.function_error = analytic_error[
- transition.name
- ]["duration"]
- transition.energy.value = static_model(transition.name, "energy")
- if param_model(transition.name, "energy"):
- transition.energy.function = param_model(transition.name, "energy")[
- "function"
- ]
- if analytic_error is not None:
- transition.energy.function_error = analytic_error[
- transition.name
- ]["energy"]
+ transition.duration = model(transition.name, "duration")
+ transition.energy = model(transition.name, "energy")
if transition.is_interrupt:
- transition.timeout.value = static_model(transition.name, "timeout")
- if param_model(transition.name, "timeout"):
- transition.timeout.function = param_model(
- transition.name, "timeout"
- )["function"]
- if analytic_error is not None:
- transition.timeout.function_error = analytic_error[
- transition.name
- ]["timeout"]
+ transition.timeout = model(transition.name, "timeout")
if static_error is not None:
transition.duration.value_error = static_error[transition.name][
@@ -1218,9 +1196,21 @@ class PTA:
transition.energy.value_error = static_error[transition.name][
"energy"
]
- transition.timeout.value_error = static_error[transition.name][
- "timeout"
+ if transition.is_interrupt:
+ transition.timeout.value_error = static_error[transition.name][
+ "timeout"
+ ]
+ if function_error is not None:
+ transition.duration.function_error = function_error[
+ transition.name
+ ]["duration"]
+ transition.energy.function_error = function_error[transition.name][
+ "energy"
]
+ if transition.is_interrupt:
+ transition.timeout.function_error = function_error[
+ transition.name
+ ]["timeout"]
except KeyError:
logger.warning(
"skipping model update of transition {} due to missing data".format(
diff --git a/lib/functions.py b/lib/functions.py
index 3031a01..48049b5 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -160,7 +160,8 @@ class ModelFunction:
# parameter combination, and for use cases requiring static models
self.value = value
- # Legacy(?) attributes for PTA
+ # A ModelFunction may track its own accuracy, both of the static value and of the eval() method.
+ # However, it does not specify how the accuracy was calculated (e.g. which data was used and whether cross-validation was performed)
self.value_error = None
self.function_error = None
@@ -171,17 +172,30 @@ class ModelFunction:
raise NotImplementedError
def to_json(self):
- raise NotImplementedError
+ ret = {
+ "value": self.value,
+ "value_error": self.value_error,
+ "function_error": self.function_error,
+ }
+ return ret
@classmethod
def from_json(cls, data):
if data["type"] == "static":
- return StaticFunction.from_json(data)
- if data["type"] == "split":
- return SplitFunction.from_json(data)
- if data["type"] == "analytic":
- return AnalyticFunction.from_json(data)
- raise ValueError("Unknown ModelFunction type: " + data["type"])
+ mf = StaticFunction.from_json(data)
+ elif data["type"] == "split":
+ mf = SplitFunction.from_json(data)
+ elif data["type"] == "analytic":
+ mf = AnalyticFunction.from_json(data)
+ else:
+ raise ValueError("Unknown ModelFunction type: " + data["type"])
+
+ if "value_error" in data:
+ mf.value_error = data["value_error"]
+ if "function_error" in data:
+ mf.function_error = data["function_error"]
+
+ return mf
@classmethod
def from_json_maybe(cls, json_wrapped: dict, attribute: str):
@@ -220,7 +234,9 @@ class StaticFunction(ModelFunction):
return self.value
def to_json(self):
- return {"type": "static", "value": self.value}
+ ret = super().to_json()
+ ret.update({"type": "static", "value": self.value})
+ return ret
@classmethod
def from_json(cls, data):
@@ -257,12 +273,16 @@ class SplitFunction(ModelFunction):
return self.child[param_value].eval(param_list, arg_list)
def to_json(self):
- return {
- "type": "split",
- "value": self.value,
- "paramIndex": self.param_index,
- "child": dict([[k, v.to_json()] for k, v in self.child.items()]),
- }
+ ret = super().to_json()
+ ret.update(
+ {
+ "type": "split",
+ "value": self.value,
+ "paramIndex": self.param_index,
+ "child": dict([[k, v.to_json()] for k, v in self.child.items()]),
+ }
+ )
+ return ret
@classmethod
def from_json(cls, data):
@@ -471,14 +491,18 @@ class AnalyticFunction(ModelFunction):
return self._function(self.model_args, param_list)
def to_json(self):
- return {
- "type": "analytic",
- "value": self.value,
- "functionStr": self.model_function,
- "argCount": self._num_args,
- "parameterNames": self._parameter_names,
- "regressionModel": list(self.model_args),
- }
+ ret = super().to_json()
+ ret.update(
+ {
+ "type": "analytic",
+ "value": self.value,
+ "functionStr": self.model_function,
+ "argCount": self._num_args,
+ "parameterNames": self._parameter_names,
+ "regressionModel": list(self.model_args),
+ }
+ )
+ return ret
@classmethod
def from_json(cls, data):
diff --git a/lib/model.py b/lib/model.py
index ccbf719..4c4c226 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -790,10 +790,9 @@ class PTAModel(AnalyticModel):
if pta is None:
pta = PTA(self.states, parameters=self._parameter_names)
pta.update(
- static_model,
param_info,
static_error=static_quality["by_name"],
- analytic_error=analytic_quality["by_name"],
+ function_error=analytic_quality["by_name"],
)
return pta.to_json()