diff options
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 9350a06..4e98f54 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -666,7 +666,11 @@ class ModelAttribute: def to_dref(self, unit=None): ret = {"mean": (self.mean, unit), "median": (self.median, unit)} - if type(self.model_function) in (df.SplitFunction, df.CARTFunction): + if type(self.model_function) in ( + df.SplitFunction, + df.CARTFunction, + df.XGBoostFunction, + ): ret["decision tree/nodes"] = self.model_function.get_number_of_nodes() ret["decision tree/max depth"] = self.model_function.get_max_depth() @@ -961,7 +965,7 @@ class ModelAttribute: self.model_function = df.StaticFunction(np.mean(data)) return xgb.fit(fit_parameters, np.reshape(data, (-1, 1))) - self.model_function = df.SKLearnRegressionFunction( + self.model_function = df.XGBoostFunction( np.mean(data), xgb, category_to_index, ignore_index ) output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None) |