sp-hydra-veil-core/core/controllers/ProfileController.py

304 lines
12 KiB
Python

from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError, ConnectionUnprotectedError, EndpointVerificationError, ProfileStateConflictError
from core.controllers.ApplicationController import ApplicationController
from core.controllers.ApplicationVersionController import ApplicationVersionController
from core.controllers.SessionStateController import SessionStateController
from core.controllers.SystemStateController import SystemStateController
from core.models.BaseProfile import BaseProfile as Profile
from core.models.Subscription import Subscription
from core.models.session.SessionProfile import SessionProfile
from core.models.system.SystemProfile import SystemProfile
from core.observers.ApplicationVersionObserver import ApplicationVersionObserver
from core.observers.ConnectionObserver import ConnectionObserver
from core.observers.ProfileObserver import ProfileObserver
from core.services.WebServiceApiService import WebServiceApiService
from typing import Union, Optional
import base64
import re
import time
class ProfileController:
@staticmethod
def get(id: int) -> Union[SessionProfile, SystemProfile, None]:
return Profile.find_by_id(id)
@staticmethod
def get_all():
return Profile.all()
@staticmethod
def create(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):
profile.save()
if profile_observer is not None:
profile_observer.notify('created', profile)
@staticmethod
def update(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):
profile.save()
if profile_observer is not None:
profile_observer.notify('updated', profile)
@staticmethod
def enable(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = (), pristine: bool = False, asynchronous: bool = False, profile_observer: ProfileObserver = None, application_version_observer: ApplicationVersionObserver = None, connection_observer: ConnectionObserver = None):
from core.controllers.ConnectionController import ConnectionController
if ProfileController.is_enabled(profile):
if not ProfileStateConflictError in ignore:
raise ProfileStateConflictError('The profile is already enabled or its session was not properly terminated.')
else:
ProfileController.disable(profile)
if pristine:
profile.delete_data()
if profile.is_session_profile():
application_version = profile.application_version
if not application_version.is_installed():
ApplicationVersionController.install(application_version, application_version_observer=application_version_observer, connection_observer=connection_observer)
try:
port_number = ConnectionController.establish_connection(profile, ignore=ignore, connection_observer=connection_observer)
except ConnectionError:
raise ProfileActivationError('The profile could not be enabled.')
if profile_observer is not None:
profile_observer.notify('enabled', profile)
ApplicationController.launch(application_version, profile, port_number, asynchronous=asynchronous, profile_observer=profile_observer)
if profile.is_system_profile():
try:
ConnectionController.establish_connection(profile, ignore=ignore, connection_observer=connection_observer)
except ConnectionError:
raise ProfileActivationError('The profile could not be enabled.')
if profile_observer is not None:
profile_observer.notify('enabled', profile)
@staticmethod
def disable(profile: Union[SessionProfile, SystemProfile], explicitly: bool = True, ignore: tuple[type[Exception]] = (), profile_observer: ProfileObserver = None):
from core.controllers.ConnectionController import ConnectionController
if profile.is_session_profile():
if SessionStateController.exists(profile.id):
session_state = SessionStateController.get(profile.id)
if session_state is not None:
session_state.dissolve(session_state.id)
if profile.is_system_profile():
subjects = ProfileController.get_all().values()
for subject in subjects:
if subject.is_session_profile():
if subject.connection.is_unprotected() and ProfileController.is_enabled(subject) and not ConnectionUnprotectedError in ignore:
raise ConnectionUnprotectedError('Disabling this system connection would leave one or more sessions exposed.')
if SystemStateController.exists():
try:
ConnectionController.terminate_system_connection(profile)
except ConnectionTerminationError:
raise ProfileDeactivationError('The profile could not be disabled.')
if profile_observer is not None:
profile_observer.notify('disabled', profile, dict(
explicitly=explicitly,
))
time.sleep(1.0)
@staticmethod
def destroy(profile: Union[SessionProfile, SystemProfile], profile_observer: ProfileObserver = None):
ProfileController.disable(profile)
profile.delete()
if profile_observer is not None:
profile_observer.notify('destroyed', profile)
@staticmethod
def attach_subscription(profile: Union[SessionProfile, SystemProfile], subscription: Subscription):
profile.subscription = subscription
profile.save()
@staticmethod
def activate_subscription(profile: Union[SessionProfile, SystemProfile], connection_observer: Optional[ConnectionObserver] = None):
from core.controllers.ConnectionController import ConnectionController
if profile.has_subscription():
subscription = ConnectionController.with_preferred_connection(profile.subscription.billing_code, task=WebServiceApiService.get_subscription, connection_observer=connection_observer)
if subscription is not None:
profile.subscription = subscription
profile.save()
else:
raise InvalidSubscriptionError()
else:
raise MissingSubscriptionError()
@staticmethod
def is_enabled(profile: Union[SessionProfile, SystemProfile]):
from core.controllers.ConnectionController import ConnectionController
if profile.is_session_profile():
session_state = SessionStateController.get_or_new(profile.id)
return len(session_state.network_port_numbers) > 0 or len(session_state.process_ids) > 0
if profile.is_system_profile():
system_state = SystemStateController.get()
if system_state is not None and system_state.profile_id is profile.id:
return ConnectionController.system_uses_wireguard_interface()
return False
@staticmethod
def get_invoice(profile: Union[SessionProfile, SystemProfile]):
if profile.has_subscription():
return WebServiceApiService.get_invoice(profile.subscription.billing_code)
else:
return None
@staticmethod
def attach_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):
if profile.is_session_profile() and profile.has_subscription():
proxy_configuration = WebServiceApiService.get_proxy_configuration(profile.subscription.billing_code)
if proxy_configuration is not None:
profile.attach_proxy_configuration(proxy_configuration)
@staticmethod
def get_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):
if profile.is_session_profile():
return profile.get_proxy_configuration()
else:
return None
@staticmethod
def has_proxy_configuration(profile: Union[SessionProfile, SystemProfile]):
profile.has_proxy_configuration()
@staticmethod
def register_wireguard_session(profile: Union[SessionProfile, SystemProfile], connection_observer: Optional[ConnectionObserver] = None):
from core.controllers.ConnectionController import ConnectionController
if not profile.has_subscription():
raise MissingSubscriptionError()
if not profile.has_location():
raise MissingLocationError()
wireguard_keys = ProfileController.__generate_wireguard_keys()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
if wireguard_configuration is None:
raise InvalidSubscriptionError()
expression = re.compile(r'^(PrivateKey =)\s?$', re.MULTILINE)
wireguard_configuration = re.sub(expression, r'\1 ' + wireguard_keys.get('private'), wireguard_configuration)
profile.attach_wireguard_configuration(wireguard_configuration)
@staticmethod
def get_wireguard_configuration_path(profile: Union[SessionProfile, SystemProfile]):
return profile.get_wireguard_configuration_path()
@staticmethod
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
return profile.has_wireguard_configuration()
@staticmethod
def verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = ()):
try:
ProfileController.__verify_wireguard_endpoint(profile)
except EndpointVerificationError as error:
if not EndpointVerificationError in ignore:
profile.address_security_incident()
raise error
@staticmethod
def __verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile]):
from cryptography.hazmat.primitives.asymmetric import ed25519
import base64
signature = profile.get_wireguard_configuration_metadata('Signature')
wireguard_public_keys = profile.get_wireguard_public_keys()
operator = profile.location.operator
if signature is None:
raise EndpointVerificationError('The WireGuard endpoint\'s signature could not be determined.')
if not wireguard_public_keys:
raise EndpointVerificationError('The WireGuard endpoint\'s public key could not be determined.')
if operator is None:
raise EndpointVerificationError('The WireGuard endpoint\'s operator could not be determined.')
try:
operator_public_key = ed25519.Ed25519PublicKey.from_public_bytes(bytes.fromhex(operator.public_key))
for wireguard_public_key in wireguard_public_keys:
operator_public_key.verify(base64.b64decode(signature), wireguard_public_key.encode('utf-8'))
except Exception:
raise EndpointVerificationError('The WireGuard endpoint could not be verified.')
@staticmethod
def __generate_wireguard_keys():
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
raw_private_key = X25519PrivateKey.generate()
public_key = raw_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
private_key = raw_private_key.private_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption()
)
return dict(
private=base64.b64encode(private_key).decode(),
public=base64.b64encode(public_key).decode()
)