# Copyright (C) 2017 Jiaan Dai

"""Validation script using real data."""

########################
# data configurations
########################

SPECTRA_PATH = r'C:\Users\Jiaan\Desktop\PIPI suppl\Simulation Data\02A.pin.selected_psm.mgf'
REFERENCE_ID = r'C:\Users\Jiaan\Desktop\ReferencePepId.txt'
FASTADB_PATH = r'C:\Users\Jiaan\Desktop\PowerSimTool\data\uniprot-all.fasta'
DECOYDB_PATH = r'C:\Users\Jiaan\Desktop\PowerSimTool\data\uniprot-homo-sapiens.shuffle.fasta'

########################
# environment setup
########################

import sys
sys.path.append('.')
from PowerSimTool.types import Spectrum
from PowerSimTool.loaders import load_peptides
from PowerSimTool.denovo import infer_tags
from PowerSimTool.search import search
from PowerSimTool.search import cosine_similarity as score

########################
# spectra data loading
########################

from pyteomics import mgf  # inapplicable for remote usage

reader = mgf.read(SPECTRA_PATH)
spectra = []
for raw_spec in reader:
    # peaks = [ (raw_spec['m/z array'][i], raw_spec['intensity array'][i])
    peaks = [(raw_spec['m/z array'][i], 1)
              for i in range(0, len(raw_spec['m/z array'])) ]
    precursor_mz = raw_spec['params']['pepmass'][0]
    charge = raw_spec['params']['charge'][0]
    scan_num = int(raw_spec['params']['scans'])
    spec = Spectrum(scan_num, peaks, precursor_mz, charge)
    spectra.append(spec)

########################
# database setup
########################

def _load_database(path):
    peptide_lower_mass = 700
    peptide_upper_mass = 5000
    peptides = load_peptides(path)
    peptides = [p for p in peptides
                if peptide_lower_mass <= p.mass <= peptide_upper_mass]
    peptide_mass = [p.mass for p in peptides]
    peptide_tags = [p.get_theo_tags(3) for p in peptides]
    return (peptide_mass, peptides, peptide_tags)

targets = _load_database(FASTADB_PATH)
decoys = _load_database(DECOYDB_PATH)

########################
# reference data loading
########################

reference_dict = {}  # scan_num -> sequence
valid_dict = {}  # sequence -> None, waiting for mapping to peptide type
is_first_line = True
with open(REFERENCE_ID, 'r') as file:
    for line in file:
        if is_first_line:
            is_first_line = False
        else:
            line = line.strip()
            scan_num, peptide_id = line.split('\t')
            scan_num = int(scan_num)
            peptide_id = peptide_id.strip()
            peptide_id = peptide_id.replace('I', 'L')
            reference_dict[scan_num] = peptide_id
            valid_dict[peptide_id] = None

########################
# validate existence
########################
# loop the whole database to check the existence of reference id
# if exist, move them into a new dict
# reference_dict is supposed to be empty if all refs exist in db

for cand in targets[1]:
    if cand.sequence in valid_dict:
        valid_dict[cand.sequence] = cand  # sequence -> peptide

remain_dict = {}  # scan_num -> peptide
for key, value in reference_dict.items():  # value is sequence
    if valid_dict[value] is not None:  # use sequence to search peptide
        remain_dict[key] = valid_dict[value]  # map scan_num to peptide

########################
# bind labels with spec
########################

spectra = [(spec, remain_dict[spec._scan_num])
           for spec in spectra if spec._scan_num in remain_dict]

print('Total {} spectra are under examination.'.format(len(spectra)))

########################
# search in tags
########################

from datetime import datetime

i = 0
records = []
for spec, peptide in spectra:
    spec_tags = infer_tags(spec, [3], 0.02)[0]
    self_score = score(spec_tags, peptide.get_theo_tags(3))
    target_score, target_idx = search(spec_tags, targets[0], targets[2], 250,
                                      spec.precursor_mass, score)
    decoy_score, decoy_idx = search(spec_tags, decoys[0], decoys[2], 250,
                                    spec.precursor_mass, score)

    record = { 'spec': spec,
               'label': peptide,
               'self_score': self_score,
               'target_id': targets[1][target_idx],
               'target_score': target_score,
               'decoy_id': decoys[1][decoy_idx],
               'decoy_score': decoy_score }
    records.append(record)

    i += 1
    if i % 500 == 0:
        print('{}\tFinish {} spectra.'.format(datetime.now(), i))

########################
# extract quality params
########################

import numpy as np

def extract_params(spectrum, peptide_label):
    """Extract quality params given sequence label."""
    ms2tol = 0.02
    max_mz = max(map(lambda x: x[0], spectrum.peaks))
    num_bins = int(max_mz / 0.02) + 1
    mz_array = np.zeros(num_bins, dtype=np.int)

    # fill mz_array
    for mz, intensity in spectrum.peaks:
        idx = int(mz / 0.02)
        mz_array[idx] = 1

    # loop label
    exist_count = 0
    for mz in peptide_label.theo_peaks:
        idx = int(mz / 0.02)
        ion_exist = False
        if 0 <= idx < num_bins and mz_array[idx] == 1:
            ion_exist = True
            mz_array[idx] = 0
        if 0 <= idx+1 < num_bins and mz_array[idx+1] == 1:
            ion_exist = True
            mz_array[idx+1] = 0
        if 0 <= idx-1 < num_bins and mz_array[idx-1] == 1:
            ion_exist = True
            mz_array[idx-1] = 0
        if ion_exist:
            exist_count += 1

    expect_count = len(peptide_label.theo_peaks)
    signal_ratio = float(exist_count) / float(expect_count)

    # check remains
    noise_count = 0
    for i in range(0, num_bins):
        if mz_array[i] == 1:
            noise_count += 1

    noise_ratio = float(noise_count) / float(expect_count)

    return (signal_ratio, noise_ratio)


for record in records:
    signal_ratio, noise_ratio = extract_params(record['spec'], record['label'])
    record['signal_ratio'] = signal_ratio
    record['noise_ratio'] = noise_ratio

########################
# print search results
########################

print('ScanNum,Label,SelfScore,TargetScore,TargetId,DecoyScore,DecoyId'
      ',SignalRatio,NoiseRatio,PrecursorMass,LabelMass,TargetMass,DecoyMass')
for r in records:
    print('{},{},{},{},{},{},{},{},{},{},{},{},{}'.format(
        r['spec']._scan_num, r['label'].sequence, r['self_score'],
        r['target_score'], r['target_id'].sequence, r['decoy_score'],
        r['decoy_id'].sequence, r['signal_ratio'], r['noise_ratio'],
        r['spec'].precursor_mass, r['label'].mass, r['target_id'].mass,
        r['decoy_id'].mass))
