summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 09:41:50 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-10 09:41:50 +0100
commit5adc39c8de1fbda6e1ba9abcaaaf17516eb046ae (patch)
treef52dbe13e5cf9fa70f38a8d0e8d75020856ce781
parent72fc17a4dcff42bedb21b456f7da63065a72fd9f (diff)
--show-model=param: add CART support
-rwxr-xr-xbin/analyze-archive.py38
-rwxr-xr-xbin/analyze-kconfig.py13
-rwxr-xr-xbin/analyze-log.py4
-rw-r--r--lib/cli.py20
4 files changed, 53 insertions, 22 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 5432ffd..ffc0d67 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -45,15 +45,9 @@ import sys
import time
import dfatool.cli
import dfatool.utils
+import dfatool.functions as df
from dfatool import plotter
from dfatool.loader import RawData, pta_trace_to_aggregate
-from dfatool.functions import (
- gplearn_to_function,
- SplitFunction,
- AnalyticFunction,
- SubstateFunction,
- StaticFunction,
-)
from dfatool.model import PTAModel
from dfatool.validation import CrossValidator
from dfatool.automata import PTA
@@ -666,7 +660,7 @@ if __name__ == "__main__":
],
)
)
- if type(info) is AnalyticFunction:
+ if type(info) is df.AnalyticFunction:
for param_name in sorted(info.fit_by_param.keys(), key=str):
param_fit = info.fit_by_param[param_name]["results"]
for function_type in sorted(param_fit.keys()):
@@ -685,26 +679,34 @@ if __name__ == "__main__":
for state in model.states:
for attribute in model.attributes(state):
info = param_info(state, attribute)
- if type(info) is AnalyticFunction:
+ if type(info) is df.AnalyticFunction:
dfatool.cli.print_analyticinfo(f"{state:10s} {attribute:15s}", info)
- elif type(info) is SplitFunction:
+ elif type(info) is df.CARTFunction:
+ dfatool.cli.print_cartinfo(
+ f"{state:10s} {attribute:15s}", info, model.parameters
+ )
+ elif type(info) is df.SplitFunction:
dfatool.cli.print_splitinfo(
model.parameters, info, f"{state:10s} {attribute:15s}"
)
- elif type(info) is StaticFunction:
+ elif type(info) is df.StaticFunction:
dfatool.cli.print_staticinfo(f"{state:10s} {attribute:15s}", info)
- elif type(info) is SubstateFunction:
+ elif type(info) is df.SubstateFunction:
print(f"{state:10s} {attribute:15s}: Substate (TODO)")
for trans in model.transitions:
for attribute in model.attributes(trans):
info = param_info(trans, attribute)
- if type(info) is AnalyticFunction:
+ if type(info) is df.AnalyticFunction:
dfatool.cli.print_analyticinfo(f"{trans:10s} {attribute:15s}", info)
- elif type(info) is SplitFunction:
+ elif type(info) is df.CARTFunction:
+ dfatool.cli.print_cartinfo(
+ f"{trans:10s} {attribute:15s}", info, model.parameters
+ )
+ elif type(info) is df.SplitFunction:
dfatool.cli.print_splitinfo(
model.parameters, info, f"{trans:10s} {attribute:15s}"
)
- elif type(info) is SubstateFunction:
+ elif type(info) is df.SubstateFunction:
print(f"{state:10s} {attribute:15s}: Substate (TODO)")
if args.with_substates:
for submodel in model.submodel_by_name.values():
@@ -712,7 +714,7 @@ if __name__ == "__main__":
for substate in submodel.states:
for subattribute in submodel.attributes(substate):
info = sub_param_info(substate, subattribute)
- if type(info) is AnalyticFunction:
+ if type(info) is df.AnalyticFunction:
print(
"{:10s} {:15s}: {}".format(
substate, subattribute, info.model_function
@@ -724,7 +726,7 @@ if __name__ == "__main__":
for state in model.states:
if (
type(model.attr_by_name[state]["power"].model_function)
- is SubstateFunction
+ is df.SubstateFunction
):
# sub-state models need to know the duration of the state / transition. only needed for eval.
model.attr_by_name[state]["power"].model_function.static_duration = (
@@ -835,7 +837,7 @@ if __name__ == "__main__":
)
sys.exit(1)
if len(function):
- function = gplearn_to_function(" ".join(function))
+ function = df.gplearn_to_function(" ".join(function))
else:
function = None
plotter.plot_param(
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index 8d8b63a..77b76aa 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -21,6 +21,7 @@ import numpy as np
import dfatool.cli
import dfatool.plotter
import dfatool.utils
+import dfatool.functions as df
from dfatool.loader.kconfig import KConfigAttributes
from dfatool.model import AnalyticModel
from dfatool.validation import CrossValidator
@@ -527,15 +528,19 @@ def main():
for name in model.names:
for attribute in model.attributes(name):
info = param_info(name, attribute)
- if type(info) is dfatool.cli.AnalyticFunction:
+ if type(info) is df.AnalyticFunction:
dfatool.cli.print_analyticinfo(f"{name:20s} {attribute:15s}", info)
- elif type(info) is dfatool.cli.FOLFunction:
+ elif type(info) is df.CARTFunction:
+ dfatool.cli.print_cartinfo(
+ f"{name:20s} {attribute:15s}", info, model.parameters
+ )
+ elif type(info) is df.FOLFunction:
dfatool.cli.print_analyticinfo(f"{name:20s} {attribute:15s}", info)
- elif type(info) is dfatool.cli.SplitFunction:
+ elif type(info) is df.SplitFunction:
dfatool.cli.print_splitinfo(
model.parameters, info, f"{name:20s} {attribute:15s}"
)
- elif type(info) is dfatool.cli.StaticFunction:
+ elif type(info) is df.StaticFunction:
dfatool.cli.print_staticinfo(f"{state:10s} {attribute:15s}", info)
if "table" in args.show_quality or "all" in args.show_quality:
diff --git a/bin/analyze-log.py b/bin/analyze-log.py
index 8171cbb..b1dc6ba 100755
--- a/bin/analyze-log.py
+++ b/bin/analyze-log.py
@@ -229,6 +229,10 @@ def main():
info = param_info(name, attribute)
if type(info) is df.AnalyticFunction:
dfatool.cli.print_analyticinfo(f"{name:10s} {attribute:15s}", info)
+ elif type(info) is df.CARTFunction:
+ dfatool.cli.print_cartinfo(
+ f"{name:10s} {attribute:15s}", info, model.parameters
+ )
elif type(info) is df.SplitFunction:
dfatool.cli.print_splitinfo(
model.parameters, info, f"{name:10s} {attribute:15s}"
diff --git a/lib/cli.py b/lib/cli.py
index 314c6dd..a1655c6 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -94,6 +94,26 @@ def print_staticinfo(prefix, info):
print(f"{prefix}: {info.value}")
+def print_cartinfo(prefix, info, feature_names):
+ _print_cartinfo(prefix, info.to_json(feature_names=feature_names), feature_names)
+
+
+def _print_cartinfo(prefix, model, feature_names):
+ if model["type"] == "static":
+ print(f"""{prefix}: {model["value"]}""")
+ else:
+ _print_cartinfo(
+ f"""{prefix} {model["paramName"]}<{model["paramDecisionValue"]} """,
+ model["left"],
+ feature_names,
+ )
+ _print_cartinfo(
+ f"""{prefix} {model["paramName"]}≥{model["paramDecisionValue"]} """,
+ model["right"],
+ feature_names,
+ )
+
+
def print_splitinfo(param_names, info, prefix=""):
if type(info) is SplitFunction:
for k, v in info.child.items():