diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-12-23 15:33:45 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-12-23 15:33:45 +0100 |
commit | 3473ff227b9c085b5333147f0c2f6cb1e431c875 (patch) | |
tree | ab782568c0adea3e4b00be786299525f8b3e7692 /lib/model.py | |
parent | a9d538afbd9d766a35093e851fbe5c12112fb2eb (diff) |
model: add sklearn CART support (CART with scalar features)
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/lib/model.py b/lib/model.py index 9133196..4f5f60f 100644 --- a/lib/model.py +++ b/lib/model.py @@ -157,6 +157,9 @@ class AnalyticModel: with_nonbinary_nodes = bool( int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1")) ) + with_sklearn_cart = bool( + int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0")) + ) loss_ignore_scalar = bool( int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0")) ) @@ -169,6 +172,7 @@ class AnalyticModel: threshold=threshold, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, + with_sklearn_cart=with_sklearn_cart, loss_ignore_scalar=loss_ignore_scalar, ) self.fit_done = True @@ -317,6 +321,12 @@ class AnalyticModel: with_nonbinary_nodes = bool( int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1")) ) + with_sklearn_cart = bool( + int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0")) + ) + loss_ignore_scalar = bool( + int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0")) + ) threshold = self.attr_by_name[name][attr].stats.std_param_lut if ( self.dtree_max_std @@ -324,9 +334,6 @@ class AnalyticModel: and attr in self.dtree_max_std[name] ): threshold = self.dtree_max_std[name][attr] - loss_ignore_scalar = bool( - int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0")) - ) logger.debug( f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves}, with_nonbinary_nodes={with_nonbinary_nodes}, loss_ignore_scalar={loss_ignore_scalar})" ) @@ -336,6 +343,7 @@ class AnalyticModel: threshold=threshold, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, + with_sklearn_cart=with_sklearn_cart, loss_ignore_scalar=loss_ignore_scalar, ) else: @@ -414,6 +422,7 @@ class AnalyticModel: threshold=100, with_function_leaves=False, with_nonbinary_nodes=True, + with_sklearn_cart=False, loss_ignore_scalar=False, ): @@ -435,6 +444,7 @@ class AnalyticModel: self.by_name[name][attribute], with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, + with_sklearn_cart=with_sklearn_cart, loss_ignore_scalar=loss_ignore_scalar, threshold=threshold, ) |