# Copyright (C) 2017 Jiaan Dai

"""Entry script for parallel simulations."""

import os
import argparse
import json
import random
from multiprocessing import Pool
from datetime import datetime

import sys
sys.path.append('.')
from PowerSimTool import Simulator, RecordWriter, RecordMerger, logger

__version__ = '0.5'


def mapper(task):
    task.init()
    task.run()


def run(params, parallel):
    output_dir = 'output/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_path = output_dir + params['output_path']
    tempprefix = output_path + '.' + datetime.now().strftime('PST%Y%m%d')
    tempfilenames = []

    random.seed(params.get('seed', None))
    tasks = []
    assert len(params['missing_proportions']) == len(params['signal_levels'])
    assert len(params['signal_levels']) == len(params['num_mods'])
    length = len(params['num_mods'])
    for i in range(0, length):
        mu = params['missing_proportions'][i]
        rho = params['signal_levels'][i]
        num_mods = params['num_mods'][i]
        local_seed = random.randint(0, 4294967295)
        tempfile = '{}.Temp{:03d}'.format(tempprefix, len(tasks))
        tempfilenames.append(tempfile)
        simulator = Simulator(mu, rho, num_mods, params['tag_length'],
                              params['size'], local_seed,
                              RecordWriter(tempfile), params)
        tasks.append(simulator)

    if parallel:
        logger.info('Multiprocessing is enabled. Simulations run in parallel '
                    'with {} threads.'.format(params['num_threads']))
        with Pool(processes=params['num_threads'], maxtasksperchild=1) as pool:
            pool.map(mapper, tasks, chunksize=1)
    else:
        logger.info('Multiprocessing is disabled. '
                    'Simulations run sequentially.')
        # remove the reference to the simulator object to spare memory
        while tasks:
            mapper(tasks[0])
            del tasks[0]

    logger.info('Workers closed. Merging files...')
    description = 'Generated by PST version ' + __version__
    merger = RecordMerger(output_path, description, params, tempfilenames)
    merger.run()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('config', help='config file path')
    parser.add_argument('--no-parallel', help='disable multiprocessing',
                        action='store_true')
    args = parser.parse_args()
    with open(args.config, 'r') as f:
        params = json.loads(f.read())
    logger.info('Parse {} succeed. Start simulations.'.format(args.config))
    parallel = not args.no_parallel
    run(params, parallel=parallel)
    logger.info('Scheduled simulations end.')
