# Copyright (C) 2017 Jiaan Dai

"""FDR Controlling utility."""

import json
import gzip


class CompactRecord:
    """Compact record only for FDR control and export."""

    def __init__(self, raw_record):
        """Input the parsed json record."""
        self.scan_num = raw_record['scan_num']
        self.label = raw_record['label']
        self.self_score = raw_record['self_match']['score']
        self.is_decoy = (raw_record['target_match'][0]['score']
                          < raw_record['decoy_match'][0]['score'])
        if self.is_decoy:
            self.match_seq = raw_record['decoy_match'][0]['sequence']
            self.match_score = raw_record['decoy_match'][0]['score']
            self.match_correct = False
        else:
            self.match_seq = raw_record['target_match'][0]['sequence']
            self.match_score = raw_record['target_match'][0]['score']
            self.match_correct = self.label == self.match_seq


class FdrController:
    """Control fdr given a merged result file."""

    def __init__(self, result_filename):
        self._result_filename = result_filename
        self._buffer = []

    def _control_fdr(self, records):
        records = sorted(records, key=lambda x: x.match_score, reverse=True)
        num_target = 0
        num_decoy = 0
        for record in records:
            if record.is_decoy:
                num_decoy += 1
            else:
                num_target += 1
            record.fdr = min(1.0, float(2 * num_decoy) / float(
                num_decoy + num_target))
        current_min = 1.0
        for i in reversed(range(0, len(records))):
            if records[i].fdr < current_min:
                current_min = records[i].fdr
            records[i].q_value = current_min
        return records

    def _calculate_probability(self, records):
        i = len(records) - 1
        while i >= 0 and records[i].q_value > 0.05:
            i -= 1
        correct_id = 0
        if i >= 0:
            # records[i].q_value must be <= 0.05
            for r in records[:i + 1]:
                if (not r.is_decoy) and r.match_correct:
                    correct_id += 1
        probability = float(correct_id) / float(len(records))
        return probability

    def run(self, output_path=None):
        self._buffer.append('{},{},{},{}'.format('mu', 'rho',
                                                 'num_mods', 'prob'))
        with gzip.open(self._result_filename + '.dat', 'rb') as f:
            global_header = f.readline()  # skip
            while True:
                local_header = f.readline()
                if len(local_header) == 0:
                    break
                local_params = json.loads(
                    str(local_header, encoding='utf-8'))
                sample_size = local_params['size']
                records = []
                for i in range(sample_size):
                    raw_record = json.loads(
                        str(f.readline(), encoding='utf-8'))
                    records.append(CompactRecord(raw_record))

                records = self._control_fdr(records)
                prob = self._calculate_probability(records)

                self._buffer.append('{},{},{},{}'.format(
                    local_params['missing_proportion'],
                    local_params['signal_level'],
                    local_params['num_mods'], prob))
        if output_path:
            with open(output_path + '.csv', 'w') as of:
                of.write('\n'.join(self._buffer))
        else:
            print('\n'.join(self._buffer))
