# Copyright (C) 2017 Jiaan Dai

"""Simulated spectrum generator

Models are defined in this script to generate artificial missing and noise
peaks, as well as the pseudo-PTMs (random mass shift).
"""

import numpy as np
from .mass import *
from .types import *


class SimSpecGen:
    """Generate artificial spectrum with modifications based on the given
    peptide.
    """

    def __init__(self, missing_proportion, signal_level, ms2tol):
        self._missing_proportion = missing_proportion
        self._signal_level = signal_level
        self._ms2tol = ms2tol

    def gen_artificial_spec(self, peptide, scan_num,
                            theo_peaks, precursor_mass):
        """Add effects of missing and noise."""
        noise_num = float(2 * len(peptide.sequence)) / self._signal_level
        noise = np.random.uniform(
            low=PROTON, high=precursor_mass+2*PROTON, size=int(noise_num))
        peak_dict = dict()  # depend on ms2tol
        bin_unit = float(self._ms2tol)
        for n in noise.tolist():
            peak_dict[int(n / bin_unit)] = n
        for theo in theo_peaks:
            peak_dict[int(theo / bin_unit)] = theo

        # delete peaks
        peaks_num = 2 * len(peptide.sequence)
        delete_num = int(self._missing_proportion * peaks_num)
        del_peaks = np.random.choice(theo_peaks, size=delete_num, replace=False)
        for delp in del_peaks:
            peak_dict[int(delp / bin_unit)] = -1
            peak_dict[int(delp / bin_unit)-1] = -1
            peak_dict[int(delp / bin_unit)+1] = -1
        peaks = [peak for peak in peak_dict.values() if peak >= PROTON]
        peaks = sorted(peaks)
        peaks = [(peak, 1) for peak in peaks]

        return Spectrum(scan_num, peaks, precursor_mass + PROTON, 1)

    def gen_mods(self, sequence, num_mods, max_mass_shift):
        """Add effects of modifications (random mass shift) based on the model.
        """
        # if expect_num_mods <= 0:
        #     num_mods = 0  # lambda in Poisson requires positive real value
        # else:
        #     num_mods = np.random.poisson(expect_num_mods)
        #     while num_mods > len(sequence):
        #         num_mods = np.random.poisson(expect_num_mods)

        sites = np.random.choice(len(sequence), size=num_mods, replace=False)
        mods = {}
        loop_flag = True
        while loop_flag:
            mods = {}
            for site in sites:
                mods[np.asscalar(site)] = np.random.uniform(
                    -MASS[sequence[site]], max_mass_shift)
            total = sum(mods.values())
            if -max_mass_shift <= total <= max_mass_shift:
                loop_flag = False

        return mods

    def gen_modified_theo_peaks(self, sequence, mods):
        """Given the sequence and the generated mods, generate the peaks."""
        theo_peaks = []
        current_b_ion = PROTON
        for idx in range(0, len(sequence)):
            current_b_ion += MASS[sequence[idx]]
            if idx in mods:
                current_b_ion += mods[idx]
            theo_peaks.append(current_b_ion)
        current_y_ion = WATER + PROTON
        for idx in range(len(sequence) - 1, -1, -1):
            current_y_ion += MASS[sequence[idx]]
            if idx in mods:
                current_y_ion += mods[idx]
            theo_peaks.append(current_y_ion)
        return sorted(theo_peaks)
