Add support for WireGuard endpoint verification
This commit is contained in:
parent
70d3905117
commit
18f50c67fb
7 changed files with 128 additions and 24 deletions
|
|
@ -84,3 +84,7 @@ class ConnectionUnprotectedError(Exception):
|
||||||
|
|
||||||
class FileIntegrityError(Exception):
|
class FileIntegrityError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointVerificationError(Exception):
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,23 @@ class ConfigurationController:
|
||||||
configuration.auto_sync_enabled = enable_auto_sync
|
configuration.auto_sync_enabled = enable_auto_sync
|
||||||
configuration.save()
|
configuration.save()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_endpoint_verification_enabled():
|
||||||
|
|
||||||
|
configuration = ConfigurationController.get()
|
||||||
|
|
||||||
|
if configuration is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return configuration.endpoint_verification_enabled
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_endpoint_verification_enabled(enable_endpoint_verification: Optional[bool] = None):
|
||||||
|
|
||||||
|
configuration = ConfigurationController.get_or_new()
|
||||||
|
configuration.endpoint_verification_enabled = enable_endpoint_verification
|
||||||
|
configuration.save()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_last_synced_at():
|
def get_last_synced_at():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,14 +102,14 @@ class ConnectionController:
|
||||||
if profile.is_system_profile():
|
if profile.is_system_profile():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer)
|
return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
|
||||||
|
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
|
|
||||||
if ConnectionController.__should_renegotiate(profile):
|
if ConnectionController.__should_renegotiate(profile):
|
||||||
|
|
||||||
ProfileController.register_wireguard_session(profile, connection_observer=connection_observer)
|
ProfileController.register_wireguard_session(profile, connection_observer=connection_observer)
|
||||||
return ConnectionController.establish_system_connection(profile, connection_observer=connection_observer)
|
return ConnectionController.establish_system_connection(profile, ignore=ignore, connection_observer=connection_observer)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ConnectionError('The connection could not be established.')
|
raise ConnectionError('The connection could not be established.')
|
||||||
|
|
@ -142,6 +142,9 @@ class ConnectionController:
|
||||||
|
|
||||||
elif profile.connection.code == 'wireguard':
|
elif profile.connection.code == 'wireguard':
|
||||||
|
|
||||||
|
if ConfigurationController.get_endpoint_verification_enabled():
|
||||||
|
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
|
||||||
|
|
||||||
port_number = ConnectionController.get_random_available_port_number()
|
port_number = ConnectionController.get_random_available_port_number()
|
||||||
ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number)
|
ConnectionController.establish_wireguard_session_connection(profile, session_directory, port_number)
|
||||||
session_state.network_port_numbers.append(port_number)
|
session_state.network_port_numbers.append(port_number)
|
||||||
|
|
@ -162,7 +165,7 @@ class ConnectionController:
|
||||||
return proxy_port_number or port_number
|
return proxy_port_number or port_number
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def establish_system_connection(profile: SystemProfile, connection_observer: Optional[ConnectionObserver] = None):
|
def establish_system_connection(profile: SystemProfile, ignore: tuple[type[Exception]] = (), connection_observer: Optional[ConnectionObserver] = None):
|
||||||
|
|
||||||
if shutil.which('pkexec') is None:
|
if shutil.which('pkexec') is None:
|
||||||
raise CommandNotFoundError('pkexec')
|
raise CommandNotFoundError('pkexec')
|
||||||
|
|
@ -170,6 +173,9 @@ class ConnectionController:
|
||||||
if shutil.which('wg-quick') is None:
|
if shutil.which('wg-quick') is None:
|
||||||
raise CommandNotFoundError('wg-quick')
|
raise CommandNotFoundError('wg-quick')
|
||||||
|
|
||||||
|
if ConfigurationController.get_endpoint_verification_enabled():
|
||||||
|
ProfileController.verify_wireguard_endpoint(profile, ignore=ignore)
|
||||||
|
|
||||||
process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
process = subprocess.Popen(('pkexec', 'wg-quick', 'up', profile.get_wireguard_configuration_path()), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||||
completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8)
|
completed_successfully = not bool(os.waitpid(process.pid, 0)[1] >> 8)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from core.Errors import ProfileStateConflictError, InvalidSubscriptionError, MissingSubscriptionError, ConnectionUnprotectedError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError
|
from core.Errors import InvalidSubscriptionError, MissingSubscriptionError, ConnectionTerminationError, ProfileActivationError, ProfileDeactivationError, MissingLocationError, ConnectionUnprotectedError, EndpointVerificationError, ProfileStateConflictError
|
||||||
from core.controllers.ApplicationController import ApplicationController
|
from core.controllers.ApplicationController import ApplicationController
|
||||||
from core.controllers.ApplicationVersionController import ApplicationVersionController
|
from core.controllers.ApplicationVersionController import ApplicationVersionController
|
||||||
from core.controllers.SessionStateController import SessionStateController
|
from core.controllers.SessionStateController import SessionStateController
|
||||||
|
|
@ -240,6 +240,46 @@ class ProfileController:
|
||||||
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
|
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
|
||||||
return profile.has_wireguard_configuration()
|
return profile.has_wireguard_configuration()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile], ignore: tuple[type[Exception]] = ()):
|
||||||
|
|
||||||
|
try:
|
||||||
|
ProfileController.__verify_wireguard_endpoint(profile)
|
||||||
|
|
||||||
|
except EndpointVerificationError as error:
|
||||||
|
|
||||||
|
if not EndpointVerificationError in ignore:
|
||||||
|
raise error
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __verify_wireguard_endpoint(profile: Union[SessionProfile, SystemProfile]):
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
import base64
|
||||||
|
|
||||||
|
signature = profile.get_wireguard_configuration_metadata('Signature')
|
||||||
|
wireguard_public_keys = profile.get_wireguard_public_keys()
|
||||||
|
operator = profile.location.operator
|
||||||
|
|
||||||
|
if signature is None:
|
||||||
|
raise EndpointVerificationError('The WireGuard endpoint\'s signature could not be determined.')
|
||||||
|
|
||||||
|
if not wireguard_public_keys:
|
||||||
|
raise EndpointVerificationError('The WireGuard endpoint\'s public key could not be determined.')
|
||||||
|
|
||||||
|
if operator is None:
|
||||||
|
raise EndpointVerificationError('The WireGuard endpoint\'s operator could not be determined.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
operator_public_key = ed25519.Ed25519PublicKey.from_public_bytes(bytes.fromhex(operator.public_key))
|
||||||
|
|
||||||
|
for wireguard_public_key in wireguard_public_keys:
|
||||||
|
operator_public_key.verify(base64.b64decode(signature), wireguard_public_key.encode('utf-8'))
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise EndpointVerificationError('The WireGuard endpoint could not be verified.')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __generate_wireguard_keys():
|
def __generate_wireguard_keys():
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from core.Constants import Constants
|
from core.Constants import Constants
|
||||||
from core.models.Location import Location
|
from core.models.Location import Location
|
||||||
from core.models.Subscription import Subscription
|
from core.models.Subscription import Subscription
|
||||||
|
|
@ -14,7 +15,7 @@ import shutil
|
||||||
|
|
||||||
@dataclass_json
|
@dataclass_json
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseProfile:
|
class BaseProfile(ABC):
|
||||||
id: int = field(
|
id: int = field(
|
||||||
metadata=config(exclude=Exclude.ALWAYS)
|
metadata=config(exclude=Exclude.ALWAYS)
|
||||||
)
|
)
|
||||||
|
|
@ -22,6 +23,10 @@ class BaseProfile:
|
||||||
subscription: Optional[Subscription]
|
subscription: Optional[Subscription]
|
||||||
location: Optional[Location]
|
location: Optional[Location]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_wireguard_configuration_path(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def get_config_path(self):
|
def get_config_path(self):
|
||||||
return BaseProfile.__get_config_path(self.id)
|
return BaseProfile.__get_config_path(self.id)
|
||||||
|
|
||||||
|
|
@ -61,6 +66,49 @@ class BaseProfile:
|
||||||
shutil.rmtree(self.get_config_path(), ignore_errors=True)
|
shutil.rmtree(self.get_config_path(), ignore_errors=True)
|
||||||
self.delete_data()
|
self.delete_data()
|
||||||
|
|
||||||
|
def get_wireguard_configuration_metadata(self, key):
|
||||||
|
|
||||||
|
configuration = self.get_wireguard_configuration()
|
||||||
|
|
||||||
|
if configuration is not None:
|
||||||
|
|
||||||
|
for line in configuration.splitlines():
|
||||||
|
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
return re.sub(r'[^a-zA-Z0-9+=\-_ /]', '', match.group(1).strip())
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_wireguard_public_keys(self):
|
||||||
|
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
wireguard_public_keys = set()
|
||||||
|
|
||||||
|
configuration = self.get_wireguard_configuration()
|
||||||
|
parsed_configuration = configparser.ConfigParser()
|
||||||
|
|
||||||
|
if configuration is not None:
|
||||||
|
|
||||||
|
parsed_configuration.read_string(configuration)
|
||||||
|
|
||||||
|
for section in parsed_configuration.sections():
|
||||||
|
|
||||||
|
if parsed_configuration.has_option(section, 'PublicKey'):
|
||||||
|
wireguard_public_keys.add(parsed_configuration.get(section, 'PublicKey'))
|
||||||
|
|
||||||
|
return tuple(wireguard_public_keys)
|
||||||
|
|
||||||
|
def get_wireguard_configuration(self):
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.get_wireguard_configuration_path(), 'r') as file:
|
||||||
|
return file.read()
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
def _get_dirty_keys(self: Self):
|
def _get_dirty_keys(self: Self):
|
||||||
|
|
||||||
reference = BaseProfile.find_by_id(self.id)
|
reference = BaseProfile.find_by_id(self.id)
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,13 @@ class Configuration:
|
||||||
exclude=lambda value: value is None
|
exclude=lambda value: value is None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
endpoint_verification_enabled: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata=config(
|
||||||
|
undefined=dataclasses_json.Undefined.EXCLUDE,
|
||||||
|
exclude=lambda value: value is None
|
||||||
|
)
|
||||||
|
)
|
||||||
last_synced_at: Optional[datetime] = field(
|
last_synced_at: Optional[datetime] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=config(
|
metadata=config(
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -102,7 +101,7 @@ class SessionProfile(BaseProfile):
|
||||||
elif self.connection.needs_wireguard_configuration():
|
elif self.connection.needs_wireguard_configuration():
|
||||||
|
|
||||||
if self.has_wireguard_configuration():
|
if self.has_wireguard_configuration():
|
||||||
time_zone = self.__get_wireguard_configuration_metadata('TZ')
|
time_zone = self.get_wireguard_configuration_metadata('TZ')
|
||||||
|
|
||||||
if time_zone is None and self.has_location():
|
if time_zone is None and self.has_location():
|
||||||
time_zone = self.location.time_zone
|
time_zone = self.location.time_zone
|
||||||
|
|
@ -112,23 +111,6 @@ class SessionProfile(BaseProfile):
|
||||||
|
|
||||||
return time_zone
|
return time_zone
|
||||||
|
|
||||||
def __get_wireguard_configuration(self):
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(self.get_wireguard_configuration_path(), 'r') as file:
|
|
||||||
return file.readlines()
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __get_wireguard_configuration_metadata(self, key):
|
|
||||||
|
|
||||||
for line in self.__get_wireguard_configuration():
|
|
||||||
match = re.match(r'^# {} = (.*)$'.format(re.escape(key)), line)
|
|
||||||
|
|
||||||
if match:
|
|
||||||
return re.sub(r'[^a-zA-Z0-9+\-_ /]', '', match.group(1).strip())
|
|
||||||
|
|
||||||
def __delete_proxy_configuration(self):
|
def __delete_proxy_configuration(self):
|
||||||
Path(self.get_proxy_configuration_path()).unlink(missing_ok=True)
|
Path(self.get_proxy_configuration_path()).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue