#!/opt/pypy39-v7.3.9/bin/pypy3.9
import json
import os
import io
import re
import time
import argparse
import yaml

from collections import namedtuple
from multiprocessing.pool import ThreadPool

from PIL import Image
from celery.utils.threads import Local
from peewee import MySQLDatabase, SqliteDatabase
from selenium.webdriver import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait

from core.crawler import CrawlerWorker
from core.firefox import FirefoxClient
from core.machine import Machine, LocalMachine
from core.settings import ClusterSettings
from core.tree import Node
from core.utils import ResourceController

import subprocess
import matplotlib.pyplot as plt

from core.db import db_proxy
from core.models import Proxy, Crawling, ErrorNode, ProcessedNode, ErrorText

LocalConf = namedtuple('LocalConf', ['local_repo_path', 'ssh_key_path'])
RedisConf = namedtuple('RedisConf', ['ip', 'port'])
MysqlConf = namedtuple("MysqlConf", ['ip', 'port', 'user', 'password'])
SqliteConf = namedtuple("SqliteConf", ['filename'])
ErrorHandlingConf = namedtuple("ErrorHandlingConf", ['screenshots_dir', 'dom_dir', 'storage_machine'])
DevelopmentRepositoryConf = namedtuple('DevelopmentRepositoryConf', ['ip', 'user'])


class Config:

    def __init__(self, path):
        self.path = path

        self.conf = None

        self.redis_config = None
        self.local_config = None
        self.error_handling_config = None

        self.mysql_config = None
        self.sqlite_config = None

        self.db = None

        self.parse_config()
        self.configure_db()

        self.machines = list(Machine.select())

        for machine in self.machines:
            machine.set_private_key(self.local_config.ssh_key_path)

        self.local_machine = self.get_local_machine()

        self.check_cluster_settings()

        self.error_handling_config = ErrorHandlingConf(
            storage_machine=ClusterSettings.get_value_of(ClusterSettings.SettingDefaults.ERROR_STORAGE_MACHINE),
            screenshots_dir=ClusterSettings.get_value_of(ClusterSettings.SettingDefaults.ERROR_SCREENSHOTS_DIR),
            dom_dir=ClusterSettings.get_value_of(ClusterSettings.SettingDefaults.ERROR_DOM_DIR)
        )

    def get_local_machine(self) -> LocalMachine:
        local_machine = list(filter(lambda x: isinstance(x, LocalMachine), self.machines))
        if local_machine:
            return local_machine[0]
        else:
            return Machine(ip=Machine._get_current_ip())

    def parse_config(self):
        with open(self.path, encoding="utf-8") as conf:

            self.conf = yaml.safe_load(conf.read().replace("\x00", ''))
            
            local_conf = self.conf['local']
            self.local_config = LocalConf(local_repo_path=local_conf['local_repo_path'], ssh_key_path=local_conf.get('ssh_key_path'))

            redis_conf = self.conf['redis-debug'] if args.debug else self.conf['redis']
            self.redis_config = RedisConf(ip=redis_conf['ip'], port=redis_conf['port'])

            mysql_conf = self.conf.get('mysql')
            if mysql_conf:
                self.mysql_config = MysqlConf(ip=mysql_conf['ip'], port=mysql_conf['port'], user=mysql_conf['user'],
                                              password=mysql_conf['password'])

            sqlite_conf = self.conf.get('sqlite')
            if sqlite_conf:
                self.sqlite_config = SqliteConf(filename=sqlite_conf['filename'])

    def check_cluster_settings(self):
        for setting_default in list(ClusterSettings.SettingDefaults):
            setting = ClusterSettings.get_or_create(setting_name=setting_default.name)[0]

    def configure_db(self):

        if self.mysql_config:
            self.db = MySQLDatabase("crawler", host=self.mysql_config.ip, port=self.mysql_config.port,
                                      user=self.mysql_config.user, password=self.mysql_config.password,
                                      connect_timeout=10)
        elif self.sqlite_config:
            self.db = SqliteDatabase(self.sqlite_config.filename)

        db_proxy.initialize(self.db)

        with self.db:
            self.db.create_tables([Proxy, Crawling, ErrorNode, ProcessedNode, ErrorText, Machine,
                                   ClusterSettings])

    def get_machine_with_name(self, name):
        return list(filter(lambda x: x.name == name, self.machines))[0]

    def get_storage_machine(self) -> Machine:
        return self.get_machine_with_name(self.error_handling_config.storage_machine)

    def get_main_machine(self) -> Machine:
        return list(filter(lambda x: x.type == Machine.Type.MAIN.value, self.machines))[0]

    def machine_selector(self):
        if args.debug:
            return [self.local_machine]
        return filter(lambda x: x.name in args.machines or args.machines == "all",
                      self.machines)


def init(conf):
    print("Database initialized")


def commit(dir_path):
    command = f"""
        cd {dir_path};
        hg addremove;
        hg commit -m 'update';
    """
    
    subprocess.run(command, capture_output=True, shell=True)


def update(machine, dir_path):
    command = f"""
        cd {dir_path};
        hg update
    """
    print(machine.execute_command(command, read=True))


def push(machine, conf):
    command = f"""
        cd {conf.local_config.local_repo_path}; 
        hg push {machine.get_ssh_address()}/home/crawler_node/
    """
    subprocess.run(command, capture_output=True, shell=True)


def sync(conf: Config):
    print("Start synchronization")
    print("Checking for incoming changes")
    print("="*100)
    result = conf.local_machine.execute_command(f"hg incom", read=True)
    print(result)
    print("="*100)

    if "changeset" in result:
        print("You must merge incoming changes before syncing")
        return

    print("Creating release...")

    result = conf.get_main_machine().execute_command("python /home/create_release.py", read=True)
    print(result)

    print("Updating...")

    machines_to_update = [
        machine
        for machine in conf.machines
        if machine.name in args.machines or args.machines == "all" and machine.type != Machine.Type.MAIN.value
    ]

    with ThreadPool(processes=8) as pool:
        pool.map(lambda x: x.update(conf), machines_to_update)

    for machine in conf.machines:
        machine.close()


def deploy(conf):
    
    for machine in conf.machines:
        if machine.name in args.machines or args.machines == "all":
            print(machine)
            install_repo_out = machine.execute_command(
                "yum install -y https://crw.welard.com/crawler/9/noarch/crawler-release-1-0.el9.noarch.rpm", read=True)
            print(install_repo_out)
            install_node_out = machine.execute_command("yum install -y crawler-node", read=True)
            print(install_node_out)


def uninstall(conf):

    for machine in conf.machines:
        if machine.name in args.machines or args.machines == "all":
            print(machine)
            remove_out = machine.execute_command("yum remove -y crawler-*", read=True)
            print(remove_out)


def update_config(conf: Config):

    for machine in conf.machines:
        if machine.name in args.machines or args.machines == "all":
            if machine.type == Machine.Type.MAIN.value:
                continue

            print(machine)
            machine.update_config(conf)


def read_logs(conf):
    machine = list(filter(lambda x: x.name == args.machines[0], conf.machines))[0]
    
    machine.connect_sftp()
    log_file = machine.sftp_connection.open('/var/log/crawler/crawler_node.log')
    try:
        log_file.seek(0,2)
        while True:
            where = log_file.tell() 
            line = log_file.readline()
            if not line:
                time.sleep(0.1)
                log_file.seek(where)
            else:
                print(line.rstrip())
    except:
        machine.close()
    

def restart(conf: Config):
    for machine in conf.machine_selector():
        print(machine)
        machine.restart_node()
        print("OK")


def stop(conf):
    for machine in conf.machines:
        if machine.name in args.machines or args.machines == "all":
            print(machine)
            stdout = machine.execute_command("systemctl stop crawler_node.service", read=True)
            if stdout:
                print(stdout)
            print("OK")


def truncate(conf):
    confirmation = input("Are you sure [y/N]:")
    if confirmation == "y":
        outputs = [os.path.join(f"{conf.local_config.local_repo_path}output/", file)
                   for file in os.listdir(f"{conf.local_config.local_repo_path}output/") 
                   if os.path.isfile(f"{conf.local_config.local_repo_path}output/{file}") ]
        for output in outputs:
            output = open(output, "w")
            output.close()


def update_service(conf):
    for machine in conf.machines:
        if machine.name in args.machines or args.machines == "all":
            print(machine)
            print("Sending service")
            machine.send_to_machine(local_path=f"{conf.local_config.local_repo_path}install/crawler_node.service", 
                                        remote_path="/home/crawler/install/crawler_node.service")

            stdout = machine.execute_command("rm -rf /etc/systemd/system/crawler_node.service", read=True)
            print(stdout)
            stdout = machine.execute_command("cp /home/crawler/install/crawler_node.service /etc/systemd/system/", read=True)
            print(stdout)
            stdout = machine.execute_command("systemctl daemon-reload", read=True)
            print(stdout)
            stdout = machine.execute_command("systemctl restart crawler_node.service", read=True)
            print(stdout)


def mem_stat(conf):
    machines = list(filter(lambda x: x.name in args.machines or "all" in args.machines, conf.machines))
   
    pids = { machine.name: [] for machine in machines }
    
    ram_consumprion = {machine.name: [] for machine in machines}

    plt.ion()
   
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_ylim([0, 20])
    
    
    plt.ylabel("RAM consumption, GB")
    plt.xlabel("Seconds")
    lines = {machine.name: ax.plot(range(60), [0]*60, label=machine.name)[0] for machine in machines}
    ax.plot(range(60), [16]*60, label="max_memo", linestyle="--")
    ax.plot(range(60), [16 - ResourceController.SUSPEND_MARGIN] * 60, label="suspend", linestyle="--")
    ax.plot(range(60), [16 - ResourceController.REGULATING_MARGIN] * 60, label="regulate", linestyle="--")
    ax.legend(loc="upper left")
    try:
        while True:
            for machine in machines:
                total_usage = machine.get_memory_usage()

                ram = float(total_usage) / 1024 / 1024 / 1024
                
                if not args.plot:
                    print(machine, total_usage)
                ram_consumprion[machine.name].append(float(ram))
            
            for machine_name, machine_ram_consumption in ram_consumprion.items():
                if len(machine_ram_consumption) < 60:
                    machine_ram_consumption = machine_ram_consumption + [None] * (60 - len(machine_ram_consumption))
                else:
                    machine_ram_consumption = machine_ram_consumption[-60:]
                lines[machine_name].set_ydata(machine_ram_consumption)
            fig.canvas.draw()
            fig.canvas.flush_events()

            if not args.plot:
                break
            time.sleep(1)
            
    except KeyboardInterrupt:
        pass


def create_crawling(conf):
    print(Crawling)


def crawl(conf):
    crawling = Crawling.get(Crawling.name == args.crawling)
    crawler = crawling.get_crawler()
    crawler.crawl(crawling, conf)


def load_proxies(conf):
    with open("proxies", "r") as proxies_file:
        for line in proxies_file.readlines():
            if line:
                proxy_regex = '^(?P<proto>https?://)?(?P<user>[^:]+):(?P<password>[^@]+)@(?P<ip>[^:]+):(?P<port>\d+)$'
                proxy = re.match(proxy_regex, line)
                Proxy.create(ip=proxy['ip'], port=proxy['port'], user=proxy['user'], password=proxy['password'], scheme="https")


def clean(conf):
   for machine in conf.machines:
       print(machine)
       machine.execute_command("rm -rf /tmp/*")


def load_machines(conf):
    with open(args.file, 'r') as machines_file:
        for line in machines_file.readlines()[1:]:
            name, ip, username, password, type = line.split(",")
            type = type.strip()
            Machine.create(name=name, ip=ip, username=username, password=password, type=Machine.Type[type].value)


def analyze_proxies(conf: Config):
    proxy = Proxy.select().first()

    with FirefoxClient(proxy=proxy, debug=True) as client:
        client.get("http://ipgeolocation.io")

        for proxy in Proxy.select():
            input_element = client.find_element(By.CSS_SELECTOR, "input#inputIPAddress")
            WebDriverWait(client, 10).until(EC.element_to_be_clickable((By.CSS_SELECTOR, "input#inputIPAddress")))
            input_element.clear()
            input_element.send_keys(str(proxy.ip))
            client.find_element(By.CSS_SELECTOR, "button.button.navbar__submit__input").click()
            input_element.send_keys(Keys.ENTER)

            country = client.find_element(By.CSS_SELECTOR, "span.ip-info-right.countryName").text
            proxy.country = country
            proxy.save()


def show_error(conf: Config):
    error_node = ErrorNode.get(ErrorNode.id == args.id)
    if conf.error_handling_config.storage_machine:
        machine = conf.get_storage_machine()
        machine.connect_sftp()
        screenshot_file = machine.sftp_connection.open(error_node.screenshot_path)
        image = Image.open(io.BytesIO(screenshot_file.read()))
    else:
        image = Image.open(error_node.screenshot_machine)
    image.show()

    local_dom_file_path = os.path.join(conf.local_config.local_repo_path, "errors/dom/", os.path.split(error_node.dom_file_path)[1])
    conf.get_storage_machine().download_from_machine(error_node.dom_file_path, local_dom_file_path)

    print(error_node.error_text.text)


def test_error_node(conf):
    crawling = Crawling.get(Crawling.name == args.crawling)
    crawler = crawling.get_crawler()

    error_node = ErrorNode.get(ErrorNode.id == args.id)

    settings = crawler.settings
    settings.enable_debug()
    settings.disable_partitioning()
    settings.max_retries = 1

    path_parts = error_node.path.split('/')
    properties = error_node.properties_chain.split(ErrorNode.DELIMITER)

    nodes = []

    for i, name in enumerate(path_parts):
        node = Node(name=name)
        node.properties = json.loads(properties[i])
        node.layer = i + 1
        if len(nodes):
            nodes[i-1].children.append(node)
            node.parent = nodes[i-1]
        nodes.append(node)

    node = nodes[-1]
    node.page_url = error_node.url
    context = settings.context_class(settings=settings)
    context.initialize_client(proxy=error_node.proxy)
    crawler = CrawlerWorker(context=context, root_node=node)
    crawler.crawl()


def test_node(conf):
    crawling = Crawling.get(Crawling.name == args.crawling)
    crawler = crawling.get_crawler()

    settings = crawler.settings
    settings.enable_debug()
    settings.disable_partitioning()
    settings.max_retries = 1

    node = Node(name=args.name, page_url=args.url, layer=args.layer)

    context = settings.context_class(settings=settings)
    context.initialize_client()
    crawler = CrawlerWorker(context=context, root_node=node)
    crawler.crawl()


parser = argparse.ArgumentParser(
                    prog='crawler admin',
                    description='',
                    epilog='')

parser.add_argument("-e", "--execute", nargs='+', default=[])
parser.add_argument("-m", "--machines", default="all", nargs="+")
parser.add_argument("-em", "--except-machines", default="all", nargs="+")

parser.add_argument("-p", "--plot", default=False)
parser.add_argument('-d', '--debug', action="store_true")
parser.add_argument("-f", "--file", default="")

parser.add_argument("-c", "--crawling", default=False)
parser.add_argument("-id", "--id", default=0)

parser.add_argument("-n",  "--name")
parser.add_argument("-u", "--url")
parser.add_argument("-l", "--layer", type=int)


args = parser.parse_args()


COMMANDS = {
    "init": init,

    "sync": sync,

    "install": deploy,
    "uninstall": uninstall,

    "read_logs": read_logs,
    "restart": restart,
    "stop": stop,
    "truncate": truncate,
    "update_service": update_service,
    "update_config": update_config,
    "mem_stat": mem_stat,
    "load_machines": load_machines,
    "clean": clean,

    "create_crawling": create_crawling,
    "crawl": crawl,

    "show_error": show_error,
    "test_error_node": test_error_node,

    "test_node": test_node,

    "load_proxies": load_proxies,
    "analyze_proxies": analyze_proxies
}

commands = args.execute
conf = Config("config.yaml")

for command in commands:
    COMMANDS[command](conf)





