summaryrefslogtreecommitdiff
path: root/lib/cli.py
blob: a1e4a58c1577619dd6638b9d03696cf6ada18667 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3

from dfatool.functions import SplitFunction, AnalyticFunction, StaticFunction


def print_static(model, static_model, name, attribute):
    unit = "  "
    if attribute == "power":
        unit = "µW"
    elif attribute == "duration":
        unit = "µs"
    elif attribute == "substate_count":
        unit = "su"
    print(
        "{:10s}: {:.0f} {:s}  ({:.2f})".format(
            name,
            static_model(name, attribute),
            unit,
            model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(),
        )
    )
    for param in model.parameters:
        print(
            "{:10s}  dependence on {:15s}: {:.2f}".format(
                "",
                param,
                model.attr_by_name[name][attribute].stats.param_dependence_ratio(param),
            )
        )


def print_analyticinfo(prefix, info):
    empty = ""
    print(f"{prefix}: {info.model_function}")
    print(f"{empty:{len(prefix)}s}  {info.model_args}")


def print_splitinfo(param_names, info, prefix=""):
    if type(info) is SplitFunction:
        for k, v in info.child.items():
            if info.param_index < len(param_names):
                param_name = param_names[info.param_index]
            else:
                param_name = f"arg{info.param_index - len(param_names)}"
            print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
    elif type(info) is AnalyticFunction:
        print_analyticinfo(prefix, info)
    elif type(info) is StaticFunction:
        print(f"{prefix}: {info.value}")
    else:
        print(f"{prefix}: UNKNOWN")


def format_quality_measures(result):
    if "smape" in result:
        return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
    else:
        return "{:6}    {:9.0f}".format("", result["mae"])


def model_quality_table(header, result_lists, info_list):
    print(
        "{:20s} {:15s}       {:19s}       {:19s}       {:19s}".format(
            "key",
            "attribute",
            header[0].center(19),
            header[1].center(19),
            header[2].center(19),
        )
    )
    for state_or_tran in result_lists[0].keys():
        for key in result_lists[0][state_or_tran].keys():
            buf = "{:20s} {:15s}".format(state_or_tran, key)
            for i, results in enumerate(result_lists):
                info = info_list[i]
                buf += "  |||  "
                if (
                    info is None
                    or (
                        key != "energy_Pt"
                        and type(info(state_or_tran, key)) is not StaticFunction
                    )
                    or (
                        key == "energy_Pt"
                        and (
                            type(info(state_or_tran, "power")) is not StaticFunction
                            or type(info(state_or_tran, "duration"))
                            is not StaticFunction
                        )
                    )
                ):
                    result = results[state_or_tran][key]
                    buf += format_quality_measures(result)
                else:
                    buf += "{:7}----{:8}".format("", "")
            print(buf)


def add_standard_arguments(parser):
    parser.add_argument(
        "--export-dref",
        metavar="FILE",
        type=str,
        help="Export model and model quality to LaTeX dataref file",
    )
    parser.add_argument(
        "--cross-validate",
        metavar="<method>:<count>",
        type=str,
        help="Perform cross validation when computing model quality. "
        "Only works with --show-quality=table at the moment.",
    )
    parser.add_argument(
        "--parameter-aware-cross-validation",
        action="store_true",
        help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.",
    )