# Copyright (C) 2017 Jiaan Dai

"""De novo sequencing utilities."""

import bisect
from .mass import *


class _Vertex:
    def __init__(self, mass, intensity):
        self.mass = mass
        self.intensity = intensity
        self.edges = []


class _Edge:
    def __init__(self, aa, vertex_idx):
        self.aa = aa
        self.vertex_idx = vertex_idx


class SpectrumGraph:
    """Spectrum graph and algorithms using for de novo sequencing."""

    def __init__(self, peaks, precursor_mass, ms2tol, max_tag_len):
        """Input peaks should be sorted based on masses (i.e. peak[0])."""
        self._ms2tol = ms2tol
        self._max_tag_len = max_tag_len
        if not peaks:
            self._min_int = 1
        else:
            self._min_int = min(peaks, key=lambda x: x[1])[1]
        self._vertexes = []
        self._vertexes_mass = []  # copy the mass for bisect module
        self._build_graph(peaks, precursor_mass)

    def _append_complementary_peaks(self, peaks, precursor_mass):
        """Use the below formula to calculate and append complementary peaks.

            y_ion = precursor_mass - b_ion + 2 * PROTON

        Also, the starting points and the ending points of b ions and y ions
        are appended as well.
        """
        new_peaks = [
            (PROTON, self._min_int),  # b ion start vertex
            (WATER + PROTON, self._min_int),  # y ion start vertex
            # b ion end vertex
            (precursor_mass - (WATER + PROTON) + 2 * PROTON, self._min_int),
            # y ion end vertex
            (precursor_mass - PROTON + 2 * PROTON, self._min_int)
        ]

        for peak in peaks:
            new_peaks.append(peak)
            new_peaks.append((precursor_mass - peak[0] + 2 * PROTON, peak[1]))
        return sorted(new_peaks, key=lambda x: x[0])

    def _build_graph(self, peaks, precursor_mass):
        new_peaks = self._append_complementary_peaks(peaks, precursor_mass)
        for peak in new_peaks:  # append vertexes from left to right
            self._add_vertex(peak)

    def _add_vertex(self, peak):
        vertex = _Vertex(peak[0], peak[1])
        self._vertexes.append(vertex)
        self._vertexes_mass.append(peak[0])
        for aa, mass in MASS.items():
            target_mass = peak[0] - mass
            start = bisect.bisect_left(self._vertexes_mass,
                                       target_mass - self._ms2tol)
            end = bisect.bisect_right(self._vertexes_mass,
                                      target_mass + self._ms2tol)
            for i in range(start, end):
                edge = _Edge(aa, len(self._vertexes) - 1)
                self._vertexes[i].edges.append(edge)

    def iterate_paths(self):
        """"Return all possible path strings."""
        # using dp solution
        dptable = [set() for i in range(0, len(self._vertexes))]
        for i in range(len(self._vertexes) - 1, -1, -1):
            if len(self._vertexes[i].edges) == 0:
                pass
            else:
                paths = set()
                for edge in self._vertexes[i].edges:
                    rest = dptable[edge.vertex_idx]
                    if len(rest) == 0:
                        paths.add(edge.aa)
                    else:
                        for partial_path in rest:
                            extended_path = edge.aa + partial_path
                            if len(extended_path) <= self._max_tag_len:
                                paths.add(extended_path)
                            else:
                                paths.add(extended_path[0:self._max_tag_len])
                dptable[i] = paths
        res = set()
        for entry in dptable:
            if len(entry) != 0:
                res = res.union(entry)
        return list(res)


def infer_tags(spectrum, tag_lengths, ms2tol):
    """"Infer tags of multiple tag lengths, allow mix-use different tag length.
    """
    pools = [set() for i in range(0, len(tag_lengths))]
    max_tag_len = max(tag_lengths)
    spectrum_graph = SpectrumGraph(spectrum.peaks, spectrum.precursor_mass,
                                   ms2tol, max_tag_len)
    paths = spectrum_graph.iterate_paths()
    for path in paths:
        for len_idx in range(0, len(tag_lengths)):
            tag_length = tag_lengths[len_idx]
            if len(path) < tag_length:
                break
            for start_idx in range(0, len(path) - tag_length + 1):
                tag = path[start_idx:start_idx+tag_length]
                reversed_tag = ''.join(reversed(tag))
                if reversed_tag < tag:
                    tag = reversed_tag
                pools[len_idx].add(tag)
    return pools

