summaryrefslogtreecommitdiff
path: root/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/utils.py')
-rw-r--r--lib/utils.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/lib/utils.py b/lib/utils.py
index d28ecda..adcb534 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -1,3 +1,4 @@
+import json
import numpy as np
import re
import logging
@@ -6,6 +7,18 @@ arg_support_enabled = True
logger = logging.getLogger(__name__)
+class NpEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return super(NpEncoder, self).default(obj)
+
+
def running_mean(x: np.ndarray, N: int) -> np.ndarray:
"""
Compute `N` elements wide running average over `x`.