# Copyright (C) 2017 Jiaan Dai

"""Data types."""

from .mass import *

__all__ = ['Spectrum', 'Peptide']


class Spectrum:
    def __init__(self, scan_num, peaks, precursor_mz, precursor_charge):
        self._scan_num = scan_num
        self._peaks = peaks
        self._precursor_mass = (precursor_mz - PROTON) * precursor_charge

    @property
    def peaks(self):
        return self._peaks

    @property
    def precursor_mass(self):
        return self._precursor_mass


class Peptide:
    def __init__(self, sequence, protein, nterm, cterm):
        "'I' should be already replaced by 'L'."
        self._sequence = sequence
        self._protein = protein
        self._nterm = nterm
        self._cterm = cterm

        # calculate mass, possible KeyError, just catch the KeyError outside
        mass = WATER
        for aa in sequence:
            mass += MASS[aa]
        self._mass = mass

        self._theo_peaks = None
        self._len_start = 2
        self._len_end = 5
        self._theo_tags = [None for i in range(self._len_start,
                                               self._len_end+1)]

    @property
    def mass(self):
        return self._mass

    @property
    def sequence(self):
        return self._sequence

    @property
    def protein(self):
        return self._protein

    @property
    def theo_peaks(self):
        if self._theo_peaks is None:
            self._theo_peaks = self._gen_theo_peaks()
        return self._theo_peaks

    # TODO: remove tag_length check, if fails it will directly raise exception
    def get_theo_tags(self, tag_length):
        # if tag_length > self._len_end:
        #     raise RuntimeError(
        #         'Fail to cache tags with length {}'.format(tag_length))
        if self._theo_tags[tag_length-self._len_start] is None:
            self._theo_tags[tag_length-self._len_start] = self._gen_theo_tags(
                tag_length)
        return self._theo_tags[tag_length-self._len_start]

    def __hash__(self):
        return hash(self._sequence)

    def __eq__(self, other):
        return self._sequence == other._sequence

    def _gen_theo_peaks(self):
        theo_peaks = []
        current_b_ion = PROTON
        for idx in range(0, len(self._sequence)):
            current_b_ion += MASS[self._sequence[idx]]
            theo_peaks.append(current_b_ion)
        current_y_ion = WATER + PROTON
        for idx in range(len(self._sequence) - 1, -1, -1):
            current_y_ion += MASS[self._sequence[idx]]
            theo_peaks.append(current_y_ion)
        return sorted(theo_peaks)

    def _gen_theo_tags(self, tag_length):
        theo_tags = set()
        if len(self._sequence) < tag_length:
            return theo_tags
        for idx in range(0, len(self._sequence) - tag_length + 1):
            theo_tag = self._sequence[idx:idx + tag_length]
            reversed_tag = ''.join(reversed(theo_tag))
            if reversed_tag < theo_tag:
                theo_tag = reversed_tag
            theo_tags.add(theo_tag)
        return theo_tags

