# Copyright (C) 2018 Jiaan Dai

"""Extract consensus spectra from .msp and .sptxt spectral library."""

import argparse
import json
import gzip
import re

PROTON = 1.00727


def msp_parse_name(name_field):
    name_field = name_field.split(' ')[1].strip()
    sequence, charge = name_field.split('/')
    charge = int(charge)
    sequence = sequence.replace('(O)', '')
    return sequence, charge


def output(outfile, record):
    if outfile:
        outfile.write(bytes(json.dumps(record), encoding='utf-8'))
        outfile.write(b'\n')
    else:
        print(json.dumps(record))


def msp_loop_lines(infile, outfile=None):
    block = None
    reading_peaks = False

    for inline in infile:
        inline = inline.strip()
        if inline.startswith('###'):
            continue  # skip comments

        if reading_peaks:
            assert block is not None
            if len(inline) <= 0:
                reading_peaks = False
                output(outfile, block)
                block = None
            else:
                mass, intensity, annotation = inline.split('\t')
                mass = float(mass)
                intensity = float(intensity)
                annotation = annotation.strip('"')
                block['peaks'].append({ 'mass': mass,
                                        'intensity': intensity,
                                        'annotation': annotation })
            continue

        if inline.startswith('Name:'):
            if block is not None:  # output block before open a new block
                reading_peaks = False
                output(outfile, block)
            sequence, charge = msp_parse_name(inline)
            block = {'sequence': sequence, 'charge': charge}
        elif inline.startswith('MW'):
            assert block is not None
            charged_mass = inline.split(' ')[1].strip()
            precursor_mass = float(charged_mass) - PROTON * block['charge']
            block['precursor_mass'] = precursor_mass
        elif inline.startswith('Comment:'):
            assert block is not None
            num_mods = 0
            fields = inline[9:].strip().split(' ')
            for field in fields:
                if field.startswith('Mods='):
                    num_mods = int(field.split('=')[1].split('/')[0])
            block['num_mods'] = num_mods
        elif inline.startswith('Num peaks:'):
            assert block is not None
            num_peaks = inline.split(':')[1].strip()
            num_peaks = int(num_peaks)
            block['num_peaks'] = num_peaks
            block['peaks'] = []
            reading_peaks = True
        else:
            raise Exception('Invalid state.')

    if block is not None:
        output(outfile, block)


def sptxt_parse_name(name_field):
    name_field = name_field.split(' ')[1].strip()
    sequence, charge = name_field.split('/')
    charge = int(charge)
    sequence = sequence.replace('C[160]', 'C')
    mods = re.findall(r'(\[\d+(\.\d+)?\])', sequence)
    num_mods = len(mods)
    sequence = re.sub(r'(\[\d+(\.\d+)?\])', '', sequence)
    return sequence, charge, num_mods


def sptxt_loop_lines(infile, outfile=None):
    block = None
    reading_peaks = False

    for inline in infile:
        inline = inline.strip()
        if inline.startswith('###'):
            continue  # skip comments

        if reading_peaks:
            assert block is not None
            if len(inline) <= 0:
                reading_peaks = False
                output(outfile, block)
                block = None
            else:
                words = inline.split('\t')
                mass = words[0]
                intensity = words[1]
                annotation = words[2]
                mass = float(mass)
                intensity = float(intensity)
                annotation = annotation.strip('"')
                block['peaks'].append({ 'mass': mass,
                                        'intensity': intensity,
                                        'annotation': annotation })
            continue

        if inline.startswith('Name:'):
            if block is not None:  # output block before open a new block
                reading_peaks = False
                output(outfile, block)
            sequence, charge, num_mods = sptxt_parse_name(inline)
            block = {'sequence': sequence, 'charge': charge, 'num_mods': num_mods}
        elif inline.startswith('MW'):
            assert block is not None
            charged_mass = inline.split(' ')[1].strip()
            precursor_mass = float(charged_mass) - PROTON * block['charge']
            block['precursor_mass'] = precursor_mass
        elif inline.startswith('NumPeaks:'):
            assert block is not None
            num_peaks = inline.split(':')[1].strip()
            num_peaks = int(num_peaks)
            block['num_peaks'] = num_peaks
            block['peaks'] = []
            reading_peaks = True
        elif any(map(inline.startswith, ['Comment:', 'FullName:', 'LibID:',
                                         'PrecursorMZ:', 'Status:'])):
            pass
        else:
            raise Exception('Invalid state.' + '\n[Current Line] ' + inline)

    if block is not None:
        output(outfile, block)


def main(input_filename, output_filename=None):  # manage resources
    with open(input_filename, 'r') as infile:
        if output_filename:
            with gzip.open(output_filename, 'wb') as outfile:
                if input_filename.endswith('.msp'):
                    msp_loop_lines(infile, outfile)
                elif input_filename.endswith('.sptxt'):
                    sptxt_loop_lines(infile, outfile)
        else:
            if input_filename.endswith('.msp'):
                msp_loop_lines(infile, outfile=None)
            elif input_filename.endswith('.sptxt'):
                sptxt_loop_lines(infile, outfile)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('-i', required=True, help='input .msp spectral library')
    parser.add_argument('-o', help='output gzip-compressed line-delimited .json file')
    args = parser.parse_args()
    main(input_filename=args.i, output_filename=args.o)
