From e654b3dd34c32cb5ae0320e1064bf7934d500e85 Mon Sep 17 00:00:00 2001
From: Birte Kristina Friesel <birte.friesel@uos.de>
Date: Mon, 15 Jan 2024 16:21:01 +0100
Subject: improve(?) loss_ignore_scalar implementation: more opportunities for
 FFS

---
 lib/parameters.py | 50 ++++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 44 insertions(+), 6 deletions(-)

diff --git a/lib/parameters.py b/lib/parameters.py
index a00e376..0c4401c 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1284,7 +1284,7 @@ class ModelAttribute:
             children = list()
             for child in child_indexes:
                 child_data = list(map(lambda i: data[i], child))
-                if loss_ignore_scalar:
+                if loss_ignore_scalar and False:
                     child_param = list(map(lambda i: parameters[i], child))
                     child_data_by_scalar = partition_by_param(
                         child_data,
@@ -1292,7 +1292,9 @@ class ModelAttribute:
                         ignore_parameters=list(self.ignore_param.keys())
                         + ffs_unsuitable_params,
                     )
+                    logger.debug(f"got {len(child_data_by_scalar)} partitions")
                     for sub_data in child_data_by_scalar.values():
+                        assert len(sub_data)
                         children.extend((np.array(sub_data) - np.mean(sub_data)) ** 2)
                 else:
                     children.extend((np.array(child_data) - np.mean(child_data)) ** 2)
@@ -1303,7 +1305,6 @@ class ModelAttribute:
                 loss.append(np.sum(children))
 
         if np.all(np.isinf(loss)):
-            # all children have the same configuration. We shouldn't get here due to the threshold check above...
             if ffs_feasible:
                 # try generating a function. if it fails, model_function is a StaticFunction.
                 ma = ModelAttribute(
@@ -1323,12 +1324,48 @@ class ModelAttribute:
                 paramfit.fit()
                 ma.set_data_from_paramfit(paramfit)
                 return ma.model_function
-            # else:
-            #    logging.warning(
-            #        f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction"
-            #    )
             return df.StaticFunction(np.mean(data))
 
+        split_feasible = True
+        if loss_ignore_scalar:
+            data_by_scalar = partition_by_param(
+                data,
+                parameters,
+                ignore_parameters=list(self.ignore_param.keys())
+                + ffs_unsuitable_params,
+            )
+            if np.all(
+                np.array([np.std(partition) for partition in data_by_scalar.values()])
+                <= threshold
+            ):
+                # Varying non-scalar params in partitions with fixed scalar params does not affect system behaviour
+                # -> further non-scalar splits are _probably_ not sensible
+                # (_probably_ because this implicitly assumes that there are multiple scalar configurations for each non-scalar configuration.
+                split_feasible = False
+
+        if ffs_feasible and not split_feasible:
+            # There is a _probably_ above: the heuristic assumes that there are multiple scalar configurations for each non-scalar configuration.
+            # If there is just one it may recommend to stop splitting too early.
+            # Hence, we will try generating an FFS leaf node here, but continue splitting if it turns out that it is no good.
+            ma = ModelAttribute(
+                self.name + "_",
+                self.attr,
+                data,
+                parameters,
+                self.param_names,
+                arg_count=self.arg_count,
+                param_type=self.param_type,
+                codependent_param=codependent_param_dict(parameters),
+            )
+            ParamStats.compute_for_attr(ma)
+            paramfit = ParamFit(parallel=False)
+            for key, param, args, kwargs in ma.get_data_for_paramfit():
+                paramfit.enqueue(key, param, args, kwargs)
+            paramfit.fit()
+            ma.set_data_from_paramfit(paramfit)
+            if type(ma.model_function) == df.AnalyticFunction:
+                return ma.model_function
+
         symbol_index = np.argmin(loss)
         unique_values = list(set(map(lambda p: p[symbol_index], parameters)))
 
@@ -1352,6 +1389,7 @@ class ModelAttribute:
                 ignore_irrelevant_parameters=ignore_irrelevant_parameters,
                 loss_ignore_scalar=loss_ignore_scalar,
                 threshold=threshold,
+                relevance_threshold=relevance_threshold,
                 level=level + 1,
             )
 
-- 
cgit v1.2.3