import io
import re
import os.path
import time
import json
import base64

from redis import Redis
from PIL import Image

from core.models import ErrorNode, ErrorText


class RedisMessageQueue:

    def __init__(self, task_id, redis_client: Redis):
        self.task_id = task_id
        self.redis_client = redis_client

    def close(self):
        self.redis_client.close()

    def clear(self):
        self.redis_client.delete(f"celery-task-meta-{self.task_id}")

    def send_message(self, request_type, request_body, wait_response=True):
        keys = self.redis_client.keys(f"{self.task_id}_node_request*")
        new_key = f"{self.task_id}_node_request_{len(keys)+1}"
        request_content = {"request_type": request_type, "request_body": request_body}
        self.redis_client.set(name=new_key, value=json.dumps(request_content, ensure_ascii=False))

        if wait_response:
            response_key = new_key.replace('request', "response")
            start = time.time()
            while time.time() - start < 60:
                time.sleep(0.05)
                response = self.redis_client.mget([response_key])[0]
                if response:
                    self.redis_client.delete(response_key)
                    return json.loads(response.decode())
            return []

    def create_response(self, response_key, response_content):
        self.redis_client.set(response_key, json.dumps(response_content))

    def listen(self):
        keys = self.redis_client.keys(f"{self.task_id}_node_request*")
        messages = []
        for key in keys:
            message_content = json.loads(self.redis_client.get(key))
            self.redis_client.delete(key.decode())

            response_key = key.decode().replace('request', "response")
            messages.append(Message(message_content,
                                    response_callback=lambda response_content: self.create_response(response_key=response_key, response_content=response_content)
                                    ))

        return messages


class MessageHandlersRegistry:

    def __init__(self):
        self.request_processors = {}

    def register_handler(self, message_type, handler_class):
        self.request_processors[message_type] = handler_class

    def get_handler_for(self, message_type):
        return self.request_processors[message_type]


class Message:

    def __init__(self, message_content, response_callback):
        self.message_type = message_content['request_type']
        self.message_body = message_content['request_body']

        self.response_callback = response_callback

    def send_response(self, response_content):
        self.response_callback(response_content)


class MessageHandler:

    def __init__(self, context):
        self.context = context

    def handle_message(self, request):
        pass


class WriteLineMessageHandler(MessageHandler):

    MESSAGE_TYPE = "write_line"

    def handle_message(self, message):
        file_code = message.message_body['code']
        self.context.settings.outputs[file_code].writeline(message.message_body['data'])


class ErrorNodeMessageHandler(MessageHandler):

    MESSAGE_TYPE = "handle_error_node"

    def handle_message(self, request):
        name, url, path, error_text, dom, screenshot, properties_chain, proxy_id = request.message_body.values()

        screenshot_path = self.save_screenshot(screenshot)
        dom_path = self.save_dom(dom)
        error_text_object = self.error_text_to_object(error_text=error_text)

        ErrorNode.create(name=name, url=url, path=path,  dom_file_path=dom_path, error_text=error_text_object,
                         screenshot_path=screenshot_path, properties_chain=properties_chain, proxy_id=proxy_id)

    def error_text_to_object(self, error_text) -> ErrorText:
        lines = error_text.split("\n")
        result = []

        insert = True
        for line in lines:
            if "selenium.common" in line:
                insert = False
            if line == '':
                insert = True
            if insert:
                result.append(line)

        result = '\n'.join(result)

        error_text_object = ErrorText.get_or_create(text=result)[0]

        return error_text_object

    def save_dom(self, dom):
        dom_bytes = io.BytesIO(dom.encode())
        path = os.path.join(self.context.conf.error_handling_config.dom_dir,
                            str(time.time_ns()) + ".html")
        storage_machine = self.context.conf.get_storage_machine()
        storage_machine.write_file(file_path=path, file_bytes=dom_bytes)
        return path

    def save_screenshot(self, screenshot):
        image_bytes = io.BytesIO(base64.b64decode(screenshot.encode()))
        path = os.path.join(self.context.conf.error_handling_config.screenshots_dir,
                            str(time.time_ns()) + ".png")
        storage_machine = self.context.conf.get_storage_machine()
        storage_machine.write_file(file_path=path, file_bytes=image_bytes)
        return path


message_handler_registry = MessageHandlersRegistry()

message_handler_registry.register_handler(
                        message_type=WriteLineMessageHandler.MESSAGE_TYPE,
                        handler_class=WriteLineMessageHandler
                        )

message_handler_registry.register_handler(
                        message_type=ErrorNodeMessageHandler.MESSAGE_TYPE,
                        handler_class=ErrorNodeMessageHandler
                        )
