# Copyright (C) 2018 Jiaan Dai

"""Entry script for spectrum-based parallel simulations."""

import os
import argparse
import json
import random
from multiprocessing import Pool
from datetime import datetime

import sys
sys.path.append('.')
from PowerSimTool import Simulator, RecordWriter, RecordMerger, logger

from PowerSimTool.loaders import load_peptides
from PowerSimTool.search import search
from PowerSimTool.search import cosine_similarity as score

__version__ = '0.5'


class SpectrumBasedSimulator(Simulator):
    """Override run method to conduct spectrum-based simulations."""

    def _convert_theo_spec(self, peaks, ms2tol):
        return set([int(p / ms2tol) for p in peaks])

    def _load_database(self, filename):
        """Format: ([mass], [peptide], [tags])."""
        peptide_lower_mass = self._params['peptide_mass_range'][0]
        peptide_upper_mass = self._params['peptide_mass_range'][1]
        peptides = load_peptides(self._dbprefix + filename)
        peptides = [p for p in peptides
            if peptide_lower_mass <= p.mass <= peptide_upper_mass]
        peptide_mass = [p.mass for p in peptides]
        #peptide_tags = [p.get_theo_tags(self._tag_length) for p in peptides]
        peptide_theo_peaks = [self._convert_theo_spec(p.theo_peaks,
            self._params['ms2tol']) for p in peptides]
        #return (peptide_mass, peptides, peptide_tags)
        return (peptide_mass, peptides, peptide_theo_peaks)

    def run(self):
        logger.info('Start simulation on ' + self._local_params)
        for i in range(0, self._size):
            spec, idx, mods = self._get_spectrum(i)
            peptide = self._targets[1][idx]
            exp_spec = self._convert_theo_spec([p[0] for p in spec.peaks], self._params['ms2tol'])

            self_score = score(exp_spec,
                self._convert_theo_spec(peptide.theo_peaks, self._params['ms2tol']))  # original scoring
            target_score, target_idx = search(exp_spec, self._targets[0],
                self._targets[2], self._params['ms1tol'],
                spec.precursor_mass, score)
            decoy_score, decoy_idx = search(exp_spec, self._decoys[0],
                self._decoys[2], self._params['ms1tol'],
                spec.precursor_mass, score)

            if i % 5000 == 0 and i > 0:
                logger.info('Progress {}/{} spectra done.'.format(
                    i, self._size))

            # build record
            record = {
                'scan_num': i,
                'random_index': idx,
                'label': peptide.sequence,
                'mods': mods,  # index: shift, where index starts from zero
                # 'extracted_tags': list(spec_tags),
                'sim_peaks': [p[0] for p in spec.peaks],
                'self_match': {
                    'score': self_score,
                    'match_num': len(exp_spec.intersection(
                        self._convert_theo_spec(peptide.theo_peaks,
                        self._params['ms2tol'])))
                },
                'target_match': [{
                    'sequence': self._targets[1][target_idx].sequence,
                    'rank': 1,
                    'score': target_score,
                    'match_num': len(exp_spec.intersection(
                        self._targets[2][target_idx]))
                }],
                'decoy_match': [{
                    'sequence': self._decoys[1][decoy_idx].sequence,
                    'rank': 1,
                    'score': decoy_score,
                    'match_num': len(exp_spec.intersection(
                        self._decoys[2][decoy_idx]))
                }]
            }
            self._writer.append(record)

        logger.info('End simulation. Writing data to disk...')
        self._writer.write(self._local_params) # write to disk


def mapper(task):
    task.init()
    task.run()


def run(params, parallel):
    output_dir = 'output/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_path = output_dir + params['output_path']
    tempprefix = output_path + '.' + datetime.now().strftime('PST%Y%m%d')
    tempfilenames = []

    random.seed(params.get('seed', None))
    tasks = []
    for mu in params['missing_proportions']:
        for rho in params['signal_levels']:
            for num_mods in params['num_mods']:
                local_seed = random.randint(0, 4294967295)
                tempfile = '{}.Temp{:03d}'.format(tempprefix, len(tasks))
                tempfilenames.append(tempfile)
                simulator = SpectrumBasedSimulator(
                    mu, rho, num_mods, params['tag_length'],
                    params['size'], local_seed,
                    RecordWriter(tempfile), params)
                tasks.append(simulator)

    if parallel:
        logger.info('Multiprocessing is enabled. Simulations run in parallel '
                    'with {} threads.'.format(params['num_threads']))
        with Pool(processes=params['num_threads'], maxtasksperchild=1) as pool:
            pool.map(mapper, tasks, chunksize=1)
    else:
        logger.info('Multiprocessing is disabled. '
                    'Simulations run sequentially.')
        # remove the reference to the simulator object to spare memory
        while tasks:
            mapper(tasks[0])
            del tasks[0]

    logger.info('Workers closed. Merging files...')
    description = 'Generated by PST version ' + __version__
    merger = RecordMerger(output_path, description, params, tempfilenames)
    merger.run()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('config', help='config file path')
    parser.add_argument('--no-parallel', help='disable multiprocessing',
                        action='store_true')
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        params = json.loads(f.read())
    logger.info('Parse {} succeed. Start spectrum-based simulations.'.format(args.config))
    parallel = not args.no_parallel
    run(params, parallel=parallel)
    logger.info('Scheduled simulations end.')
