summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/lib/model.py b/lib/model.py
index e908af4..bb4a45b 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -5,6 +5,7 @@ import numpy as np
from scipy import optimize
from sklearn.metrics import r2_score
from multiprocessing import Pool
+from .automata import PTA
from .functions import analytic
from .functions import AnalyticFunction
from .parameters import ParamStats
@@ -700,7 +701,6 @@ class PTAModel:
arg_count,
traces=[],
ignore_trace_indexes=[],
- discard_outliers=None,
function_override={},
use_corrcoef=False,
pta=None,
@@ -716,13 +716,6 @@ class PTAModel:
arg_count -- function arguments, as returned by pta_trace_to_aggregate
traces -- list of preprocessed DFA traces, as returned by RawData.get_preprocessed_data()
ignore_trace_indexes -- list of trace indexes. The corresponding traces will be ignored.
- discard_outliers -- currently not supported: threshold for outlier detection and removel (float).
- Outlier detection is performed individually for each state/transition in each trace,
- so it only works if the benchmark ran several times.
- Given "data" (a set of measurements of the same thing, e.g. TX duration in the third benchmark trace),
- "m" (the median of all attribute measurements with the same parameters, which may include data from other traces),
- a data point X is considered an outlier if
- | 0.6745 * (X - m) / median(|data - m|) | > discard_outliers .
function_override -- dict of overrides for automatic parameter function generation.
If (state or transition name, model attribute) is present in function_override,
the corresponding text string is the function used for analytic (parameter-aware/fitted)
@@ -749,7 +742,6 @@ class PTAModel:
)
self.cache = {}
np.seterr("raise")
- self._outlier_threshold = discard_outliers
self.function_override = function_override.copy()
self.pta = pta
self.ignore_trace_indexes = ignore_trace_indexes
@@ -940,13 +932,16 @@ class PTAModel:
static_quality = self.assess(static_model)
param_model, param_info = self.get_fitted()
analytic_quality = self.assess(param_model)
- self.pta.update(
+ pta = self.pta
+ if pta is None:
+ pta = PTA(self.states(), parameters=self._parameter_names)
+ pta.update(
static_model,
param_info,
static_error=static_quality["by_name"],
analytic_error=analytic_quality["by_name"],
)
- return self.pta.to_json()
+ return pta.to_json()
def states(self):
"""Return sorted list of state names."""