diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-01 11:02:49 +0200 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-01 11:02:49 +0200 |
commit | f1dfa367044344760331d79361363ba1860f2367 (patch) | |
tree | 3bab6fa0fbbca5508cfc610d332cdbf900144319 /lib | |
parent | 40642f6127d2961a10c9cde5c7fa18cdb55b97e9 (diff) |
mutual_information: handle skipped parameters; return a dict
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 7 | ||||
-rw-r--r-- | lib/parameters.py | 12 |
2 files changed, 10 insertions, 9 deletions
@@ -97,8 +97,11 @@ def print_information_gain_by_name(model, by_name): for attr in model.attributes(name): print(f"{name} {attr}:") mutual_information = model.mutual_information(name, attr) - for i, param in enumerate(model.parameters): - print(f" Parameter {param} : {mutual_information[i]:5.2f}") + for param in model.parameters: + if param in mutual_information: + print(f" Parameter {param} : {mutual_information[param]:5.2f}") + else: + print(f" Parameter {param} : -.--") def print_analyticinfo(prefix, info): diff --git a/lib/parameters.py b/lib/parameters.py index 352e7c7..0653100 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -712,17 +712,15 @@ class ModelAttribute: self.param_values, with_nan=False, categorical_to_scalar=True ) - param_to_fit_param = dict() + mutual_info_result = mutual_info_regression(fit_parameters, self.data) + + self.mutual_information_cache = dict() j = 0 - for i in range(len(self.param_names)): + for i, param_name in enumerate(self.param_names): if not ignore_index[i]: - param_to_fit_param[i] = j + self.mutual_information_cache[param_name] = mutual_info_result[j] j += 1 - self.mutual_information_cache = mutual_info_regression( - fit_parameters, self.data - ) - return self.mutual_information_cache @classmethod |