From 9fb10653c723f10b36ae0567ff29f08a4e725ea3 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Tue, 15 Oct 2019 17:40:50 +0200 Subject: PTA: Add from_file constructor --- bin/analyze-archive.py | 14 +++++++------- bin/generate-dfa-benchmark.py | 6 +----- 2 files changed, 8 insertions(+), 12 deletions(-) (limited to 'bin') diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 685ce92..4551cda 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -70,8 +70,8 @@ Options: not been set yet) are discarded. Note that this may remove entire function calls from the model. ---hwmodel= - Load DFA hardware model from JSON +--hwmodel= + Load DFA hardware model from JSON or YAML --export-energymodel= Export energy model. Requires --hwmodel. @@ -86,6 +86,7 @@ from dfatool import PTAModel, RawData, pta_trace_to_aggregate from dfatool import soft_cast_int, is_numeric, gplearn_to_function from dfatool import CrossValidator from utils import filter_aggregate_by_param +from automata import PTA opts = {} @@ -211,7 +212,7 @@ if __name__ == '__main__': function_override = {} show_models = [] show_quality = [] - hwmodel = None + pta = None energymodel_export_file = None xv_method = None xv_count = 10 @@ -262,8 +263,7 @@ if __name__ == '__main__': safe_functions_enabled = True if 'hwmodel' in opts: - with open(opts['hwmodel'], 'r') as f: - hwmodel = json.load(f) + pta = PTA.from_file(opts['hwmodel']) except getopt.GetoptError as err: print(err) @@ -280,7 +280,7 @@ if __name__ == '__main__': traces = preprocessed_data, discard_outliers = discard_outliers, function_override = function_override, - hwmodel = hwmodel) + pta = pta) if xv_method: xv = CrossValidator(PTAModel, by_name, parameters, arg_count) @@ -416,7 +416,7 @@ if __name__ == '__main__': plotter.plot_param(model, state_or_trans, attribute, model.param_index(param_name), extra_function=function) if 'export-energymodel' in opts: - if not hwmodel: + if not pta: print('[E] --export-energymodel requires --hwmodel to be set') sys.exit(1) json_model = model.to_json() diff --git a/bin/generate-dfa-benchmark.py b/bin/generate-dfa-benchmark.py index 7f67e1b..3ae51e6 100755 --- a/bin/generate-dfa-benchmark.py +++ b/bin/generate-dfa-benchmark.py @@ -302,11 +302,7 @@ if __name__ == '__main__': modelfile = args[0] - with open(modelfile, 'r') as f: - if '.json' in modelfile: - pta = PTA.from_json(json.load(f)) - else: - pta = PTA.from_yaml(yaml.safe_load(f)) + pta = PTA.from_file(modelfile) if 'shrink' in opt: pta.shrink_argument_values() -- cgit v1.2.3