diff options
-rw-r--r-- | lib/cli.py | 14 | ||||
-rw-r--r-- | lib/parameters.py | 13 |
2 files changed, 22 insertions, 5 deletions
@@ -17,8 +17,18 @@ def sanity_check(args): ) sys.exit(1) if args.skip_param_stats and not args.force_tree: - print("--skip-param-stats requires --force-tree", file=sys.stderr) - sys.exit(1) + if bool(int(os.getenv("DFATOOL_FIT_FOL", "0"))): + print( + "Note: DFATOOL_FIT_FOL=1 relies on param stats to skip useless features.", + file=sys.stderr, + ) + print( + "Disabling it via --skip-param-stats will likely lead to unsatisfactory results.", + file=sys.stderr, + ) + else: + print("--skip-param-stats requires --force-tree", file=sys.stderr) + sys.exit(1) def print_static(model, static_model, name, attribute, with_dependence=False): diff --git a/lib/parameters.py b/lib/parameters.py index 7cb1314..a101dca 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -585,6 +585,8 @@ class ModelAttribute: # LUT model used as upper bound of model accuracy self.by_param = None # set via ParallelParamStats or get_by_param + self.stats = None # set via ParallelParamStats + # param model override self.function_override = None @@ -911,9 +913,14 @@ class ModelAttribute: for param_index, param in enumerate(self.param_names): if not self.stats.depends_on_param(param): ignore_param_indexes.append(param_index) - for param_index, _ in enumerate(self.param_names): - if len(self.stats.distinct_values_by_param_index[param_index]) < 2: - ignore_param_indexes.append(param_index) + if not self.stats: + logger.warning( + "build_fol_model called with ModelAttribute.stats unavailable -- overfitting likely" + ) + else: + for param_index, _ in enumerate(self.param_names): + if len(self.stats.distinct_values_by_param_index[param_index]) < 2: + ignore_param_indexes.append(param_index) x = df.FOLFunction(self.median, self.param_names, n_samples=self.data.shape[0]) x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) if x.fit_success: |