diff options
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index e199153..9cfe145 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -38,6 +38,37 @@ def distinct_param_values(param_tuples): return distinct_values +def param_to_ndarray(param_tuples, with_nan=True): + has_nan = dict() + has_non_numeric = dict() + + for param_tuple in param_tuples: + for i, param in enumerate(param_tuple): + if not is_numeric(param): + if param is None: + has_nan[i] = True + else: + has_non_numeric[i] = True + + ignore_index = dict() + for i in range(len(param_tuples[0])): + if has_non_numeric.get(i, False): + ignore_index[i] = True + elif not with_nan and has_nan.get(i, False): + ignore_index[i] = True + else: + ignore_index[i] = False + + ret_tuples = list() + for param_tuple in param_tuples: + ret_tuple = list() + for i, param in enumerate(param_tuple): + if not ignore_index[i]: + ret_tuple.append(param) + ret_tuples.append(ret_tuple) + return np.asarray(ret_tuples), ignore_index + + def _depends_on_param(corr_param, std_param, std_lut): # if self.use_corrcoef: if False: @@ -843,6 +874,7 @@ class ModelAttribute: data, with_function_leaves=False, with_nonbinary_nodes=True, + with_sklearn_cart=False, loss_ignore_scalar=False, threshold=100, ): @@ -853,12 +885,28 @@ class ModelAttribute: :param data: Measurements. [data 1, data 2, data 3, ...] :param with_function_leaves: Use fitted function sets to generate function leaves for scalar parameters :param with_nonbinary_nodes: Allow non-binary nodes for enum and scalar parameters (i.e., nodes with more than two children) + :param with_sklearn_cart: Use `sklearn.tree.DecisionTreeRegressor` CART implementation for tree generation. Does not support categorial (enum) + and sparse parameters. Both are ignored during fitting. All other options are ignored as well. :param loss_ignore_scalar: Ignore scalar parameters when computing the loss for split candidates. Only sensible if with_function_leaves is enabled. :param threshold: Return a StaticFunction leaf node if std(data) < threshold. Default 100. :returns: SplitFunction or StaticFunction """ + if with_sklearn_cart: + from sklearn.tree import DecisionTreeRegressor + + max_depth = int(os.getenv("DFATOOL_CART_MAX_DEPTH", "0")) + if max_depth == 0: + max_depth = None + cart = DecisionTreeRegressor(max_depth=max_depth) + fit_parameters, ignore_index = param_to_ndarray(parameters, with_nan=False) + cart.fit(fit_parameters, data) + self.model_function = df.SKLearnRegressionFunction( + np.mean(data), cart, ignore_index + ) + return + if loss_ignore_scalar and not with_function_leaves: logger.warning( "build_dtree called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." |