summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py5
-rw-r--r--lib/parameters.py73
2 files changed, 66 insertions, 12 deletions
diff --git a/lib/model.py b/lib/model.py
index 829ca37..75f7195 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -5,7 +5,7 @@ import numpy as np
import os
from .automata import PTA, ModelAttribute
from .functions import StaticFunction, SubstateFunction
-from .parameters import ParallelParamStats
+from .parameters import ParallelParamStats, codependent_param_dict
from .paramfit import ParallelParamFit
from .utils import soft_cast_int, by_name_to_by_param, regression_measures
@@ -126,10 +126,12 @@ class AnalyticModel:
return f"AnalyticModel<names=[{names}]>"
def _compute_stats(self, by_name):
+
paramstats = ParallelParamStats()
for name, data in by_name.items():
self.attr_by_name[name] = dict()
+ codependent_param = codependent_param_dict(data["param"])
for attr in data["attributes"]:
model_attr = ModelAttribute(
name,
@@ -138,6 +140,7 @@ class AnalyticModel:
data["param"],
self.parameters,
self._num_args.get(name, 0),
+ codependent_param=codependent_param,
)
self.attr_by_name[name][attr] = model_attr
paramstats.enqueue((name, attr), model_attr)
diff --git a/lib/parameters.py b/lib/parameters.py
index 0eb3d8a..ea14ad2 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -227,8 +227,6 @@ def _compute_param_statistics(
ret["_depends_on_param"] = dict()
ret["_depends_on_arg"] = list()
- ret["codependent_param_pair"] = _codepenent_param_pair(param_tuples)
-
np.seterr("raise")
for param_idx, param in enumerate(param_names):
@@ -274,7 +272,7 @@ def _compute_param_statistics(
return ret
-def _codepenent_param_pair(param_values):
+def codependent_param_dict(param_values):
lut = [dict() for i in param_values[0]]
for param_index in range(len(param_values[0])):
uniqs = set(map(lambda param_tuple: param_tuple[param_index], param_values))
@@ -312,9 +310,6 @@ def _codepenent_param_pair(param_values):
if std_by_param[param1_i] > 0 and std_by_param[param2_i] > 0:
if std_by_param_pair[(param1_i, param2_i)] == 0:
ret[(param1_i, param2_i)] = True
- logger.warning(
- f"parameters ({param1_i}, {param2_i}) are codependent"
- )
return ret
@@ -589,7 +584,16 @@ class ModelAttribute:
- a fitted model (`model_function`, a `ModelFunction` instance)
"""
- def __init__(self, name, attr, data, param_values, param_names, arg_count=0):
+ def __init__(
+ self,
+ name,
+ attr,
+ data,
+ param_values,
+ param_names,
+ arg_count=0,
+ codependent_param=dict(),
+ ):
# Data for model generation
self.data = np.array(data)
@@ -601,6 +605,15 @@ class ModelAttribute:
self.param_names = sorted(param_names)
self.arg_count = arg_count
+ self.log_param_names = self.param_names + list(
+ map(lambda i: f"arg{i}", range(arg_count))
+ )
+
+ # Co-dependent parameters. If (paam1_index, param2_index) in codependent_param, they are codependent.
+ # In this case, only one of them must be used for parameter-dependent model attribute detection and modeling
+ self.codependent_param = codependent_param
+ self.ignore_param = dict()
+
# Static model used as lower bound of model accuracy
self.mean = np.mean(data)
self.median = np.median(data)
@@ -617,6 +630,8 @@ class ModelAttribute:
# The best model we have. May be Static, Split, or Param (and later perhaps Substate)
self.model_function = None
+ self._check_codependent_param()
+
def __repr__(self):
mean = np.mean(self.data)
return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>"
@@ -644,6 +659,28 @@ class ModelAttribute:
return self
+ def _check_codependent_param(self):
+ for (
+ param1_index,
+ param2_index,
+ ), is_codependent in self.codependent_param.items():
+ if not is_codependent:
+ continue
+ param1_values = map(lambda pv: pv[param1_index], self.param_values)
+ param1_numeric_count = sum(map(is_numeric, param1_values))
+ param2_values = map(lambda pv: pv[param2_index], self.param_values)
+ param2_numeric_count = sum(map(is_numeric, param2_values))
+ if param1_numeric_count >= param2_numeric_count:
+ self.ignore_param[param2_index] = True
+ logger.warning(
+ f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param2_index]}"
+ )
+ else:
+ self.ignore_param[param1_index] = True
+ logger.warning(
+ f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param1_index]}"
+ )
+
def get_static(self, use_mean=False):
if use_mean:
return self.mean
@@ -712,7 +749,11 @@ class ModelAttribute:
std_by_param = list()
for param_index, param_name in enumerate(self.param_names):
distinct_values = self.stats.distinct_values_by_param_index[param_index]
- if self.stats.depends_on_param(param_name) and len(distinct_values) == 2:
+ if (
+ self.stats.depends_on_param(param_name)
+ and len(distinct_values) == 2
+ and not param_index in self.ignore_param
+ ):
val1 = list(
map(
lambda i: self.param_values[i][param_index]
@@ -730,7 +771,11 @@ class ModelAttribute:
distinct_values = self.stats.distinct_values_by_param_index[
len(self.param_names) + arg_index
]
- if self.stats.depends_on_arg(arg_index) and len(distinct_values) == 2:
+ if (
+ self.stats.depends_on_arg(arg_index)
+ and len(distinct_values) == 2
+ and not len(self.param_names) + arg_index in self.ignore_param
+ ):
val1 = list(
map(
lambda i: self.param_values[i][
@@ -776,7 +821,10 @@ class ModelAttribute:
def get_data_for_paramfit_this(self, safe_functions_enabled=False):
ret = list()
for param_index, param_name in enumerate(self.param_names):
- if self.stats.depends_on_param(param_name):
+ if (
+ self.stats.depends_on_param(param_name)
+ and not param_index in self.ignore_param
+ ):
ret.append(
(
(self.name, self.attr),
@@ -786,7 +834,10 @@ class ModelAttribute:
)
if self.arg_count:
for arg_index in range(self.arg_count):
- if self.stats.depends_on_arg(arg_index):
+ if (
+ self.stats.depends_on_arg(arg_index)
+ and not arg_index + len(self.param_names) in self.ignore_param
+ ):
ret.append(
(
(self.name, self.attr),