diff options
-rw-r--r-- | lib/behaviour.py | 36 | ||||
-rw-r--r-- | lib/utils.py | 11 |
2 files changed, 36 insertions, 11 deletions
diff --git a/lib/behaviour.py b/lib/behaviour.py index a0ceb95..626d3c2 100644 --- a/lib/behaviour.py +++ b/lib/behaviour.py @@ -74,7 +74,7 @@ class SDKBehaviourModel: flat_model = info(name, t_from).flatten() else: flat_model = list() - logging.warning( + logger.warning( f"Model for {name} {t_from} is {info(name, t_from)}, expected SplitFunction" ) @@ -94,12 +94,32 @@ class SDKBehaviourModel: while current_state != "__end__": next_states = delta[current_state] + states_seen.add(current_state) next_states = list(filter(lambda q: q not in states_seen, next_states)) if len(next_states) == 0: raise RuntimeError( f"get_trace({name}, {param_dict}): found infinite loop at {trace}" ) + + if len(next_states) > 1 and self.transition_guard[current_state]: + matching_next_states = list() + for candidate in next_states: + for condition in self.transition_guard[current_state][candidate]: + valid = True + for key, value in condition: + if param_dict[key] != value: + valid = False + break + if valid: + matching_next_states.append(candidate) + break + next_states = matching_next_states + + if len(next_states) == 0: + raise RuntimeError( + f"get_trace({name}, {param_dict}): found no valid outbound transitions at {trace}, candidates {self.transition_guard[current_state]}" + ) if len(next_states) > 1: raise RuntimeError( f"get_trace({name}, {param_dict}): found non-deterministic outbound transitions {next_states} at {trace}" @@ -108,10 +128,8 @@ class SDKBehaviourModel: (next_state,) = next_states trace.append(next_state) - states_seen.add(current_state) current_state = next_state - print(trace) return trace def learn_pta(self, observations, annotation, delta=dict(), delta_param=dict()): @@ -130,7 +148,7 @@ class SDKBehaviourModel: if this in n_seen: if n_seen[this] == 1: - logging.debug( + logger.debug( f"Loop found in {annotation.start.name} {annotation.end.param}: {this} ⟳" ) n_seen[this] += 1 @@ -219,7 +237,7 @@ class SDKBehaviourModel: if this in n_seen: if n_seen[this] == 1: - logging.debug( + logger.debug( f"Loop found in {annotation.start.name} {annotation.end.param}: {this} ⟳" ) n_seen[this] += 1 @@ -332,7 +350,7 @@ class EventSequenceModel: param_list = utils.param_dict_to_list(param, ref_model.parameters) if not use_lut and not param_info(name, action).is_predictable(param_list): - logging.warning( + logger.warning( f"Cannot predict {name}.{action}({param}), falling back to static model" ) @@ -346,15 +364,15 @@ class EventSequenceModel: ) except KeyError: if use_lut: - logging.error( + logger.error( f"Cannot predict {name}.{action}({param}) from LUT model" ) else: - logging.error(f"Cannot predict {name}.{action}({param}) from model") + logger.error(f"Cannot predict {name}.{action}({param}) from model") raise except TypeError: if not use_lut: - logging.error(f"Cannot predict {name}.{action}({param}) from model") + logger.error(f"Cannot predict {name}.{action}({param}) from model") raise if aggregate == "sum": diff --git a/lib/utils.py b/lib/utils.py index 48a29d8..fb76367 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -728,11 +728,18 @@ def regression_measures(predicted: np.ndarray, ground_truth: np.ndarray): rsq -- R^2 measure, see sklearn.metrics.r2_score count -- Number of values """ - if type(predicted) != np.ndarray: + + if type(predicted) is list: + predicted = np.array(predicted) + + if type(ground_truth) is list: + ground_truth = np.array(ground_truth) + + if type(predicted) is not np.ndarray: raise ValueError( "first arg ('predicted') must be ndarray, is {}".format(type(predicted)) ) - if type(ground_truth) != np.ndarray: + if type(ground_truth) is not np.ndarray: raise ValueError( "second arg ('ground_truth') must be ndarray, is {}".format( type(ground_truth) |