From f1dfa367044344760331d79361363ba1860f2367 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Mon, 1 Jul 2024 11:02:49 +0200 Subject: mutual_information: handle skipped parameters; return a dict --- bin/analyze-log.py | 9 +++++---- lib/cli.py | 7 +++++-- lib/parameters.py | 12 +++++------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 50c0344..43a7d11 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -286,10 +286,11 @@ def main(): for name in model.names: for attr in model.attributes(name): mutual_information = model.mutual_information(name, attr) - for i, param in enumerate(model.parameters): - dref[f"mutual information/{name}/{attr}/{param}"] = ( - mutual_information[i] - ) + for param in model.parameters: + if param in mutual_information: + dref[f"mutual information/{name}/{attr}/{param}"] = ( + mutual_information[param] + ) dfatool.cli.export_dataref( args.export_dref, dref, precision=args.dref_precision diff --git a/lib/cli.py b/lib/cli.py index d59c997..00a0d11 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -97,8 +97,11 @@ def print_information_gain_by_name(model, by_name): for attr in model.attributes(name): print(f"{name} {attr}:") mutual_information = model.mutual_information(name, attr) - for i, param in enumerate(model.parameters): - print(f" Parameter {param} : {mutual_information[i]:5.2f}") + for param in model.parameters: + if param in mutual_information: + print(f" Parameter {param} : {mutual_information[param]:5.2f}") + else: + print(f" Parameter {param} : -.--") def print_analyticinfo(prefix, info): diff --git a/lib/parameters.py b/lib/parameters.py index 352e7c7..0653100 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -712,17 +712,15 @@ class ModelAttribute: self.param_values, with_nan=False, categorical_to_scalar=True ) - param_to_fit_param = dict() + mutual_info_result = mutual_info_regression(fit_parameters, self.data) + + self.mutual_information_cache = dict() j = 0 - for i in range(len(self.param_names)): + for i, param_name in enumerate(self.param_names): if not ignore_index[i]: - param_to_fit_param[i] = j + self.mutual_information_cache[param_name] = mutual_info_result[j] j += 1 - self.mutual_information_cache = mutual_info_regression( - fit_parameters, self.data - ) - return self.mutual_information_cache @classmethod -- cgit v1.2.3