summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 9350a06..4e98f54 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -666,7 +666,11 @@ class ModelAttribute:
def to_dref(self, unit=None):
ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
- if type(self.model_function) in (df.SplitFunction, df.CARTFunction):
+ if type(self.model_function) in (
+ df.SplitFunction,
+ df.CARTFunction,
+ df.XGBoostFunction,
+ ):
ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
ret["decision tree/max depth"] = self.model_function.get_max_depth()
@@ -961,7 +965,7 @@ class ModelAttribute:
self.model_function = df.StaticFunction(np.mean(data))
return
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
- self.model_function = df.SKLearnRegressionFunction(
+ self.model_function = df.XGBoostFunction(
np.mean(data), xgb, category_to_index, ignore_index
)
output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None)