import time
import json
import msgpack
import pika
import os, sys
import dill
from elasticsearch import Elasticsearch, NotFoundError

fileDir = os.path.dirname(os.path.abspath(__file__))
parentDir = os.path.dirname(fileDir)
newPath = os.path.join(parentDir)
sys.path.append(newPath)
from lib.my_logging import logging
from lib.exit_gracefully import exit_gracefully
from lib.locker import unlock
from lib.postgis_helper import Remote

class NotEmptyQueueException(Exception):
    pass


def create_sampling_task(cfg, channel, uuid):

    # here-below we generate a task for the sample generator (full -> sample)
    msg = dict()
    msg['header'] = dict()
    msg['header']['cfg'] = cfg
    #msg['header']['reindex_task_url'] = reindex_task_url
    msg['body'] = uuid

    the_body = msgpack.packb(msg, use_bin_type=True)

    # connection = pika.BlockingConnection(pika.ConnectionParameters(host=cfg['rabbitmq']['host']))
    # channel = connection.channel()
    exchange = cfg['rabbitmq']['exchange']

    queue_name  = cfg['rabbitmq']['queue_name_6']
    routing_key = cfg['rabbitmq']['routing_key_6']

    channel.exchange_declare(exchange=cfg['rabbitmq']['exchange'], exchange_type='direct')
    channel.queue_declare(queue=queue_name, durable=True, arguments={'x-message-ttl' : cfg['rabbitmq']['ttl']})
    channel.queue_bind(exchange=cfg['rabbitmq']['exchange'], queue=queue_name, routing_key=routing_key)

    channel.basic_publish( exchange=exchange,
                           routing_key=routing_key,
                           body=the_body,
                           properties=pika.BasicProperties(delivery_mode = 2)
                         )

    #connection.close()


    return

def on_msg_callback(channel, method, properties, body):

    decoded_body = msgpack.unpackb(body, raw=False)
    cfg = decoded_body['header']['cfg']
    uuid = decoded_body['body']
    serialized_deferred_count = decoded_body['header']['serialized_deferred_count']
    deferred_count = dill.loads(serialized_deferred_count)
    count_ref = deferred_count()

    # from lib.elasticsearch_template import template
    # template['index_patterns'] = [ cfg['reindexer']['destination_index'] ]
    # template['settings']['number_of_shards'] = cfg['reindexer']['number_of_shards']
    # template['settings']['number_of_replicas'] = cfg['reindexer']['number_of_replicas']

    if 'source_url' in cfg['reindexer'].keys():
        es_source = Elasticsearch([cfg['reindexer']['source_url']], timeout=60)
    else:
        es_source = Elasticsearch([cfg['reindexer']['destination_url']], timeout=60)

    # es_logger = logging.getLogger('elasticsearch')
    # es_logger.setLevel(logging.INFO)

    the_query = dict()
    the_query['query'] = dict()
    the_query['query']['term'] = {'uuid.keyword': '{0}'.format(uuid)}

    es_source.indices.refresh(index=cfg['reindexer']['source_index'])
    count_es = es_source.count(cfg['reindexer']['source_index'], body=the_query).get('count')
    # logging.debug("%i document(s) found in the source index with uuid = %s" % (count1, uuid))

    # if uuid.endswith('.full'):
    #
    #     logging.debug("Waiting for 5 seconds before counting again...")
    #     time.sleep(5)
    #
    #     es_source.indices.refresh(index=cfg['reindexer']['source_index'])
    #     count2 = es_source.count(cfg['reindexer']['source_index'], body=the_query).get('count')
    #     logging.debug("%i document(s) found in the source index with uuid = %s" % (count2, uuid))
    #
    #     if count1 != count2 or count2 == 0:
    #
    #         logging.warning('Documents are still being pushed to the source index. Waiting...')
    #         time.sleep(5)
    #         channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
    #         return
    #         #raise NotEmptyQueueException('Documents are still being pushed to the source index. Waiting...')
    #
    # elif uuid.endswith('.meta'):
    #
    #     if count1 != 1:
    #
    #         logging.warning('Documents are still being pushed to the source index. Waiting...')
    #         time.sleep(5)
    #         channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
    #         return
    #
    # else:
    #     channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
    #     logging.error("The uuid ends neither with .full nor with .meta. What shall I do?")
        # return


    # if count_es != count_ref:
    #     logging.warning('Documents are still being pushed to the source index for dataset with uuid = %s' % uuid)
    #     logging.debug('count_es = %i; count_ref = %i' % (count_es, count_ref))
    #     time.sleep(5)
    #     channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
    #     return

    # -1. checking whether Elasticsearch is already busy with some other reindexation tasks
    es = Elasticsearch([cfg['reindexer']['destination_url']], timeout=60)

    rep = es.tasks.list(actions="indices:data/write/reindex")

    reindexation_tasks_no = 0

    for node_id, node_info in rep['nodes'].items():
        reindexation_tasks_no += len(node_info['tasks'].keys())

    if reindexation_tasks_no > 0:
        logging.info("Elasticsearch is already busy with reindexation tasks. Sleeping for 5 seconds before retrying...")
        time.sleep(5)
        channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
        return

    # 0. checking whether Elasticsearch is already busy with some other delete_by_query tasks
    es = Elasticsearch([cfg['reindexer']['destination_url']], timeout=60)

    rep = es.tasks.list(actions="indices:data/write/delete/byquery")

    delete_by_query_tasks_no = 0

    for node_id, node_info in rep['nodes'].items():
        delete_by_query_tasks_no += len(node_info['tasks'].keys())

    if delete_by_query_tasks_no > 0:
        logging.info("Elasticsearch is already busy with delete_by_query tasks. Sleeping for 5 seconds before retrying...")
        time.sleep(5)
        channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
        return

    # 1. remove already existing docs from destination index
    logging.info("Removing dataset with uuid = %s from the destination index..." % uuid)

    #es = Elasticsearch([cfg['reindexer']['destination_url']], timeout=60)
    index = cfg['reindexer']['destination_index']

    try:
        es.indices.refresh(index=index)
    except NotFoundError:
        # the destination index may not be already present
        pass

    the_query = dict()
    the_query['query'] = dict()
    the_query['query']['term'] = {'uuid.keyword': '{0}'.format(uuid)}

    try:
        res = es.delete_by_query(index, doc_type='_doc', body=the_query, conflicts='proceed', refresh=True, wait_for_completion=False)
        #logging.debug(res)
        task_id = res['task']
        # wait until ES is done
        seconds_to_sleep_for = 1
        while True:
            res = es.tasks.get(task_id=task_id)
            #logging.debug(res)
            completed = res['completed']
            if not completed:
                logging.info('Waiting for delete_by_query to complete: sleeping for %i seconds...' % seconds_to_sleep_for)
                time.sleep(seconds_to_sleep_for)
                seconds_to_sleep_for += 1
            else:
                break
    except NotFoundError:
        pass
    except Exception as e:
        logging.error(e)
        channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
        return

    # # 2. setup template
    # try:
    #   rep = es.indices.delete_template(cfg['reindexer']['template_name'])
    #   logging.debug(rep)
    # except:
    #   pass
    #
    # rep = es.indices.put_template(cfg['reindexer']['template_name'], template)
    # # rep = es.indices.get_template("template_1")
    # logging.debug(rep)


    # 3. trigger reindexation
    body = {
        "source": {
            "index": cfg['reindexer']['source_index'],
            "query": {
                "term": {"uuid.keyword": '{0}'.format(uuid)}
            },
            "type": "_doc",
            "size": 1000
        },
        "dest": {
            "index": cfg['reindexer']['destination_index'],
            "type": "_doc"
        }
    }

    if 'source_url' in cfg['reindexer'].keys():
        body['source']['remote'] = {'host': cfg['reindexer']['source_url']}

    rep = es.reindex(body, wait_for_completion=False)

    logging.debug(rep)

    if 'task' in rep:
        channel.basic_ack(delivery_tag = method.delivery_tag)
        #print("")
        reindex_task_url = "{0}/_tasks/{1}".format(cfg['reindexer']['destination_url'], rep['task'])
        logging.info("Created reindex task: {0}".format(reindex_task_url))

        # 3. create sampling task (full -> sample)
        create_sampling_task(cfg, channel, uuid)#, reindex_task_url)
        logging.info("Created sampling task.")

    else:
        channel.basic_nack(delivery_tag = method.delivery_tag, requeue=1)
        #print("")
        #logging.error(json.dumps(rep, indent=4))
        logging.error("Failed")


    return


def main(cfg):

    #from lib.close_connection import on_timeout

    connection = pika.BlockingConnection(pika.ConnectionParameters(host=cfg['rabbitmq_host'], port=cfg['rabbitmq_port']))
    #timeout = 5
    #connection.add_timeout(timeout, on_timeout(connection))
    channel = connection.channel()
    exchange    = cfg['rabbitmq_exchange']
    # the queue this program will consume messages from:
    reindex_tasks_to_create_qn = cfg['rabbitmq_queue']

    channel.basic_qos(prefetch_count=1)
    channel.basic_consume(on_message_callback=lambda ch, method, properties, body: on_msg_callback(ch, method, properties, body),
                            queue=reindex_tasks_to_create_qn)

    channel.start_consuming()

    connection.close()

    return


if __name__ == '__main__':

    import yaml
    import time
    import signal
    import argparse

    signal.signal(signal.SIGINT, exit_gracefully)

    parser = argparse.ArgumentParser(description='Incremental reindexer')
    parser.add_argument('--host', dest='host', help='the RabbitMQ host', type=str, required=True)
    parser.add_argument('--port', dest='port', help='the RabbitMQ port', type=int, default=5672)
    parser.add_argument('--exchange', dest='exchange', help='the RabbitMQ exchange', type=str, required=True)
    parser.add_argument('--queue', dest='queue', help='the RabbitMQ queue', type=str, required=True)
    parser.add_argument('--loglevel', dest='loglevel', help='the log level', default="INFO", type=str, choices=['INFO', 'DEBUG', 'WARN', 'CRITICAL', 'ERROR'])

    args = parser.parse_args()

    cfg = dict()
    cfg['rabbitmq_host'] = args.host
    cfg['rabbitmq_port'] = args.port
    cfg['rabbitmq_exchange'] = args.exchange
    cfg['rabbitmq_queue'] = args.queue

    logging.getLogger().setLevel(args.loglevel)
    logging.info('Starting...')

    while True:
        try:
            main(cfg)
        except pika.exceptions.ChannelClosed:
            logging.info("Waiting for tasks...")
            time.sleep(5)
        except pika.exceptions.AMQPConnectionError:
            logging.info('Waiting for RabbitMQ to be reachable...')
            time.sleep(5)
        except Exception as e:
            logging.error(e)
            time.sleep(5)
            exit(1)