# Copyright (C) 2018 Jiaan Dai

"""Script for annotating .pin file given .mgf file.

The main task is to annotate .pin file with additional feature obtained by
tag-based fitted model. Therefore, what we do is to loop pin_data and fill
fields of the additional feature. To achieve it, we need .mgf file to extract
parameters. Since .mgf reader only provide an iterator, we could loop the
spectra and query pin_data using scan number. After all spectra are processed,
we drop  N/A in feature column. Then we save processed pin files with and
without additional feature (because we may delete some row, so we also need to
store the one without feature). Then, we call Percolator outside to check the
performance.
"""

import argparse
import re
import json
import numpy as np
import pandas as pd  # inapplicable for remote usage
from pyteomics import mgf  # inapplicable for remote usage
from progress.bar import Bar  # inapplicable for remote usage

import sys
sys.path.append('.')
from PowerSimTool.types import Peptide
from PowerSimTool.specgen import SimSpecGen


def load_spectra(mgf_path):
    for raw_spec in mgf.read(mgf_path):
        mz_array = raw_spec['m/z array']
        scan_num = int(raw_spec['params']['title'].split('.')[-2])
        yield scan_num, mz_array


def extract_params(exp_peaks, theo_peaks, ms2tol=0.02):
    max_mz = max(exp_peaks)
    num_bins = int(max_mz / ms2tol) + 1
    binned_mz = np.zeros(num_bins, dtype=np.int)
    for mz in exp_peaks:  # mark peaks
        idx = int(mz / ms2tol)
        binned_mz[idx] = 1

    # loop label
    match_count = 0
    for mz in theo_peaks:
        idx = int(mz / ms2tol)
        if 0 <= idx < num_bins and binned_mz[idx] == 1:
            match_count += 1
    signal_ratio = float(match_count) / float(len(theo_peaks))

    # remove match
    for mz in theo_peaks:
        idx = int(mz / ms2tol)
        if 0 <= idx < num_bins:
            binned_mz[idx] = 0

    # check remains
    noise_count = 0
    for i in range(0, num_bins):
        if binned_mz[i] == 1:
            noise_count += 1
    noise_ratio = float(noise_count) / float(len(theo_peaks))

    return signal_ratio, noise_ratio


mod_pattern = re.compile(r'([A-Z]\[[\d\.]+\])')
def separate_mods(raw_sequence):
    """Separate mod annotation and backbone sequence."""
    mods = {}
    m = mod_pattern.search(raw_sequence)
    while m:
        mod = m.group()
        idx = m.start()
        aa = raw_sequence[idx]
        if aa != 'C':
            mods[idx] = float(mod[2:-1])
        raw_sequence = raw_sequence.replace(mod, aa, 1)
        m = mod_pattern.search(raw_sequence)
    return raw_sequence, mods


def gen_mod_theo_peaks(sequence, mods):
    specgen = SimSpecGen(0, 0, 0.02)  # parameter doesn't matter
    return specgen.gen_modified_theo_peaks(sequence, mods)


# define our model
def predict(rs, rn, nmod):
    a = -15.1282
    b = 21.9060
    c = 6.4384
    d = -5.0234
    e = -0.1646
    f = 1.0366
    prob = (e * nmod + f) / (1 + np.exp(-(a + b * rs))) / (1 + np.exp(-(c + d * np.log10(rn))))
    return prob


def load_pin(pin_path):
    pin_data = pd.read_csv(pin_path, sep='\t')
    pin_data['NewFeature'] = np.nan
    pin_data['_sequence'] = np.nan
    pin_data['_extract_rs'] = np.nan
    pin_data['_extract_rn'] = np.nan
    pin_data['_extract_nm'] = np.nan
    pin_data['_scan_num'] = pin_data['ScanNr']
    pin_data.set_index('_scan_num', inplace=True)
    return pin_data


def match_mgf_pin(pin_data, mgf_path):
    barfmt = '%(index)d/%(max)d - %(eta)ds\t\t'
    bar = Bar('Matching mgf...', max=len(pin_data), suffix=barfmt)
    for scan_num, mz_array in load_spectra(mgf_path):
        if scan_num in pin_data.index:
            raw_sequence = pin_data.loc[scan_num, 'peptide'][2:-2]
            sequence, mods = separate_mods(raw_sequence)
            sequence = sequence.replace('I', 'L')
            mod_theo_peaks = gen_mod_theo_peaks(sequence, mods)
            signal_ratio, noise_ratio = extract_params(mz_array,
                                                       mod_theo_peaks,
                                                       ms2tol=0.02)
            num_mods = len(mods)
            prob = predict(signal_ratio, noise_ratio, num_mods)
            pin_data.loc[scan_num, 'NewFeature'] = prob
            pin_data.loc[scan_num, '_sequence'] = sequence
            pin_data.loc[scan_num, '_extract_rs'] = signal_ratio
            pin_data.loc[scan_num, '_extract_rn'] = noise_ratio
            pin_data.loc[scan_num, '_extract_nm'] = num_mods
            bar.next()
    bar.finish()
    pin_data.dropna(axis=0, how='any', inplace=True)
    return pin_data


def output(pin_data, output_path):
    pin_data.to_csv(output_path + '.all.pin', sep='\t', index=False)
    pin_data.drop(['_scan_num', '_label', '_num_mods', '_r_s', '_r_n',
                   '_sequence', '_extract_rs', '_extract_rn', '_extract_nm'],
                  axis=1, inplace=True)
    feature_col = pin_data['NewFeature']
    pin_data.drop(['NewFeature'], axis=1, inplace=True)
    pin_data.to_csv(output_path + '.old.pin', sep='\t', index=False)
    pin_data.insert(9, 'NewFeature', feature_col)
    pin_data.to_csv(output_path + '.new.pin', sep='\t', index=False)


def read_reference(jsonl_path):
    scan_array = []
    seq_array = []
    nmod_array = []
    rs_array = []
    rn_array = []
    with open(jsonl_path, 'r') as file:
        for line in file:
            line = line.strip()
            record = json.loads(line)
            scan_array.append(record['scan'])
            seq_array.append(record['seq'])
            nmod_array.append(record['nmod'])
            rs_array.append(record['r_s'])
            rn_array.append(record['r_n'])
    df = pd.DataFrame({ '_scan_num': scan_array,
                        '_label': seq_array,
                        '_num_mods': nmod_array,
                        '_r_s': rs_array,
                        '_r_n': rn_array })
    df.set_index('_scan_num', inplace=True)
    return df


def label_protein(data):
    conds = data['_sequence'] != data['_label']
    data.loc[conds, 'proteinId1'] = data.loc[conds, 'proteinId1'].apply(lambda x: '_X_' + x)
    return data


def main(mgf_path, pin_path, output_path, reference_path):
    pin_data = load_pin(pin_path)
    ref_data = read_reference(reference_path)
    data = pd.merge(left=pin_data, right=ref_data,
                    left_index=True, right_index=True, how='inner')  # inner will dropna
    #data = data[data['_num_mods'] >= 1]
    data = match_mgf_pin(data, mgf_path)
    data = label_protein(data)
    data.reset_index(inplace=True)
    output(data, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--mgf', required=True, help='input .mgf file')
    parser.add_argument('--pin', required=True, help='input .pin file')
    parser.add_argument('-o', required=True, help='output .pin files prefix')
    parser.add_argument('-r', required=True, help='input reference .jsonl file')
    args = parser.parse_args()
    main(args.mgf, args.pin, args.o, args.r)
