diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/model.py | 35 | ||||
-rw-r--r-- | lib/parameters.py | 34 |
2 files changed, 56 insertions, 13 deletions
diff --git a/lib/model.py b/lib/model.py index 2ee3c03..10e2928 100644 --- a/lib/model.py +++ b/lib/model.py @@ -149,10 +149,19 @@ class AnalyticModel: with_function_leaves = bool( int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) ) + with_nonbinary_nodes = bool( + int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1")) + ) logger.debug( - f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves}, with_nonbinary_nodes={with_nonbinary_nodes})" + ) + self.build_dtree( + name, + attr, + threshold=threshold, + with_function_leaves=with_function_leaves, + with_nonbinary_nodes=with_nonbinary_nodes, ) - self.build_dtree(name, attr, threshold, with_function_leaves) self.fit_done = True def __repr__(self): @@ -295,6 +304,9 @@ class AnalyticModel: with_function_leaves = bool( int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) ) + with_nonbinary_nodes = bool( + int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1")) + ) threshold = self.attr_by_name[name][attr].stats.std_param_lut if ( self.dtree_max_std @@ -303,13 +315,14 @@ class AnalyticModel: ): threshold = self.dtree_max_std[name][attr] logger.debug( - f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves}, with_nonbinary_nodes={with_nonbinary_nodes})" ) self.build_dtree( name, attr, - threshold, + threshold=threshold, with_function_leaves=with_function_leaves, + with_nonbinary_nodes=with_nonbinary_nodes, ) else: self.attr_by_name[name][attr].set_data_from_paramfit(paramfit) @@ -380,7 +393,14 @@ class AnalyticModel: return detailed_results - def build_dtree(self, name, attribute, threshold=100, with_function_leaves=False): + def build_dtree( + self, + name, + attribute, + threshold=100, + with_function_leaves=False, + with_nonbinary_nodes=True, + ): if name not in self.attr_by_name: self.attr_by_name[name] = dict() @@ -404,8 +424,9 @@ class AnalyticModel: self.attr_by_name[name][attribute].build_dtree( self.by_name[name]["param"], self.by_name[name][attribute], - with_function_leaves, - threshold, + with_function_leaves=with_function_leaves, + with_nonbinary_nodes=with_nonbinary_nodes, + threshold=threshold, ) def to_dref(self, static_quality, lut_quality, model_quality) -> dict: diff --git a/lib/parameters.py b/lib/parameters.py index a290db0..5ebf25c 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -786,7 +786,14 @@ class ModelAttribute: if x.fit_success: self.model_function = x - def build_dtree(self, parameters, data, with_function_leaves=False, threshold=100): + def build_dtree( + self, + parameters, + data, + with_function_leaves=False, + with_nonbinary_nodes=True, + threshold=100, + ): """ Build a Decision Tree on `param` / `data` for kconfig models. @@ -799,11 +806,21 @@ class ModelAttribute: :returns: SplitFunction or StaticFunction """ self.model_function = self._build_dtree( - parameters, data, with_function_leaves, threshold + parameters, + data, + with_function_leaves=with_function_leaves, + with_nonbinary_nodes=with_nonbinary_nodes, + threshold=threshold, ) def _build_dtree( - self, parameters, data, with_function_leaves=False, threshold=100, level=0 + self, + parameters, + data, + with_function_leaves=False, + with_nonbinary_nodes=True, + threshold=100, + level=0, ): """ Build a Decision Tree on `param` / `data` for kconfig models. @@ -838,6 +855,10 @@ class ModelAttribute: mean_stds.append(np.inf) continue + if not with_nonbinary_nodes and len(unique_values) > 2: + mean_stds.append(np.inf) + continue + if ( with_function_leaves and len(unique_values) >= self.min_values_for_analytic_model @@ -915,9 +936,10 @@ class ModelAttribute: child[value] = self._build_dtree( child_parameters, child_data, - with_function_leaves, - threshold, - level + 1, + with_function_leaves=with_function_leaves, + with_nonbinary_nodes=with_nonbinary_nodes, + threshold=threshold, + level=level + 1, ) assert len(child.values()) >= 2 |