diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-04 15:37:05 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-04 15:37:05 +0100 |
commit | 26d949fb61f874e072f2a495c38ea343e912a02f (patch) | |
tree | 96a3b5f0bbf28da6d43f471cb4ddb3d9727f2016 /lib/automata.py | |
parent | 846c3f4420e67613c2a6a359d4b117ac942200b8 (diff) |
Restore --export-energymodel
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-x | lib/automata.py | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/lib/automata.py b/lib/automata.py index b1e5623..a89155c 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -276,7 +276,9 @@ class State: def to_json(self) -> dict: """Return JSON encoding of this state object.""" - ret = {"name": self.name, "power": self.power.to_json()} + ret = {"name": self.name, "power": None} + if self.power is not None: + ret["power"] = self.power.to_json() return ret @@ -428,10 +430,16 @@ class Transition: "argument_combination": self.argument_combination, "arg_to_param_map": self.arg_to_param_map, "set_param": self.set_param, - "duration": self.duration.to_json(), - "energy": self.energy.to_json(), - "timeout": self.timeout.to_json(), + "duration": None, + "energy": None, + "timeout": None, } + if self.duration is not None: + ret["duration"] = (self.duration.to_json(),) + if self.energy is not None: + ret["energy"] = (self.energy.to_json(),) + if self.timeout is not None: + ret["timeout"] = (self.timeout.to_json(),) return ret @@ -1158,21 +1166,15 @@ class PTA: energy_mae=np.sqrt(total_energy_error), ) - def update(self, static_model, param_model, static_error=None, analytic_error=None): + def update(self, model, static_error=None, function_error=None): for state in self.state.values(): if state.name != "UNINITIALIZED": try: - state.power.value = static_model(state.name, "power") + state.power = model(state.name, "power") if static_error is not None: state.power.value_error = static_error[state.name]["power"] - if param_model(state.name, "power"): - state.power.function = param_model(state.name, "power")[ - "function" - ] - if analytic_error is not None: - state.power.function_error = analytic_error[state.name][ - "power" - ] + if function_error is not None: + state.power.function_error = function_error[state.name]["power"] except KeyError: logger.warning( "skipping model update of state {} due to missing data".format( @@ -1182,34 +1184,10 @@ class PTA: pass for transition in self.transitions: try: - transition.duration.value = static_model(transition.name, "duration") - if param_model(transition.name, "duration"): - transition.duration.function = param_model( - transition.name, "duration" - )["function"] - if analytic_error is not None: - transition.duration.function_error = analytic_error[ - transition.name - ]["duration"] - transition.energy.value = static_model(transition.name, "energy") - if param_model(transition.name, "energy"): - transition.energy.function = param_model(transition.name, "energy")[ - "function" - ] - if analytic_error is not None: - transition.energy.function_error = analytic_error[ - transition.name - ]["energy"] + transition.duration = model(transition.name, "duration") + transition.energy = model(transition.name, "energy") if transition.is_interrupt: - transition.timeout.value = static_model(transition.name, "timeout") - if param_model(transition.name, "timeout"): - transition.timeout.function = param_model( - transition.name, "timeout" - )["function"] - if analytic_error is not None: - transition.timeout.function_error = analytic_error[ - transition.name - ]["timeout"] + transition.timeout = model(transition.name, "timeout") if static_error is not None: transition.duration.value_error = static_error[transition.name][ @@ -1218,9 +1196,21 @@ class PTA: transition.energy.value_error = static_error[transition.name][ "energy" ] - transition.timeout.value_error = static_error[transition.name][ - "timeout" + if transition.is_interrupt: + transition.timeout.value_error = static_error[transition.name][ + "timeout" + ] + if function_error is not None: + transition.duration.function_error = function_error[ + transition.name + ]["duration"] + transition.energy.function_error = function_error[transition.name][ + "energy" ] + if transition.is_interrupt: + transition.timeout.function_error = function_error[ + transition.name + ]["timeout"] except KeyError: logger.warning( "skipping model update of transition {} due to missing data".format( |