import json
import os.path
import time
import base64
import traceback
import requests

from redis import Redis

from core.firefox import FirefoxClient
from core.models import ErrorNode
from core.tree import Node, Tree
from core.messages import RedisMessageQueue, ErrorNodeMessageHandler, WriteLineMessageHandler
from core.utils import timeit


class PartitionOutputMessage:

    def __init__(self, filecode, data):
        self.filecode = filecode
        self.data = data


class Context:

    def __init__(self, settings, proxy=None):
        self.settings = settings
        self.current_node: Node = None
        self.last_fetch = 0
        self.tree = Tree()

        self.crawling = None
        self.conf = None

    def is_request_permitted(self):
        delay = 1 / self.settings.queries_per_second
        if time.time() - self.last_fetch > delay:
            return True
        else:
            return False

    def make_request(self, url):
        response = None
        counter = 0
        if not self.is_request_permitted():
            while not self.is_request_permitted():
                pass

        proxies = {}
        if self.proxy:
            proxies[self.proxy.scheme] = self.proxy.get_full_address()

        while True:
            try:
                response = requests.get(url, timeout=10, proxies=proxies)
                break
            except Exception:
                counter += 1
                if counter > 10:
                    with open('errors.txt', 'w') as f:
                        print(self.page_url, "\n", file=f)
                    break
                time.sleep(2 * counter)
        return response

    def download_file(self, filelink, path):
        response = Context.make_request(self, filelink)

        if response and response.status_code == 200:
            if os.path.isfile(path):
                return
            with open(path, 'wb') as f:
                for chunk in response:
                    f.write(chunk)

    def write(self, code,  node):
        ancestor_chain = node.get_ancestor_chain()
        output = self.settings.outputs[code]
        write_data = output.format_function(list(ancestor_chain))
        output.writeline(','.join(write_data))

    def writeline(self, code, data):
        output = self.settings.outputs[code]
        output.writeline(data)

    def finish(self):
        for output in self.settings.outputs.values():
            output.close()

    def log_error_node(self):
        pass


class SeleniumContext(Context):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client = None
        self.proxy = None

    @timeit
    def make_request(self, url) -> FirefoxClient:
        if not self.is_request_permitted():
            while not self.is_request_permitted():
                pass
        self.client.get(url)
        return self.client

    @timeit
    def initialize_client(self, proxy=None):
        self.proxy = proxy
        self.client = FirefoxClient(
            proxy=proxy,
            debug=self.settings.debug,
            languages=self.settings.languages
        )

    def log_start_node_processing(self):
        pass

    def log_error_node(self):
        print(f"Error on {self.current_node}")
        print(traceback.format_exc())


class CelerySeleniumContext(SeleniumContext):
    
    def __init__(self, task_id, redis_client: Redis, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.message_queue = RedisMessageQueue(task_id=task_id, redis_client=redis_client)

        self.total_processed_nodes = 0
        self.total_error_nodes = 0

    def writeline(self, code, data):
        self.message_queue.send_message(request_type="write_to_db",
                                        request_body=data,
                                        wait_response=False)

    def log_start_node_processing(self):
        self.total_processed_nodes += 1

    def log_error_node(self):
        super().log_error_node()
        self.total_error_nodes += 1

        dom = self.get_error_dom()
        screenshot = self.get_error_screenshot()
        properties_chain = self.get_error_properties_chain()

        error_node = {"name": self.current_node.name,
                      "url": self.current_node.page_url,
                      "path": self.current_node.get_ancestor_path(),
                      "error_text":  traceback.format_exc(),
                      "dom": dom,
                      "screenshot": screenshot,
                      "properties_chain": properties_chain,
                      "proxy_id": self.proxy.id
                      }
        self.message_queue.send_message(request_type=ErrorNodeMessageHandler.MESSAGE_TYPE,
                                        request_body=error_node,
                                        wait_response=False)

    def get_error_properties_chain(self):
        try:
            properties_chain = []
            for node in self.current_node.get_ancestor_chain():
                properties_chain.append(json.dumps(node.properties, ensure_ascii=False))
            return ErrorNode.DELIMITER.join(properties_chain)
        except:
            return ""

    def get_error_dom(self):
        try:
            dom = str(self.client.bs4())
        except:
            dom = ""
        return dom

    def get_error_screenshot(self):
        try:
            screenshot = base64.b64encode(self.client.get_screenshot_as_png()).decode()
        except:
            screenshot = ""
        return screenshot

    def is_crawling_successful(self):
        if self.total_error_nodes / self.total_processed_nodes >= 0.25:
            return False
        else:
            return True

    def finish(self):
        self.client.close()

