diff --git a/core/Errors.py b/core/Errors.py index 08adbc7..7131434 100644 --- a/core/Errors.py +++ b/core/Errors.py @@ -84,3 +84,7 @@ class ConnectionUnprotectedError(Exception): class FileIntegrityError(Exception): pass + + +class EndpointVerificationError(Exception): + pass diff --git a/core/controllers/ConfigurationController.py b/core/controllers/ConfigurationController.py index 98424a9..3dc0d09 100644 --- a/core/controllers/ConfigurationController.py +++ b/core/controllers/ConfigurationController.py @@ -54,6 +54,23 @@ class ConfigurationController: configuration.auto_sync_enabled = enable_auto_sync configuration.save() + @staticmethod + def get_endpoint_verification_enabled(): + + configuration = ConfigurationController.get() + + if configuration is None: + return True + + return configuration.endpoint_verification_enabled + + @staticmethod + def set_endpoint_verification_enabled(enable_endpoint_verification: Optional[bool] = None): + + configuration = ConfigurationController.get_or_new() + configuration.endpoint_verification_enabled = enable_endpoint_verification + configuration.save() + @staticmethod def get_last_synced_at(): diff --git a/core/controllers/ConnectionController.py b/core/controllers/ConnectionController.py index deeaedc..a1d5a85 100644 --- a/core/controllers/ConnectionController.py +++ b/core/controllers/ConnectionController.py @@ -102,14 +102,14 @@ class ConnectionController: if profile.is_system_profile(): try: - return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer) + return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer) except ConnectionError: if ConnectionController.__should_renegotiate(profile): ProfileController.register_wireguard_session(profile, connection_observer=connection_observer) - return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer) + return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer) else: raise ConnectionError('The connection could not be established.') @@ -142,6 +142,9 @@ class ConnectionController: elif profile.connection.code == 'wireguard': + if ConfigurationController.get_endpoint_verification_enabled(): + ProfileController.verify_wireguard_endpoint(profile, ignore=ignore) + port_number = ConnectionController.get_random_available_port_number() ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number) session_state.network_port_numbers.append(port_number) @@ -162,7 +165,7 @@ class ConnectionController: return proxy_port_number or port_number @staticmethod - def establish_system_connection(profile: SystemProfile, connection_observer: Optional[ConnectionObserver] = None): + def establish_system_connection(profile: SystemProfile, ignore: tuple[type[Exception]] = (), connection_observer: Optional[ConnectionObserver] = None): if shutil.which('pkexec') is None: raise CommandNotFoundError('pkexec') @@ -170,6 +173,9 @@ class ConnectionController: if shutil.which('wg-quick') is None: raise CommandNotFoundError('wg-quick') + if ConfigurationController.get_endpoint_verification_enabled(): + ProfileController.verify_wireguard_endpoint(profile, ignore=ignore) + process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8) diff --git a/core/controllers/ProfileController.py b/core/controllers/ProfileController.py index 91376a6..e64f16a 100644 --- a/core/controllers/ProfileController.py +++ b/core/controllers/ProfileController.py @@ -1,4 +1,4 @@ -from core.Errors import ProfileStateConflictError, InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError +from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError, ConnectionUnprotectedError, EndpointVerificationError, ProfileStateConflictError from core.controllers.ApplicationController import ApplicationController from core.controllers.ApplicationVersionController import ApplicationVersionController from core.controllers.SessionStateController import SessionStateController @@ -240,6 +240,46 @@ class ProfileController: def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]): return profile.has_wireguard_configuration() + @staticmethod + def verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = ()): + + try: + ProfileController.__verify_wireguard_endpoint(profile) + + except EndpointVerificationError as error: + + if not EndpointVerificationError in ignore: + raise error + + @staticmethod + def __verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile]): + + from cryptography.hazmat.primitives.asymmetric import ed25519 + import base64 + + signature = profile.get_wireguard_configuration_metadata('Signature') + wireguard_public_keys = profile.get_wireguard_public_keys() + operator = profile.location.operator + + if signature is None: + raise EndpointVerificationError('The WireGuard endpoint\'s signature could not be determined.') + + if not wireguard_public_keys: + raise EndpointVerificationError('The WireGuard endpoint\'s public key could not be determined.') + + if operator is None: + raise EndpointVerificationError('The WireGuard endpoint\'s operator could not be determined.') + + try: + + operator_public_key = ed25519.Ed25519PublicKey.from_public_bytes(bytes.fromhex(operator.public_key)) + + for wireguard_public_key in wireguard_public_keys: + operator_public_key.verify(base64.b64decode(signature), wireguard_public_key.encode('utf-8')) + + except Exception: + raise EndpointVerificationError('The WireGuard endpoint could not be verified.') + @staticmethod def __generate_wireguard_keys(): diff --git a/core/models/BaseProfile.py b/core/models/BaseProfile.py index 5dd92e7..0949be4 100644 --- a/core/models/BaseProfile.py +++ b/core/models/BaseProfile.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from core.Constants import Constants from core.models.Location import Location from core.models.Subscription import Subscription @@ -14,7 +15,7 @@ import shutil @dataclass_json @dataclass -class BaseProfile: +class BaseProfile(ABC): id: int = field( metadata=config(exclude=Exclude.ALWAYS) ) @@ -22,6 +23,10 @@ class BaseProfile: subscription: Optional[Subscription] location: Optional[Location] + @abstractmethod + def get_wireguard_configuration_path(self): + pass + def get_config_path(self): return BaseProfile.__get_config_path(self.id) @@ -61,6 +66,49 @@ class BaseProfile: shutil.rmtree(self.get_config_path(), ignore_errors=True) self.delete_data() + def get_wireguard_configuration_metadata(self, key): + + configuration = self.get_wireguard_configuration() + + if configuration is not None: + + for line in configuration.splitlines(): + match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line) + + if match: + return re.sub(r'[^a-zA-Z0-9+=\-_ /]', '', match.group(1).strip()) + + return None + + def get_wireguard_public_keys(self): + + import configparser + + wireguard_public_keys = set() + + configuration = self.get_wireguard_configuration() + parsed_configuration = configparser.ConfigParser() + + if configuration is not None: + + parsed_configuration.read_string(configuration) + + for section in parsed_configuration.sections(): + + if parsed_configuration.has_option(section, 'PublicKey'): + wireguard_public_keys.add(parsed_configuration.get(section, 'PublicKey')) + + return tuple(wireguard_public_keys) + + def get_wireguard_configuration(self): + + try: + with open(self.get_wireguard_configuration_path(), 'r') as file: + return file.read() + + except FileNotFoundError: + return None + def _get_dirty_keys(self: Self): reference = BaseProfile.find_by_id(self.id) diff --git a/core/models/Configuration.py b/core/models/Configuration.py index 3272cc0..b80593c 100644 --- a/core/models/Configuration.py +++ b/core/models/Configuration.py @@ -28,6 +28,13 @@ class Configuration: exclude=lambda value: value is None ) ) + endpoint_verification_enabled: Optional[bool] = field( + default=False, + metadata=config( + undefined=dataclasses_json.Undefined.EXCLUDE, + exclude=lambda value: value is None + ) + ) last_synced_at: Optional[datetime] = field( default=None, metadata=config( diff --git a/core/models/session/SessionProfile.py b/core/models/session/SessionProfile.py index 9fc6ce4..dc6ec35 100644 --- a/core/models/session/SessionProfile.py +++ b/core/models/session/SessionProfile.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Optional import json import os -import re import shutil @@ -102,7 +101,7 @@ class SessionProfile(BaseProfile): elif self.connection.needs_wireguard_configuration(): if self.has_wireguard_configuration(): - time_zone = self.__get_wireguard_configuration_metadata('TZ') + time_zone = self.get_wireguard_configuration_metadata('TZ') if time_zone is None and self.has_location(): time_zone = self.location.time_zone @@ -112,23 +111,6 @@ class SessionProfile(BaseProfile): return time_zone - def __get_wireguard_configuration(self): - - try: - with open(self.get_wireguard_configuration_path(), 'r') as file: - return file.readlines() - - except FileNotFoundError: - return None - - def __get_wireguard_configuration_metadata(self, key): - - for line in self.__get_wireguard_configuration(): - match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line) - - if match: - return re.sub(r'[^a-zA-Z0-9+\-_ /]', '', match.group(1).strip()) - def __delete_proxy_configuration(self): Path(self.get_proxy_configuration_path()).unlink(missing_ok=True)