summaryrefslogtreecommitdiff
path: root/lib/cli.py
blob: bb1b111d5a1cd5937d52df0c49ce30dbfa4542ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python3

from dfatool.functions import (
    SplitFunction,
    AnalyticFunction,
    StaticFunction,
)


def print_static(model, static_model, name, attribute):
    unit = "  "
    if attribute == "power":
        unit = "µW"
    elif attribute == "duration":
        unit = "µs"
    elif attribute == "substate_count":
        unit = "su"
    print(
        "{:10s}: {:.0f} {:s}  ({:.2f})".format(
            name,
            static_model(name, attribute),
            unit,
            model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(),
        )
    )
    for param in model.parameters:
        print(
            "{:10s}  dependence on {:15s}: {:.2f}".format(
                "",
                param,
                model.attr_by_name[name][attribute].stats.param_dependence_ratio(param),
            )
        )


def print_analyticinfo(prefix, info):
    empty = ""
    print(f"{prefix}: {info.model_function}")
    print(f"{empty:{len(prefix)}s}  {info.model_args}")


def print_splitinfo(param_names, info, prefix=""):
    if type(info) is SplitFunction:
        for k, v in info.child.items():
            if info.param_index < len(param_names):
                param_name = param_names[info.param_index]
            else:
                param_name = f"arg{info.param_index - len(param_names)}"
            print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
    elif type(info) is AnalyticFunction:
        print_analyticinfo(prefix, info)
    elif type(info) is StaticFunction:
        print(f"{prefix}: {info.value}")
    else:
        print(f"{prefix}: UNKNOWN")


def format_quality_measures(result):
    if "smape" in result:
        return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
    else:
        return "{:6}    {:9.0f}".format("", result["mae"])


def model_quality_table(header, result_lists, info_list):
    print(
        "{:20s} {:15s}       {:19s}       {:19s}       {:19s}".format(
            "key",
            "attribute",
            header[0].center(19),
            header[1].center(19),
            header[2].center(19),
        )
    )
    for state_or_tran in result_lists[0].keys():
        for key in result_lists[0][state_or_tran].keys():
            buf = "{:20s} {:15s}".format(state_or_tran, key)
            for i, results in enumerate(result_lists):
                info = info_list[i]
                buf += "  |||  "
                if (
                    info is None
                    or (
                        key != "energy_Pt"
                        and type(info(state_or_tran, key)) is not StaticFunction
                    )
                    or (
                        key == "energy_Pt"
                        and (
                            type(info(state_or_tran, "power")) is not StaticFunction
                            or type(info(state_or_tran, "duration"))
                            is not StaticFunction
                        )
                    )
                ):
                    result = results[state_or_tran][key]
                    buf += format_quality_measures(result)
                else:
                    buf += "{:7}----{:8}".format("", "")
            print(buf)