# Copyright (C) 2017 Jiaan Dai

"""Simulator."""

import json
import numpy as np
from . import logger
from .loaders import load_peptides
from .denovo import infer_tags
from .search import search
from .search import cosine_similarity as score
from .specgen import SimSpecGen


class Simulator:
    """Simulator responsible for one simulating setting."""

    def __init__(self, mu, rho, num_mods, tag_length, size, local_seed,
                 writer, params):
        # local params
        self._mu = mu
        self._rho = rho
        self._num_mods = num_mods
        self._tag_length = tag_length
        self._size = size
        self._local_seed = local_seed

        self._local_params = json.dumps({
            'missing_proportion': self._mu,
            'signal_level': self._rho,
            'num_mods': self._num_mods,
            'tag_length': self._tag_length,
            'size': self._size,
            'seed': self._local_seed
        })

        self._writer = writer
        self._params = params
        self._dbprefix = 'data/'

    def _load_database(self, filename):
        """Format: ([mass], [peptide], [tags])."""
        peptide_lower_mass = self._params['peptide_mass_range'][0]
        peptide_upper_mass = self._params['peptide_mass_range'][1]
        peptides = load_peptides(self._dbprefix + filename)
        peptides = [p for p in peptides
            if peptide_lower_mass <= p.mass <= peptide_upper_mass]
        peptide_mass = [p.mass for p in peptides]
        peptide_tags = [p.get_theo_tags(self._tag_length) for p in peptides]
        return (peptide_mass, peptides, peptide_tags)

    def _get_spectrum(self, scan_num):
        if self._num_mods <= 0:
            num_mods = 0  # lambda in Poisson requires positive real value
        else:
            num_mods = np.random.poisson(self._num_mods)
            while num_mods > self._max_length:
                num_mods = np.random.poisson(self._num_mods)

        idx = np.random.randint(len(self._targets[1]))
        p = self._targets[1][idx]
        while len(p.sequence) < num_mods:
            idx = np.random.randint(len(self._targets[1]))
            p = self._targets[1][idx]

        mods = self._specgen.gen_mods(p.sequence, num_mods,
                                      self._params['max_mass_shift'])
        total_mass_shift = sum(mods.values())
        peaks = self._specgen.gen_modified_theo_peaks(p.sequence, mods)
        artificial_spec = self._specgen.gen_artificial_spec(
            p, scan_num, peaks, p.mass + total_mass_shift)
        return artificial_spec, idx, mods

    def init(self):
        """Prepare resources after the worker process is set up."""
        logger.info('Preparing resources...')
        np.random.seed(self._local_seed)
        self._targets = self._load_database(self._params['target_database'])
        self._decoys = self._load_database(self._params['decoy_database'])
        self._specgen = SimSpecGen(self._mu, self._rho, self._params['ms2tol'])
        self._max_length = max(map(lambda x: len(x.sequence), self._targets[1]))
        logger.info('Resources are ready.')

    def run(self):
        logger.info('Start simulation on ' + self._local_params)
        for i in range(0, self._size):
            spec, idx, mods = self._get_spectrum(i)
            peptide = self._targets[1][idx]
            spec_tags = infer_tags(spec, [self._tag_length],
                                   self._params['ms2tol'])[0]

            self_score = score(spec_tags,
                peptide.get_theo_tags(self._tag_length))  # original scoring
            target_score, target_idx = search(spec_tags, self._targets[0],
                self._targets[2], self._params['ms1tol'],
                spec.precursor_mass, score)
            decoy_score, decoy_idx = search(spec_tags, self._decoys[0],
                self._decoys[2], self._params['ms1tol'],
                spec.precursor_mass, score)

            if i % 5000 == 0 and i > 0:
                logger.info('Progress {}/{} spectra done.'.format(
                    i, self._size))

            # build record
            record = {
                'scan_num': i,
                'random_index': idx,
                'label': peptide.sequence,
                'mods': mods,  # index: shift, where index starts from zero
                'extracted_tags': list(spec_tags),
                'self_match': {
                    'score': self_score,
                    'match_num': len(spec_tags.intersection(
                        peptide.get_theo_tags(self._tag_length)))
                },
                'target_match': [{
                    'sequence': self._targets[1][target_idx].sequence,
                    'rank': 1,
                    'score': target_score,
                    'match_num': len(spec_tags.intersection(
                        self._targets[2][target_idx]))
                }],
                'decoy_match': [{
                    'sequence': self._decoys[1][decoy_idx].sequence,
                    'rank': 1,
                    'score': decoy_score,
                    'match_num': len(spec_tags.intersection(
                        self._decoys[2][decoy_idx]))
                }]
            }
            self._writer.append(record)

        logger.info('End simulation. Writing data to disk...')
        self._writer.write(self._local_params) # write to disk
