From 6a9e041978f82017b74d1ce9d47296f1ebaaa0ff Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 8 Jun 2022 14:15:47 +0200 Subject: FOL: Add JSON export --- lib/functions.py | 30 ++++++++++++++++++++++++++---- lib/parameters.py | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/lib/functions.py b/lib/functions.py index cd3c7d1..41885c7 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -635,12 +635,21 @@ class FOLFunction(ModelFunction): num_vars += 1 else: num_vars = fit_parameters.shape[0] - funbuf = "lambda reg_param, model_param: 0" + funbuf = "0" + rawbuf = "0" for i in range(num_vars): - funbuf += f" + reg_param[{i}] * model_param[{i}]" + rawbuf += f" + reg_param[{i}] * model_param[{i}]" + i = 0 + for j, param_name in enumerate(self.parameter_names): + if ignore_index[j]: + continue + else: + funbuf += f" + regression_arg({i}) * parameter({param_name})" + i += 1 - self._function_str = self.model_function = funbuf - self._function = eval(funbuf) + self.model_function = funbuf + self._function_str = "lambda reg_param, model_param:" + rawbuf + self._function = eval(self._function_str) error_function = lambda P, X, y: self._function(P, X) - y self.model_args = list(np.ones((num_vars))) @@ -697,6 +706,19 @@ class FOLFunction(ModelFunction): ) return self.value + def to_json(self, **kwargs): + ret = super().to_json(**kwargs) + ret.update( + { + "type": "analytic", + "functionStr": self.model_function, + "argCount": self._num_args, + "parameterNames": self.parameter_names, + "regressionModel": list(self.model_args), + } + ) + return ret + class AnalyticFunction(ModelFunction): """ diff --git a/lib/parameters.py b/lib/parameters.py index a636b52..207b8d8 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -606,7 +606,7 @@ class ModelAttribute: "argCount": self.arg_count, "modelFunction": self.model_function.to_json(**kwargs), } - if type(self.model_function) == df.CARTFunction: + if type(self.model_function) in (df.CARTFunction, df.FOLFunction): feature_names = self.param_names feature_names += list( map( -- cgit v1.2.3