diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/lib/model.py b/lib/model.py index 0550575..f53f645 100644 --- a/lib/model.py +++ b/lib/model.py @@ -698,6 +698,8 @@ class DecisionTreeModel: self.by_param = by_name_to_by_param(by_name) self.parameters = sorted(parameters) + self.dtree = None + # Dtree-Konzept: Für jeden Param: Split auf ==wert1 / == wert2 (-> zwei Partitionen der Daten) # Falls Menge beeinflussender Parameter in den beiden Partitionen unterschiedlich oder # Art der Abhängigkeit unterschiedlich: Hiernach muss der DTree unterscheiden. @@ -750,6 +752,7 @@ class DecisionTreeModel: candidates.append( (state, attribute, param_index, param_name, param_values) ) + # TODO erstmal rausfinden, anhand welches Parameters recursive descent gemacht wird. Es soll ja immer nur einer sein. # logger.debug('>>> recursive descent for {}/{} with {} = {}'.format(state, attribute, param_name, param_values[0])) # __class__(by_name_sub1, self.parameters) # logger.debug('<<< recursive descent for {}/{} with {} = {}'.format(state, attribute, param_name, param_values[0])) @@ -765,6 +768,70 @@ class DecisionTreeModel: (param_index, param_name, param_values) ) + for k, params in candidates_by_state_attribute.items(): + state, attribute = k + min_mae_global = np.inf + min_mae_param = None + for param_index, param_name, param_values in params: + by_name_sub1 = grep_aggregate_by_state_and_param( + by_name, state, param_index, param_values[0] + ) + by_name_sub2 = grep_aggregate_by_state_and_param( + by_name, state, param_index, param_values[1] + ) + m1 = PTAModel(by_name_sub1, self.parameters, dict()) + m2 = PTAModel(by_name_sub2, self.parameters, dict()) + pm1, pi1 = m1.get_fitted() + pm2, pi2 = m2.get_fitted() + mae1 = m1.assess(pm1)["by_name"][state][attribute]["mae"] + mae2 = m1.assess(pm1)["by_name"][state][attribute]["mae"] + min_mae = min(mae1, mae2) + if min_mae < min_mae_global: + min_mae_global = min_mae + min_mae_param = ( + param_index, + param_name, + param_values, + [by_name_sub1, by_name_sub2], + ) + logger.debug( + f"{state} {attribute}: splitting on {min_mae_param[1]} gives lowest MAE ({min_mae_global})" + ) + dtree = { + "param_index": min_mae_param[0], + "param_name": min_mae_param[1], + "min_mae": min_mae_global, + "param_value_subtree": dict(), + } + logger.debug( + f">>> recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][0]}" + ) + dtree["param_value_subtree"][min_mae_param[2][0]] = __class__( + min_mae_param[3][0], self.parameters + ).get_tree(state, attribute) + logger.debug( + f"<<< recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][0]}" + ) + logger.debug( + f">>> recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][1]}" + ) + dtree["param_value_subtree"][min_mae_param[2][1]] = __class__( + min_mae_param[3][1], self.parameters + ).get_tree(state, attribute) + logger.debug( + f"<<< recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][1]}" + ) + if self.dtree is None: + self.dtree = dict() + self.dtree[(state, attribute)] = dtree + + def get_tree(self, state=None, attribute=None): + if state is None or attribute is None: + return self.dtree + if self.dtree is not None and (state, attribute) in self.dtree: + return self.dtree[(state, attribute)] + return None + def states(self): """Return sorted list of state names.""" return sorted( |