summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-log.py13
-rw-r--r--lib/cli.py14
-rw-r--r--lib/parameters.py26
3 files changed, 53 insertions, 0 deletions
diff --git a/bin/analyze-log.py b/bin/analyze-log.py
index a394b10..50c0344 100755
--- a/bin/analyze-log.py
+++ b/bin/analyze-log.py
@@ -124,6 +124,9 @@ def main():
if args.info:
dfatool.cli.print_info_by_name(model, by_name)
+ if args.information_gain:
+ dfatool.cli.print_information_gain_by_name(model, by_name)
+
if args.export_csv_unparam:
dfatool.cli.export_csv_unparam(
model, args.export_csv_unparam, dialect=args.export_csv_dialect
@@ -278,6 +281,16 @@ def main():
dref = model.to_dref(static_quality, lut_quality, analytic_quality)
for key, value in timing.items():
dref[f"timing/{key}"] = (value, r"\second")
+
+ if args.information_gain:
+ for name in model.names:
+ for attr in model.attributes(name):
+ mutual_information = model.mutual_information(name, attr)
+ for i, param in enumerate(model.parameters):
+ dref[f"mutual information/{name}/{attr}/{param}"] = (
+ mutual_information[i]
+ )
+
dfatool.cli.export_dataref(
args.export_dref, dref, precision=args.dref_precision
)
diff --git a/lib/cli.py b/lib/cli.py
index d2c9840..d59c997 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -92,6 +92,15 @@ def print_info_by_name(model, by_name):
)
+def print_information_gain_by_name(model, by_name):
+ for name in model.names:
+ for attr in model.attributes(name):
+ print(f"{name} {attr}:")
+ mutual_information = model.mutual_information(name, attr)
+ for i, param in enumerate(model.parameters):
+ print(f" Parameter {param} : {mutual_information[i]:5.2f}")
+
+
def print_analyticinfo(prefix, info):
model_function = info.model_function.removeprefix("0 + ")
for i in range(len(info.model_args)):
@@ -576,6 +585,11 @@ def add_standard_arguments(parser):
help="Show benchmark information (number of measurements, parameter values, ...)",
)
parser.add_argument(
+ "--information-gain",
+ action="store_true",
+ help="Show information gain of parameters",
+ )
+ parser.add_argument(
"--log-level",
metavar="LEVEL",
choices=["debug", "info", "warning", "error"],
diff --git a/lib/parameters.py b/lib/parameters.py
index 4047c10..352e7c7 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -604,6 +604,9 @@ class ModelAttribute:
# The best model we have. May be Static, Split, or Param (and later perhaps Substate)
self.model_function = None
+ # Information gain cache. Used for statistical analysis
+ self.mutual_information_cache = None
+
self._check_codependent_param()
# There must be at least 3 distinct data values (≠ None) if an analytic model
@@ -699,6 +702,29 @@ class ModelAttribute:
def webconf_function_map(self):
return self.model_function.webconf_function_map()
+ def mutual_information(self):
+ if self.mutual_information_cache is not None:
+ return self.mutual_information_cache
+
+ from sklearn.feature_selection import mutual_info_regression
+
+ fit_parameters, _, ignore_index = param_to_ndarray(
+ self.param_values, with_nan=False, categorical_to_scalar=True
+ )
+
+ param_to_fit_param = dict()
+ j = 0
+ for i in range(len(self.param_names)):
+ if not ignore_index[i]:
+ param_to_fit_param[i] = j
+ j += 1
+
+ self.mutual_information_cache = mutual_info_regression(
+ fit_parameters, self.data
+ )
+
+ return self.mutual_information_cache
+
@classmethod
def from_json(cls, name, attr, data):
param_names = data["paramNames"]