summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py119
1 files changed, 117 insertions, 2 deletions
diff --git a/lib/model.py b/lib/model.py
index bb4a45b..a953c46 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -2,6 +2,7 @@
import logging
import numpy as np
+import kconfiglib
from scipy import optimize
from sklearn.metrics import r2_score
from multiprocessing import Pool
@@ -375,7 +376,7 @@ def _num_args_from_by_name(by_name):
class AnalyticModel:
- u"""
+ """
Parameter-aware analytic energy/data size/... model.
Supports both static and parameter-based model attributes, and automatic detection of parameter-dependence.
@@ -663,7 +664,7 @@ class AnalyticModel:
class PTAModel:
- u"""
+ """
Parameter-aware PTA-based energy model.
Supports both static and parameter-based model attributes, and automatic detection of parameter-dependence.
@@ -1154,3 +1155,117 @@ class PTAModel:
np.array(model_state_energy_list), np.array(real_energy_list)
),
}
+
+
+class KConfigModel:
+ class Leaf:
+ def __init__(self, value, stddev):
+ self.value = value
+ self.stddev = stddev
+
+ def model(self, kconf):
+ return self.value
+
+ def __repr__(self):
+ return f"<Leaf({self.value}, {self.stddev})>"
+
+ def to_json(self):
+ return {"value": self.value, "stddev": self.stddev}
+
+ class Node:
+ def __init__(self, symbol):
+ self.symbol = symbol
+ self.child_n = None
+ self.child_y = None
+
+ def set_child_n(self, child_node):
+ self.child_n = child_node
+
+ def set_child_y(self, child_node):
+ self.child_y = child_node
+
+ def model(self, kconf):
+ if kconf.syms[self.symbol].tri_value == 0 and self.child_n:
+ return self.child_n.model(kconf)
+ if kconf.syms[self.symbol].tri_value == 2 and self.child_y:
+ return self.child_y.model(kconf)
+ return None
+
+ def __repr__(self):
+ return f"<Node(n={self.child_n}, y={self.child_y})>"
+
+ def to_json(self):
+ ret = {"symbol": self.symbol}
+ if self.child_n:
+ ret["n"] = self.child_n.to_json()
+ else:
+ ret["n"] = None
+ if self.child_y:
+ ret["y"] = self.child_y.to_json()
+ else:
+ ret["y"] = None
+ return ret
+
+ def __init__(self, kconfig_benchmark):
+ self.data = kconfig_benchmark.data
+ self.symbols = kconfig_benchmark.symbols
+ model = self.get_min(self.symbols, self.data, 0)
+
+ output = {"model": model.to_json(), "symbols": self.symbols}
+ print(output)
+
+ # with open("kconfigmodel.json", "w") as f:
+ # json.dump(output, f)
+
+ def get_min(self, this_symbols, this_data, level):
+
+ rom_sizes = list(map(lambda x: x[1]["total"]["ROM"], this_data))
+
+ if np.std(rom_sizes) < 100 or len(this_symbols) == 0:
+ return self.Leaf(np.mean(rom_sizes), np.std(rom_sizes))
+
+ mean_stds = list()
+ for i, param in enumerate(this_symbols):
+ enabled = list(filter(lambda vrr: vrr[0][i] == True, this_data))
+ disabled = list(filter(lambda vrr: vrr[0][i] == False, this_data))
+
+ enabled_std_rom = np.std(list(map(lambda x: x[1]["total"]["ROM"], enabled)))
+ disabled_std_rom = np.std(
+ list(map(lambda x: x[1]["total"]["ROM"], disabled))
+ )
+ children = [enabled_std_rom, disabled_std_rom]
+
+ if np.any(np.isnan(children)):
+ mean_stds.append(np.inf)
+ else:
+ mean_stds.append(np.mean(children))
+
+ symbol_index = np.argmin(mean_stds)
+ symbol = this_symbols[symbol_index]
+ enabled = list(filter(lambda vrr: vrr[0][symbol_index] == True, this_data))
+ disabled = list(filter(lambda vrr: vrr[0][symbol_index] == False, this_data))
+
+ node = self.Node(symbol)
+
+ new_symbols = this_symbols[:symbol_index] + this_symbols[symbol_index + 1 :]
+ enabled = list(
+ map(
+ lambda x: (x[0][:symbol_index] + x[0][symbol_index + 1 :], x[1]),
+ enabled,
+ )
+ )
+ disabled = list(
+ map(
+ lambda x: (x[0][:symbol_index] + x[0][symbol_index + 1 :], x[1]),
+ disabled,
+ )
+ )
+ print(
+ f"Level {level} split on {symbol} has {len(enabled)} children when enabled and {len(disabled)} children when disabled"
+ )
+ if len(enabled):
+ node.set_child_y(self.get_min(new_symbols, enabled, level + 1))
+ if len(disabled):
+ node.set_child_n(self.get_min(new_symbols, disabled, level + 1))
+
+ return node