summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/lib/model.py b/lib/model.py
index 0550575..f53f645 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -698,6 +698,8 @@ class DecisionTreeModel:
self.by_param = by_name_to_by_param(by_name)
self.parameters = sorted(parameters)
+ self.dtree = None
+
# Dtree-Konzept: Für jeden Param: Split auf ==wert1 / == wert2 (-> zwei Partitionen der Daten)
# Falls Menge beeinflussender Parameter in den beiden Partitionen unterschiedlich oder
# Art der Abhängigkeit unterschiedlich: Hiernach muss der DTree unterscheiden.
@@ -750,6 +752,7 @@ class DecisionTreeModel:
candidates.append(
(state, attribute, param_index, param_name, param_values)
)
+ # TODO erstmal rausfinden, anhand welches Parameters recursive descent gemacht wird. Es soll ja immer nur einer sein.
# logger.debug('>>> recursive descent for {}/{} with {} = {}'.format(state, attribute, param_name, param_values[0]))
# __class__(by_name_sub1, self.parameters)
# logger.debug('<<< recursive descent for {}/{} with {} = {}'.format(state, attribute, param_name, param_values[0]))
@@ -765,6 +768,70 @@ class DecisionTreeModel:
(param_index, param_name, param_values)
)
+ for k, params in candidates_by_state_attribute.items():
+ state, attribute = k
+ min_mae_global = np.inf
+ min_mae_param = None
+ for param_index, param_name, param_values in params:
+ by_name_sub1 = grep_aggregate_by_state_and_param(
+ by_name, state, param_index, param_values[0]
+ )
+ by_name_sub2 = grep_aggregate_by_state_and_param(
+ by_name, state, param_index, param_values[1]
+ )
+ m1 = PTAModel(by_name_sub1, self.parameters, dict())
+ m2 = PTAModel(by_name_sub2, self.parameters, dict())
+ pm1, pi1 = m1.get_fitted()
+ pm2, pi2 = m2.get_fitted()
+ mae1 = m1.assess(pm1)["by_name"][state][attribute]["mae"]
+ mae2 = m1.assess(pm1)["by_name"][state][attribute]["mae"]
+ min_mae = min(mae1, mae2)
+ if min_mae < min_mae_global:
+ min_mae_global = min_mae
+ min_mae_param = (
+ param_index,
+ param_name,
+ param_values,
+ [by_name_sub1, by_name_sub2],
+ )
+ logger.debug(
+ f"{state} {attribute}: splitting on {min_mae_param[1]} gives lowest MAE ({min_mae_global})"
+ )
+ dtree = {
+ "param_index": min_mae_param[0],
+ "param_name": min_mae_param[1],
+ "min_mae": min_mae_global,
+ "param_value_subtree": dict(),
+ }
+ logger.debug(
+ f">>> recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][0]}"
+ )
+ dtree["param_value_subtree"][min_mae_param[2][0]] = __class__(
+ min_mae_param[3][0], self.parameters
+ ).get_tree(state, attribute)
+ logger.debug(
+ f"<<< recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][0]}"
+ )
+ logger.debug(
+ f">>> recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][1]}"
+ )
+ dtree["param_value_subtree"][min_mae_param[2][1]] = __class__(
+ min_mae_param[3][1], self.parameters
+ ).get_tree(state, attribute)
+ logger.debug(
+ f"<<< recursive descent for {state} {attribute} with {min_mae_param[1]}={min_mae_param[2][1]}"
+ )
+ if self.dtree is None:
+ self.dtree = dict()
+ self.dtree[(state, attribute)] = dtree
+
+ def get_tree(self, state=None, attribute=None):
+ if state is None or attribute is None:
+ return self.dtree
+ if self.dtree is not None and (state, attribute) in self.dtree:
+ return self.dtree[(state, attribute)]
+ return None
+
def states(self):
"""Return sorted list of state names."""
return sorted(