import codecs
import pickle
import datetime
import time
import traceback
from multiprocessing import Process, Queue
from multiprocessing.pool import ThreadPool

from celery import Celery
from celery.result import AsyncResult
from celery.worker.control import revoke

from core.messages import RedisMessageQueue, message_handler_registry
from core.context import PartitionOutputMessage

from core.models import ProcessedNode, Proxy
from core.tree import Node


class Layer:

    def __init__(self, name, process_function, multi_worker_processing=False):
        self.name = name
        self.process_function = process_function
        self.partition = multi_worker_processing


class CrawlerWorker:

    def __init__(self, context, root_node):
        self.context = context
        self.root_node = root_node

    def crawl(self):
        self.context.settings.init_function(self.context)
        self.context.current_node = self.root_node
        self.context.tree.set_root_node(self.root_node)

        self.try_process_n_times(
            self.context.settings.layers[self.root_node.layer].process_function
        )

        deepest_node = self.find_unpocessed_node()
        while deepest_node != self.root_node:
            if self.context.settings.layers[deepest_node.layer].partition:
                self.create_workers_pool(deepest_node)
                deepest_node = self.find_unpocessed_node()
                continue

            self.context.current_node = deepest_node

            print("--" * deepest_node.layer, deepest_node.name, "==", deepest_node.parent.name)

            self.try_process_n_times(
                self.context.settings.layers[deepest_node.layer].process_function
            )

            if not deepest_node.children:
                deepest_node.parent.remove_child(deepest_node)
            deepest_node = self.find_unpocessed_node()
        self.context.finish()

    def try_process_n_times(self, process_function):
        counter = 0
        self.context.log_start_node_processing()
        while True:
            try:
                process_function(self.context)
                break
            except Exception as ex:
                counter += 1
                if counter >= self.context.settings.max_retries:
                    self.context.log_error_node()
                    break

    def create_workers_pool(self, deepest_node):

        processed_nodes_names = [processed_node.name
                                 for processed_node
                                 in
                                 ProcessedNode.select().where(ProcessedNode.parent_id == deepest_node.parent.node_id)]

        nodes_to_process = list(filter(lambda x: x.name not in processed_nodes_names, deepest_node.parent.children))

        workers_pool = CrawlerWorkersPool(context=self.context, nodes_to_process=nodes_to_process)
        deepest_node.parent.children = []
        workers_pool.crawl()

    def find_unpocessed_node(self):
        node = self.context.tree.root
        while node.children:
            node = node.children[0]
        return node


class MultiprocessingCrawlerWorker(CrawlerWorker):

    def __init__(self, base_context, node_to_process, proxy):
        self.queue = Queue()
        process_context = base_context.settings.context_class.get_partition_context_class()(queue=self.queue,
                                                                                            settings=base_context.settings,
                                                                                            proxy=proxy)
        super().__init__(root_node=node_to_process, context=process_context)
        self.base_context = base_context
        self.process = None
        self.alive = True

    def parse(self):
        self.process = Process(target=super().crawl, args=())
        self.process.start()

    def handle(self):
        queue_size = self.queue.qsize()
        for i in range(queue_size):
            message = self.queue.get()
            if type(message) == PartitionOutputMessage:
                self.base_context.writeline(code=message.filecode, data=message.data)
            elif message == -1:
                self.alive = False
                self.context.proxy.unlock()

    def is_alive(self):
        return self.alive


class CeleryCrawlerWorker(CrawlerWorker):

    def __init__(self, context, node_to_process: Node, proxy: Proxy):
        self.context = context
        self.settings = context.settings.create_remote_settings()

        self.celery_app = Celery('crawler_node', broker=context.settings.celery_broker_url,
                                 backend=context.settings.celery_broker_url)

        self.node_to_process = node_to_process
        self.proxy = proxy

        self.task = None
        self.is_result_processed = False
        self.is_error = False
        self.message_queue = None

        self.start_time = None
        self.end_time = None

        self.restart_count = 0

    def crawl(self):
        self.task = self.celery_app.signature('crawler_node.celery.process_node')

        settings = codecs.encode(pickle.dumps(self.settings), "base64").decode()
        node_to_process = codecs.encode(pickle.dumps(self.node_to_process), "base64").decode()
        proxy = codecs.encode(pickle.dumps(self.proxy), "base64").decode()

        self.task = self.task.delay(settings, node_to_process, proxy)
        self.message_queue = RedisMessageQueue(task_id=self.task.id, redis_client=self.task.backend.client)
        self.start_time = datetime.datetime.now()
        self.is_result_processed = False

    def restart(self):
        self.restart_count += 1
        self.crawl()

    def handle(self):
        try:
            self.handle_()
        except Exception as ex:
            trb = traceback.format_exc()
            print(trb)
            print(str(ex), "Restart", self.node_to_process)
            if self.restart_count <= 0:
                self.restart()
            else:
                self.proxy.unlock()
                self.message_queue.close()
                self.is_error = True

    def handle_(self):
        if self.is_result_processed:
            return

        requests = self.message_queue.listen()

        for request in requests:
            processor = message_handler_registry.get_handler_for(message_type=request.message_type)(self.context)
            processor.handle_message(request)

        result = AsyncResult(self.task.id, app=self.celery_app)
        if result.ready():
            result.get()
            self.is_result_processed = True
            self.message_queue.clear()
            self.end_time = datetime.datetime.now()

    def is_alive(self):
        return not self.is_result_processed

    def get_execution_time(self):
        return self.end_time - self.start_time


class CrawlerWorkersPool:

    def __init__(self, context, nodes_to_process):
        self.context = context
        self.nodes_to_process = nodes_to_process
        self.workers = []

    def crawl(self):
        for node in self.nodes_to_process:
            node.cache_ancestor_chain()

        while True:
            if self.is_all_nodes_processed():
                break
            if self.is_allocation_needed():
                self.allocate_workers()

            for worker in self.workers:
                worker.handle()

            error_workers = [worker for worker in self.workers if worker.is_error]
            for error_worker in error_workers:
                self.workers.remove(error_worker)

            done_workers = [worker for worker in self.workers if not worker.is_alive()]
            for done_worker in done_workers:
                done_worker.handle()
                if done_worker.is_result_processed:
                    self.end_node_processing(done_worker)

    def end_node_processing(self, worker):
        self.workers.remove(worker)
        worker.proxy.unlock()
        worker.message_queue.close()
        processed_node = worker.node_to_process
        ProcessedNode.create(name=processed_node.name,
                             ancestor_path=processed_node.get_ancestor_path(),
                             parent_id=processed_node.parent.node_id,
                             crawling=self.context.crawling)
        print("Done", worker.node_to_process, "Execution time:", worker.get_execution_time(), worker.node_to_process.page_url)

    def is_all_nodes_processed(self):
        return not self.nodes_to_process and all([not worker.is_alive() for worker in self.workers])

    def allocate_workers(self):
        while self.is_allocation_needed():
            node_to_process = self.nodes_to_process.pop(0)
            proxy = self.context.settings.get_unlocked_proxy()
            print("Start", node_to_process, node_to_process.page_url)
            partition_parser_worker = CeleryCrawlerWorker(context=self.context,
                                                          node_to_process=node_to_process,
                                                          proxy=proxy)
            partition_parser_worker.crawl()
            self.workers.append(partition_parser_worker)

    def is_allocation_needed(self):
        return (len(self.workers) < self.context.settings.max_workers) and self.nodes_to_process


class Crawler:
    ROOT_LAYER_INDEX = 0

    def __init__(self, settings):
        self.settings = settings
        self.workers = []

        self.context = self.settings.context_class(settings=settings)

        self.root_node = self.context.tree.add_root_node(name=self.settings.layers[Crawler.ROOT_LAYER_INDEX].name,
                                                         page_url=self.context.settings.base_url)
        self.context.current_node = self.root_node
        self.base_parser_worker = CrawlerWorker(self.context, root_node=self.root_node)

    def crawl(self, crawling, conf):
        try:
            self.context.crawling = crawling
            self.context.conf = conf
            self.context.initialize_client(proxy=self.settings.get_unlocked_proxy())
            self.base_parser_worker.crawl()
        except (KeyboardInterrupt, Exception) as ex:
            trb = traceback.format_exc()
            print(trb)
            for output in self.settings.outputs.values():
                output.force_close()

