diff --git a/core/controllers/ProfileController.py b/core/controllers/ProfileController.py index 5f6bedf..7809a6d 100644 --- a/core/controllers/ProfileController.py +++ b/core/controllers/ProfileController.py @@ -12,6 +12,8 @@ from core.observers.ConnectionObserver import ConnectionObserver from core.observers.ProfileObserver import ProfileObserver from core.services.WebServiceApiService import WebServiceApiService from typing import Union, Optional +import base64 +import re import time @@ -221,11 +223,16 @@ class ProfileController: if profile.location is None: raise MissingLocationError() - wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) + wireguard_keys = ProfileController.__generate_wireguard_keys() + + wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) if wireguard_configuration is None: raise InvalidSubscriptionError() + expression = re.compile(r'^(PrivateKey =)\s?$', re.MULTILINE) + wireguard_configuration = re.sub(expression, r'\1 ' + wireguard_keys.get('private'), wireguard_configuration) + profile.attach_wireguard_configuration(wireguard_configuration) @staticmethod @@ -235,3 +242,24 @@ class ProfileController: @staticmethod def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]): return profile.has_wireguard_configuration() + + @staticmethod + def __generate_wireguard_keys(): + + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey + + raw_private_key = X25519PrivateKey.generate() + + public_key = raw_private_key.public_key().public_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw + ) + + private_key = raw_private_key.private_bytes( + encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption() + ) + + return dict( + private=base64.b64encode(private_key).decode(), + public=base64.b64encode(public_key).decode() + ) diff --git a/core/services/WebServiceApiService.py b/core/services/WebServiceApiService.py index c1e7583..0fa9b0a 100644 --- a/core/services/WebServiceApiService.py +++ b/core/services/WebServiceApiService.py @@ -135,9 +135,11 @@ class WebServiceApiService: return None @staticmethod - def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None): + def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None): - response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, proxies) + response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, { + 'public_key': public_key, + }, proxies) if response.status_code == requests.codes.created: return response.text diff --git a/pyproject.toml b/pyproject.toml index 339895e..c359af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ classifiers = [ "Operating System :: POSIX :: Linux", ] dependencies = [ + "cryptography ~= 44.0.1", "dataclasses-json ~= 0.6.4", "marshmallow ~= 3.21.1", "psutil ~= 5.9.8",