#!/usr/bin/env python3
# vim:tabstop=4:softtabstop=4:shiftwidth=4:textwidth=160:smarttab:expandtab

import getopt
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import sys
from dfatool.loader import MIMOSA
from dfatool.model import aggregate_measures
from dfatool.utils import running_mean

opt = dict()


def show_help():
    print(
        """mimosa-etv - MIMOSA Analyzer and Visualizer

USAGE

mimosa-etv [--skip <count>] [--threshold <power>] [--plot] [--stat] <voltage> <shunt> <file>

DESCRIPTION

mimosa-etv analyzes measurements taken via MIMOSA. Data can be plotted or aggregated on stdout.

OPTIONS

  --skip <count>
    Skip the first <count> data samples.
  --threshold <watts>|mean
    Partition data into points with mean power >= <watts> and points with
    mean power < <watts>, and print some statistics. higher power is handled
    as peaks, whereas low-power measurements constitute the baseline.
    If the threshold is set to "mean", the mean power of all measurements
    will be used
  --threshold-peakcount <num>
    Automatically determine threshold so that there are exactly <num> peaks.
    A peaks is a group of consecutive measurements with mean power >= threshold
  --plot
    Show power/time plot
  --stat
    Show mean voltage, current, and power as well as total energy consumption.
    """
    )


def peak_search(data, lower, upper, direction_function):
    while upper - lower > 1e-6:
        bs_test = np.mean([lower, upper])
        peakcount = itertools.groupby(data, lambda x: x >= bs_test)
        peakcount = filter(lambda x: x[0] == True, peakcount)
        peakcount = sum(1 for i in peakcount)
        direction = direction_function(peakcount, bs_test)
        if direction == 0:
            return bs_test
        elif direction == 1:
            lower = bs_test
        else:
            upper = bs_test
    return None


def peak_search2(data, lower, upper, check_function):
    for power in np.arange(lower, upper, 1e-6):
        peakcount = itertools.groupby(data, lambda x: x >= power)
        peakcount = filter(lambda x: x[0] == True, peakcount)
        peakcount = sum(1 for i in peakcount)
        if check_function(peakcount, power) == 0:
            return power
    return None


if __name__ == "__main__":
    try:
        optspec = "help skip= threshold= threshold-peakcount= plot stat"
        raw_opts, args = getopt.getopt(sys.argv[1:], "", optspec.split(" "))

        for option, parameter in raw_opts:
            optname = re.sub(r"^--", "", option)
            opt[optname] = parameter

        if "help" in opt:
            show_help()
            sys.exit(0)

        if "skip" in opt:
            opt["skip"] = int(opt["skip"])
        else:
            opt["skip"] = 0

        if "threshold" in opt and opt["threshold"] != "mean":
            opt["threshold"] = float(opt["threshold"])

        if "threshold-peakcount" in opt:
            opt["threshold-peakcount"] = int(opt["threshold-peakcount"])

    except getopt.GetoptError as err:
        print(err)
        sys.exit(2)
    except IndexError:
        print("Usage: mimosa-etv <duration>")
        sys.exit(2)
    except ValueError:
        print("Error: duration or skip is not a number")
        sys.exit(2)

    voltage, shunt, inputfile = args
    voltage = float(voltage)
    shunt = int(shunt)
    mim = MIMOSA(voltage, shunt)
    charges, triggers = mim.load_file(inputfile)

    currents = mim.charge_to_current_nocal(charges) * 1e-6
    powers = currents * voltage

    if "threshold-peakcount" in opt:
        bs_mean = np.mean(powers)

        # Finding the correct threshold is tricky. If #peaks < peakcont, our
        # current threshold may be too low (extreme case: a single peak
        # containing all measurements), but it may also be too high (extreme
        # case: a single peak containing just one data point). Similarly,
        # #peaks > peakcount may be due to baseline noise causing lots of
        # small peaks, or due to peak noise (if the threshold is already rather
        # high).
        # For now, we first try a simple binary search:
        # The threshold is probably somewhere around the mean, so if
        # #peaks != peakcount and threshold < mean, we go up, and if
        # #peaks != peakcount and threshold >= mean, we go down.
        # If that doesn't work, we fall back to a linear search in 1 µW steps
        def direction_function(peakcount, power):
            if peakcount == opt["threshold-peakcount"]:
                return 0
            if power < bs_mean:
                return 1
            return -1

        threshold = peak_search(power, np.min(power), np.max(power), direction_function)
        if threshold == None:
            threshold = peak_search2(
                power, np.min(power), np.max(power), direction_function
            )

        if threshold != None:
            print(
                "Threshold set to {:.0f} µW         : {:.9f}".format(
                    threshold * 1e6, threshold
                )
            )
            opt["threshold"] = threshold
        else:
            print("Found no working threshold")

    if "threshold" in opt:
        if opt["threshold"] == "mean":
            opt["threshold"] = np.mean(powers)
            print(
                "Threshold set to {:.0f} µW         : {:.9f}".format(
                    opt["threshold"] * 1e6, opt["threshold"]
                )
            )

        baseline_mean = 0
        if np.any(powers < opt["threshold"]):
            baseline_mean = np.mean(powers[powers < opt["threshold"]])
            print(
                "Baseline mean: {:.0f} µW           : {:.9f}".format(
                    baseline_mean * 1e6, baseline_mean
                )
            )
        if np.any(powers >= opt["threshold"]):
            print(
                "Peak mean: {:.0f} µW               : {:.9f}".format(
                    np.mean(powers[powers >= opt["threshold"]]) * 1e6,
                    np.mean(powers[powers >= opt["threshold"]]),
                )
            )

        peaks = []
        peak_start = -1
        for i, dp in enumerate(powers):
            if dp >= opt["threshold"] and peak_start == -1:
                peak_start = i
            elif dp < opt["threshold"] and peak_start != -1:
                peaks.append((peak_start, i))
                peak_start = -1

        total_energy = 0
        delta_energy = 0
        for peak in peaks:
            duration = (peak[1] - peak[0]) * 1e-5
            total_energy += np.mean(powers[peak[0] : peak[1]]) * duration
            delta_energy += (
                np.mean(powers[peak[0] : peak[1]]) - baseline_mean
            ) * duration
            delta_powers = powers[peak[0] : peak[1]] - baseline_mean
            print(
                "{:.2f}ms peak ({:f} -> {:f})".format(duration * 1000, peak[0], peak[1])
            )
            print(
                "    {:f} µJ / mean {:f} µW".format(
                    np.mean(powers[peak[0] : peak[1]]) * duration * 1e6,
                    np.mean(powers[peak[0] : peak[1]]) * 1e6,
                )
            )
            measures = aggregate_measures(np.mean(delta_powers), delta_powers)
            print(
                "    {:f} µW delta mean = {:0.1f}% / {:f} µW error".format(
                    np.mean(delta_powers) * 1e6,
                    measures["smape"],
                    measures["rmsd"] * 1e6,
                )
            )
        print(
            "Peak energy mean: {:.0f} µJ         : {:.9f}".format(
                total_energy * 1e6 / len(peaks), total_energy / len(peaks)
            )
        )
        print(
            "Average per-peak energy (delta over baseline): {:.0f} µJ         : {:.9f}".format(
                delta_energy * 1e6 / len(peaks), delta_energy / len(peaks)
            )
        )

    if "stat" in opt:
        mean_current = np.mean(currents)
        mean_power = np.mean(powers)
        print(
            "Mean current: {:.0f} µA       : {:.9f}".format(
                mean_current * 1e6, mean_current
            )
        )
        print(
            "Mean power: {:.0f} µW       : {:.9f}".format(mean_power * 1e6, mean_power)
        )

    if "plot" in opt:
        timestamps = np.arange(len(powers)) * 1e-5
        (pwrhandle,) = plt.plot(timestamps, powers, "b-", label="U*I", markersize=1)
        plt.legend(handles=[pwrhandle])
        plt.xlabel("Time [s]")
        plt.ylabel("Power [W]")
        plt.grid(True)
        plt.show()