from core.Errors import ProfileStateConflictError, InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError
from core.controllers.ApplicationController import ApplicationController
from core.controllers.ApplicationVersionController import ApplicationVersionController
from core.controllers.SessionStateController import SessionStateController
from core.controllers.SystemStateController import SystemStateController
from core.models.BaseProfile import BaseProfile as Profile
from core.models.Subscription import Subscription
from core.models.session.SessionProfile import SessionProfile
from core.models.system.SystemProfile import SystemProfile
from core.observers.ApplicationVersionObserver import ApplicationVersionObserver
from core.observers.ConnectionObserver import ConnectionObserver
from core.observers.ProfileObserver import ProfileObserver
from core.services.WebServiceApiService import WebServiceApiService
from typing import Union, Optional
import base64
import re
import time


class ProfileController:

    @staticmethod
    def get(id: int) -> Union[SessionProfile, SystemProfile, None]:
        return Profile.find_by_id(id)

    @staticmethod
    def get_all():
        return Profile.all()

    @staticmethod
    def create(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):

        profile.save()

        if profile_observer is not None:
            profile_observer.notify('created', profile)

    @staticmethod
    def update(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):

        profile.save()

        if profile_observer is not None:
            profile_observer.notify('updated', profile)

    @staticmethod
    def enable(profile: Union[SessionProfile, SystemProfile], force: bool = False, pristine: bool = False, asynchronous: bool = False, profile_observer: ProfileObserver = None, application_version_observer: ApplicationVersionObserver = None, connection_observer: ConnectionObserver = None):

        from core.controllers.ConnectionController import ConnectionController

        if ProfileController.is_enabled(profile):

            if not force:
                raise ProfileStateConflictError('The profile is already enabled or its session was not properly terminated.')
            else:
                ProfileController.disable(profile)

        if pristine:
            profile.delete_data()

        if profile.is_session_profile():

            application_version = profile.application_version

            if not application_version.is_installed():
                ApplicationVersionController.install(application_version, application_version_observer=application_version_observer, connection_observer=connection_observer)

            try:
                port_number = ConnectionController.establish_connection(profile, force=force, connection_observer=connection_observer)
            except ConnectionError:
                raise ProfileActivationError('The profile could not be enabled.')

            if profile_observer is not None:
                profile_observer.notify('enabled', profile)

            ApplicationController.launch(application_version, profile, port_number, asynchronous=asynchronous, profile_observer=profile_observer)

        if profile.is_system_profile():

            try:
                ConnectionController.establish_connection(profile, force=force, connection_observer=connection_observer)
            except ConnectionError:
                raise ProfileActivationError('The profile could not be enabled.')

            if profile_observer is not None:
                profile_observer.notify('enabled', profile)

    @staticmethod
    def disable(profile: Union[SessionProfile, SystemProfile], explicitly: bool = True, force: bool = False, profile_observer: ProfileObserver = None):

        from core.controllers.ConnectionController import ConnectionController

        if profile.is_session_profile():

            if SessionStateController.exists(profile.id):

                session_state = SessionStateController.get(profile.id)
                session_state.dissolve(session_state.id)

        if profile.is_system_profile():

            subjects = ProfileController.get_all().values()

            for subject in subjects:

                if subject.is_session_profile():

                    if subject.connection.is_unprotected() and ProfileController.is_enabled(subject) and not force:
                        raise ConnectionUnprotectedError('Disabling this system connection would leave one or more sessions exposed.')

            if SystemStateController.exists():

                try:
                    ConnectionController.terminate_system_connection(profile)
                except ConnectionTerminationError:
                    raise ProfileDeactivationError('The profile could not be disabled.')

        if profile_observer is not None:

            profile_observer.notify('disabled', profile, dict(
                explicitly=explicitly,
            ))

        time.sleep(1.0)

    @staticmethod
    def destroy(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):

        ProfileController.disable(profile)
        profile.delete()

        if profile_observer is not None:
            profile_observer.notify('destroyed', profile)

    @staticmethod
    def attach_subscription(profile: Union[SessionProfile, SystemProfile], subscription: Subscription):

        profile.subscription = subscription
        profile.save()

    @staticmethod
    def activate_subscription(profile: Union[SessionProfile, SystemProfile], connection_observer: Optional[ConnectionObserver] = None):

        from core.controllers.ConnectionController import ConnectionController

        if profile.has_subscription():

            subscription = ConnectionController.with_preferred_connection(profile.subscription.billing_code, task=WebServiceApiService.get_subscription, connection_observer=connection_observer)

            if subscription is not None:

                profile.subscription = subscription
                profile.save()

            else:
                raise InvalidSubscriptionError()

        else:
            raise MissingSubscriptionError()

    @staticmethod
    def is_enabled(profile: Union[SessionProfile, SystemProfile]):

        from core.controllers.ConnectionController import ConnectionController

        if profile.is_session_profile():

            session_state = SessionStateController.get_or_new(profile.id)
            return len(session_state.network_port_numbers) > 0 or len(session_state.process_ids) > 0

        if profile.is_system_profile():

            system_state = SystemStateController.get()

            if system_state is not None and system_state.profile_id is profile.id:
                return ConnectionController.system_uses_wireguard_interface()

        return False

    @staticmethod
    def get_invoice(profile: Union[SessionProfile, SystemProfile]):

        if profile.has_subscription():
            return WebServiceApiService.get_invoice(profile.subscription.billing_code)
        else:
            return None

    @staticmethod
    def attach_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):

        if profile.is_session_profile() and profile.has_subscription():

            proxy_configuration = WebServiceApiService.get_proxy_configuration(profile.subscription.billing_code)

            if proxy_configuration is not None:
                profile.attach_proxy_configuration(proxy_configuration)

    @staticmethod
    def get_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):

        if profile.is_session_profile():
            return profile.get_proxy_configuration()
        else:
            return None

    @staticmethod
    def has_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):
        profile.has_proxy_configuration()

    @staticmethod
    def register_wireguard_session(profile: Union[SessionProfile, SystemProfile], connection_observer: Optional[ConnectionObserver] = None):

        from core.controllers.ConnectionController import ConnectionController

        if not profile.has_subscription():
            raise MissingSubscriptionError()

        if not profile.has_location():
            raise MissingLocationError()

        wireguard_keys = ProfileController.__generate_wireguard_keys()

        wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)

        if wireguard_configuration is None:
            raise InvalidSubscriptionError()

        expression = re.compile(r'^(PrivateKey =)\s?$', re.MULTILINE)
        wireguard_configuration = re.sub(expression, r'\1 ' + wireguard_keys.get('private'), wireguard_configuration)

        profile.attach_wireguard_configuration(wireguard_configuration)

    @staticmethod
    def get_wireguard_configuration_path(profile: Union[SessionProfile, SystemProfile]):
        return profile.get_wireguard_configuration_path()

    @staticmethod
    def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
        return profile.has_wireguard_configuration()

    @staticmethod
    def __generate_wireguard_keys():

        from cryptography.hazmat.primitives import serialization
        from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey

        raw_private_key = X25519PrivateKey.generate()

        public_key = raw_private_key.public_key().public_bytes(
            encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
        )

        private_key = raw_private_key.private_bytes(
            encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption()
        )

        return dict(
            private=base64.b64encode(private_key).decode(),
            public=base64.b64encode(public_key).decode()
        )