Add support for WireGuard endpoint verification

This commit is contained in:
codeking 2025-11-05 02:24:21 +01:00
parent 70d3905117
commit 18f50c67fb
7 changed files with 128 additions and 24 deletions

View file

@ -84,3 +84,7 @@ class ConnectionUnprotectedError(Exception):
class FileIntegrityError(Exception): class FileIntegrityError(Exception):
pass pass
class EndpointVerificationError(Exception):
pass

View file

@ -54,6 +54,23 @@ class ConfigurationController:
configuration.auto_sync_enabled = enable_auto_sync configuration.auto_sync_enabled = enable_auto_sync
configuration.save() configuration.save()
@staticmethod
def get_endpoint_verification_enabled():
configuration = ConfigurationController.get()
if configuration is None:
return True
return configuration.endpoint_verification_enabled
@staticmethod
def set_endpoint_verification_enabled(enable_endpoint_verification: Optional[bool] = None):
configuration = ConfigurationController.get_or_new()
configuration.endpoint_verification_enabled = enable_endpoint_verification
configuration.save()
@staticmethod @staticmethod
def get_last_synced_at(): def get_last_synced_at():

View file

@ -102,14 +102,14 @@ class ConnectionController:
if profile.is_system_profile(): if profile.is_system_profile():
try: try:
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer) return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
except ConnectionError: except ConnectionError:
if ConnectionController.__should_renegotiate(profile): if ConnectionController.__should_renegotiate(profile):
ProfileController.register_wireguard_session(profile, connection_observer=connection_observer) ProfileController.register_wireguard_session(profile, connection_observer=connection_observer)
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer) return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
else: else:
raise ConnectionError('The connection could not be established.') raise ConnectionError('The connection could not be established.')
@ -142,6 +142,9 @@ class ConnectionController:
elif profile.connection.code == 'wireguard': elif profile.connection.code == 'wireguard':
if ConfigurationController.get_endpoint_verification_enabled():
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
port_number = ConnectionController.get_random_available_port_number() port_number = ConnectionController.get_random_available_port_number()
ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number) ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number)
session_state.network_port_numbers.append(port_number) session_state.network_port_numbers.append(port_number)
@ -162,7 +165,7 @@ class ConnectionController:
return proxy_port_number or port_number return proxy_port_number or port_number
@staticmethod @staticmethod
def establish_system_connection(profile: SystemProfile, connection_observer: Optional[ConnectionObserver] = None): def establish_system_connection(profile: SystemProfile, ignore: tuple[type[Exception]] = (), connection_observer: Optional[ConnectionObserver] = None):
if shutil.which('pkexec') is None: if shutil.which('pkexec') is None:
raise CommandNotFoundError('pkexec') raise CommandNotFoundError('pkexec')
@ -170,6 +173,9 @@ class ConnectionController:
if shutil.which('wg-quick') is None: if shutil.which('wg-quick') is None:
raise CommandNotFoundError('wg-quick') raise CommandNotFoundError('wg-quick')
if ConfigurationController.get_endpoint_verification_enabled():
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8) completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8)

View file

@ -1,4 +1,4 @@
from core.Errors import ProfileStateConflictError, InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError, ConnectionUnprotectedError, EndpointVerificationError, ProfileStateConflictError
from core.controllers.ApplicationController import ApplicationController from core.controllers.ApplicationController import ApplicationController
from core.controllers.ApplicationVersionController import ApplicationVersionController from core.controllers.ApplicationVersionController import ApplicationVersionController
from core.controllers.SessionStateController import SessionStateController from core.controllers.SessionStateController import SessionStateController
@ -240,6 +240,46 @@ class ProfileController:
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]): def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
return profile.has_wireguard_configuration() return profile.has_wireguard_configuration()
@staticmethod
def verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = ()):
try:
ProfileController.__verify_wireguard_endpoint(profile)
except EndpointVerificationError as error:
if not EndpointVerificationError in ignore:
raise error
@staticmethod
def __verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile]):
from cryptography.hazmat.primitives.asymmetric import ed25519
import base64
signature = profile.get_wireguard_configuration_metadata('Signature')
wireguard_public_keys = profile.get_wireguard_public_keys()
operator = profile.location.operator
if signature is None:
raise EndpointVerificationError('The WireGuard endpoint\'s signature could not be determined.')
if not wireguard_public_keys:
raise EndpointVerificationError('The WireGuard endpoint\'s public key could not be determined.')
if operator is None:
raise EndpointVerificationError('The WireGuard endpoint\'s operator could not be determined.')
try:
operator_public_key = ed25519.Ed25519PublicKey.from_public_bytes(bytes.fromhex(operator.public_key))
for wireguard_public_key in wireguard_public_keys:
operator_public_key.verify(base64.b64decode(signature), wireguard_public_key.encode('utf-8'))
except Exception:
raise EndpointVerificationError('The WireGuard endpoint could not be verified.')
@staticmethod @staticmethod
def __generate_wireguard_keys(): def __generate_wireguard_keys():

View file

@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from core.Constants import Constants from core.Constants import Constants
from core.models.Location import Location from core.models.Location import Location
from core.models.Subscription import Subscription from core.models.Subscription import Subscription
@ -14,7 +15,7 @@ import shutil
@dataclass_json @dataclass_json
@dataclass @dataclass
class BaseProfile: class BaseProfile(ABC):
id: int = field( id: int = field(
metadata=config(exclude=Exclude.ALWAYS) metadata=config(exclude=Exclude.ALWAYS)
) )
@ -22,6 +23,10 @@ class BaseProfile:
subscription: Optional[Subscription] subscription: Optional[Subscription]
location: Optional[Location] location: Optional[Location]
@abstractmethod
def get_wireguard_configuration_path(self):
pass
def get_config_path(self): def get_config_path(self):
return BaseProfile.__get_config_path(self.id) return BaseProfile.__get_config_path(self.id)
@ -61,6 +66,49 @@ class BaseProfile:
shutil.rmtree(self.get_config_path(), ignore_errors=True) shutil.rmtree(self.get_config_path(), ignore_errors=True)
self.delete_data() self.delete_data()
def get_wireguard_configuration_metadata(self, key):
configuration = self.get_wireguard_configuration()
if configuration is not None:
for line in configuration.splitlines():
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
if match:
return re.sub(r'[^a-zA-Z0-9+=\-_ /]', '', match.group(1).strip())
return None
def get_wireguard_public_keys(self):
import configparser
wireguard_public_keys = set()
configuration = self.get_wireguard_configuration()
parsed_configuration = configparser.ConfigParser()
if configuration is not None:
parsed_configuration.read_string(configuration)
for section in parsed_configuration.sections():
if parsed_configuration.has_option(section, 'PublicKey'):
wireguard_public_keys.add(parsed_configuration.get(section, 'PublicKey'))
return tuple(wireguard_public_keys)
def get_wireguard_configuration(self):
try:
with open(self.get_wireguard_configuration_path(), 'r') as file:
return file.read()
except FileNotFoundError:
return None
def _get_dirty_keys(self: Self): def _get_dirty_keys(self: Self):
reference = BaseProfile.find_by_id(self.id) reference = BaseProfile.find_by_id(self.id)

View file

@ -28,6 +28,13 @@ class Configuration:
exclude=lambda value: value is None exclude=lambda value: value is None
) )
) )
endpoint_verification_enabled: Optional[bool] = field(
default=False,
metadata=config(
undefined=dataclasses_json.Undefined.EXCLUDE,
exclude=lambda value: value is None
)
)
last_synced_at: Optional[datetime] = field( last_synced_at: Optional[datetime] = field(
default=None, default=None,
metadata=config( metadata=config(

View file

@ -10,7 +10,6 @@ from pathlib import Path
from typing import Optional from typing import Optional
import json import json
import os import os
import re
import shutil import shutil
@ -102,7 +101,7 @@ class SessionProfile(BaseProfile):
elif self.connection.needs_wireguard_configuration(): elif self.connection.needs_wireguard_configuration():
if self.has_wireguard_configuration(): if self.has_wireguard_configuration():
time_zone = self.__get_wireguard_configuration_metadata('TZ') time_zone = self.get_wireguard_configuration_metadata('TZ')
if time_zone is None and self.has_location(): if time_zone is None and self.has_location():
time_zone = self.location.time_zone time_zone = self.location.time_zone
@ -112,23 +111,6 @@ class SessionProfile(BaseProfile):
return time_zone return time_zone
def __get_wireguard_configuration(self):
try:
with open(self.get_wireguard_configuration_path(), 'r') as file:
return file.readlines()
except FileNotFoundError:
return None
def __get_wireguard_configuration_metadata(self, key):
for line in self.__get_wireguard_configuration():
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
if match:
return re.sub(r'[^a-zA-Z0-9+\-_ /]', '', match.group(1).strip())
def __delete_proxy_configuration(self): def __delete_proxy_configuration(self):
Path(self.get_proxy_configuration_path()).unlink(missing_ok=True) Path(self.get_proxy_configuration_path()).unlink(missing_ok=True)