Add support for WireGuard endpoint verification

This commit is contained in:
codeking 2025-11-05 02:24:21 +01:00
parent 70d3905117
commit 18f50c67fb
7 changed files with 128 additions and 24 deletions

View file

@ -84,3 +84,7 @@ class ConnectionUnprotectedError(Exception):
class FileIntegrityError(Exception):
pass
class EndpointVerificationError(Exception):
pass

View file

@ -54,6 +54,23 @@ class ConfigurationController:
configuration.auto_sync_enabled = enable_auto_sync
configuration.save()
@staticmethod
def get_endpoint_verification_enabled():
configuration = ConfigurationController.get()
if configuration is None:
return True
return configuration.endpoint_verification_enabled
@staticmethod
def set_endpoint_verification_enabled(enable_endpoint_verification: Optional[bool] = None):
configuration = ConfigurationController.get_or_new()
configuration.endpoint_verification_enabled = enable_endpoint_verification
configuration.save()
@staticmethod
def get_last_synced_at():

View file

@ -102,14 +102,14 @@ class ConnectionController:
if profile.is_system_profile():
try:
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer)
return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
except ConnectionError:
if ConnectionController.__should_renegotiate(profile):
ProfileController.register_wireguard_session(profile, connection_observer=connection_observer)
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer)
return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
else:
raise ConnectionError('The connection could not be established.')
@ -142,6 +142,9 @@ class ConnectionController:
elif profile.connection.code == 'wireguard':
if ConfigurationController.get_endpoint_verification_enabled():
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
port_number = ConnectionController.get_random_available_port_number()
ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number)
session_state.network_port_numbers.append(port_number)
@ -162,7 +165,7 @@ class ConnectionController:
return proxy_port_number or port_number
@staticmethod
def establish_system_connection(profile: SystemProfile, connection_observer: Optional[ConnectionObserver] = None):
def establish_system_connection(profile: SystemProfile, ignore: tuple[type[Exception]] = (), connection_observer: Optional[ConnectionObserver] = None):
if shutil.which('pkexec') is None:
raise CommandNotFoundError('pkexec')
@ -170,6 +173,9 @@ class ConnectionController:
if shutil.which('wg-quick') is None:
raise CommandNotFoundError('wg-quick')
if ConfigurationController.get_endpoint_verification_enabled():
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8)

View file

@ -1,4 +1,4 @@
from core.Errors import ProfileStateConflictError, InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError
from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError, ConnectionUnprotectedError, EndpointVerificationError, ProfileStateConflictError
from core.controllers.ApplicationController import ApplicationController
from core.controllers.ApplicationVersionController import ApplicationVersionController
from core.controllers.SessionStateController import SessionStateController
@ -240,6 +240,46 @@ class ProfileController:
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
return profile.has_wireguard_configuration()
@staticmethod
def verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = ()):
try:
ProfileController.__verify_wireguard_endpoint(profile)
except EndpointVerificationError as error:
if not EndpointVerificationError in ignore:
raise error
@staticmethod
def __verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile]):
from cryptography.hazmat.primitives.asymmetric import ed25519
import base64
signature = profile.get_wireguard_configuration_metadata('Signature')
wireguard_public_keys = profile.get_wireguard_public_keys()
operator = profile.location.operator
if signature is None:
raise EndpointVerificationError('The WireGuard endpoint\'s signature could not be determined.')
if not wireguard_public_keys:
raise EndpointVerificationError('The WireGuard endpoint\'s public key could not be determined.')
if operator is None:
raise EndpointVerificationError('The WireGuard endpoint\'s operator could not be determined.')
try:
operator_public_key = ed25519.Ed25519PublicKey.from_public_bytes(bytes.fromhex(operator.public_key))
for wireguard_public_key in wireguard_public_keys:
operator_public_key.verify(base64.b64decode(signature), wireguard_public_key.encode('utf-8'))
except Exception:
raise EndpointVerificationError('The WireGuard endpoint could not be verified.')
@staticmethod
def __generate_wireguard_keys():

View file

@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from core.Constants import Constants
from core.models.Location import Location
from core.models.Subscription import Subscription
@ -14,7 +15,7 @@ import shutil
@dataclass_json
@dataclass
class BaseProfile:
class BaseProfile(ABC):
id: int = field(
metadata=config(exclude=Exclude.ALWAYS)
)
@ -22,6 +23,10 @@ class BaseProfile:
subscription: Optional[Subscription]
location: Optional[Location]
@abstractmethod
def get_wireguard_configuration_path(self):
pass
def get_config_path(self):
return BaseProfile.__get_config_path(self.id)
@ -61,6 +66,49 @@ class BaseProfile:
shutil.rmtree(self.get_config_path(), ignore_errors=True)
self.delete_data()
def get_wireguard_configuration_metadata(self, key):
configuration = self.get_wireguard_configuration()
if configuration is not None:
for line in configuration.splitlines():
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
if match:
return re.sub(r'[^a-zA-Z0-9+=\-_ /]', '', match.group(1).strip())
return None
def get_wireguard_public_keys(self):
import configparser
wireguard_public_keys = set()
configuration = self.get_wireguard_configuration()
parsed_configuration = configparser.ConfigParser()
if configuration is not None:
parsed_configuration.read_string(configuration)
for section in parsed_configuration.sections():
if parsed_configuration.has_option(section, 'PublicKey'):
wireguard_public_keys.add(parsed_configuration.get(section, 'PublicKey'))
return tuple(wireguard_public_keys)
def get_wireguard_configuration(self):
try:
with open(self.get_wireguard_configuration_path(), 'r') as file:
return file.read()
except FileNotFoundError:
return None
def _get_dirty_keys(self: Self):
reference = BaseProfile.find_by_id(self.id)

View file

@ -28,6 +28,13 @@ class Configuration:
exclude=lambda value: value is None
)
)
endpoint_verification_enabled: Optional[bool] = field(
default=False,
metadata=config(
undefined=dataclasses_json.Undefined.EXCLUDE,
exclude=lambda value: value is None
)
)
last_synced_at: Optional[datetime] = field(
default=None,
metadata=config(

View file

@ -10,7 +10,6 @@ from pathlib import Path
from typing import Optional
import json
import os
import re
import shutil
@ -102,7 +101,7 @@ class SessionProfile(BaseProfile):
elif self.connection.needs_wireguard_configuration():
if self.has_wireguard_configuration():
time_zone = self.__get_wireguard_configuration_metadata('TZ')
time_zone = self.get_wireguard_configuration_metadata('TZ')
if time_zone is None and self.has_location():
time_zone = self.location.time_zone
@ -112,23 +111,6 @@ class SessionProfile(BaseProfile):
return time_zone
def __get_wireguard_configuration(self):
try:
with open(self.get_wireguard_configuration_path(), 'r') as file:
return file.readlines()
except FileNotFoundError:
return None
def __get_wireguard_configuration_metadata(self, key):
for line in self.__get_wireguard_configuration():
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
if match:
return re.sub(r'[^a-zA-Z0-9+\-_ /]', '', match.group(1).strip())
def __delete_proxy_configuration(self):
Path(self.get_proxy_configuration_path()).unlink(missing_ok=True)