import io
import subprocess

import paramiko
from enum import Enum
from core.db import db_proxy
from peewee import *


class Machine(Model):

    class Meta:
        database = db_proxy
        table_name = "machine"

    USAGE_FILE_PATH = {
        1: "/sys/fs/cgroup/memory/memory.usage_in_bytes",
        2: "/sys/fs/cgroup/memory.current"
    }

    class Type(Enum):
        MAIN = 0
        NODE = 1
        FILE_STORAGE = 2

    name = CharField(max_length=32, unique=True)
    ip = CharField(max_length=39, unique=True)
    username = CharField(max_length=32)
    password = CharField(max_length=100, null=True)
    type = IntegerField()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.cgroup_version = None
        self.private_key = None

        current_ip = self._get_current_ip()
        if current_ip == self.ip:
            self.__class__ = LocalMachine
            self.__init__()
        else:
            self.__class__ = RemoteMachine
            self.__init__()

    def __str__(self):
        return f"<Machine {self.ip}>"

    def hard_update(self):
        hg_update_out = self.execute_command("""
            cd /home/crawler/;
            hg update
        """, read=True)
        print(hg_update_out)

    def restart_node(self):
        if self.type != Machine.Type.MAIN:
            self.execute_command("systemctl restart crawler_node.service")

    def stop_node(self):
        self.execute_command("systemctl stop crawler_node.service")

    def reload_daemon(self):
        self.execute_command("systemctl daemon-reload")

    def set_private_key(self, private_key):
        self.private_key = private_key

    @staticmethod
    def _get_current_ip():
        return subprocess.check_output(['hostname', '-I']).decode().split()[0].strip()

    def execute_command(self, command, read=False):
        pass

    def send_to_machine(self, local_path, remote_path):
        pass

    def download_from_machine(self, remote_path, local_path):
        pass

    def write_file(self, file_path, file_bytes):
        pass

    def close(self):
        pass

    def get_memory_usage(self) -> int:
        pass

    def get_cgroup_version(self) -> int:
        pass


class RemoteMachine(Machine):

    class Meta:
        database = db_proxy
        table_name = "machine"

    def __init__(self):
        self.ssh_connection = None
        self.sftp_connection = None

    def get_ssh_address(self):
        return f"ssh://{self.username}@{self.ip}/"

    def connect(self):
        self.ssh_connection = paramiko.SSHClient()
        self.ssh_connection.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        self.ssh_connection.connect(self.ip, username=self.username, key_filename=self.private_key)

    def connect_sftp(self):
        self.check_connection()
        if not self.sftp_connection:
            self.sftp_connection = self.ssh_connection.open_sftp()

    def check_connection(self):
        if not self.ssh_connection:
            self.connect()

    def update(self, conf):
        self.execute_command("yum update -y crawler-node", read=True)
        self.execute_command("yum update -y firefox", read=True)
        self.update_config(conf)

    def update_config(self, config):
        config_file = f"""
                    redis:
                        ip: {config.redis_config.ip}
                        port: {config.redis_config.port}
                    """

        self.write_file("/home/crawler_node/config.yaml", io.BytesIO(config_file.encode()))
        self.restart_node()

    def execute_command(self, command, read=False):
        self.check_connection()
        stdin, stdout, stderr = self.ssh_connection.exec_command(command)
        if read:
            stdout = stdout.read().decode().strip()
            if stdout:
                return stdout
            else:
                return
        return stdin, stdout, stderr

    def send_to_machine(self, local_path, remote_path):
        if not self.sftp_connection:
            self.connect_sftp()
        self.sftp_connection.put(local_path, remote_path)

    def download_from_machine(self, remote_path, local_path):
        if not self.sftp_connection:
            self.connect_sftp()
        self.sftp_connection.get(remote_path, local_path)

    def write_file(self, file_path, file_bytes):
        if not self.sftp_connection:
            self.connect_sftp()

        self.sftp_connection.putfo(file_bytes, file_path)

    def close(self):
        self.ssh_connection.close()

    def get_memory_usage(self) -> int:
        return int(self.execute_command(f"echo \"$(($(cat {self.USAGE_FILE_PATH[self.get_cgroup_version()]})))\"",
                                        read=True).strip())

    def get_cgroup_version(self) -> int:
        if self.cgroup_version:
            return self.cgroup_version
        self.connect_sftp()
        try:
            self.sftp_connection.open("/sys/fs/cgroup/cgroup.controllers")
            self.cgroup_version = 2
        except:
            self.cgroup_version = 1
        return self.cgroup_version

    def hard_update(self):
        return self.execute_command("""
            cd /home/crawler_node/;
            hg update
        """, read=True)


class LocalMachine(Machine):

    def __init__(self):
        pass

    def execute_command(self, command, read=False):
        result = subprocess.run(command, shell=True, capture_output=True)
        if read:
            return result.stdout.decode().strip()

    def write_file(self, file_path, file_bytes):
        with open(file_path, "wb") as file:
            file.write(file_bytes.read())
