Implement support for client-side key generation

This commit is contained in:
codeking 2025-02-25 23:00:02 +01:00
parent 5468fe4998
commit ac125b9cb7
3 changed files with 34 additions and 3 deletions

View file

@ -12,6 +12,8 @@ from core.observers.ConnectionObserver import ConnectionObserver
from core.observers.ProfileObserver import ProfileObserver from core.observers.ProfileObserver import ProfileObserver
from core.services.WebServiceApiService import WebServiceApiService from core.services.WebServiceApiService import WebServiceApiService
from typing import Union, Optional from typing import Union, Optional
import base64
import re
import time import time
@ -221,11 +223,16 @@ class ProfileController:
if profile.location is None: if profile.location is None:
raise MissingLocationError() raise MissingLocationError()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) wireguard_keys = ProfileController.__generate_wireguard_keys()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
if wireguard_configuration is None: if wireguard_configuration is None:
raise InvalidSubscriptionError() raise InvalidSubscriptionError()
expression = re.compile(r'^(PrivateKey =)\s?$', re.MULTILINE)
wireguard_configuration = re.sub(expression, r'\1 ' + wireguard_keys.get('private'), wireguard_configuration)
profile.attach_wireguard_configuration(wireguard_configuration) profile.attach_wireguard_configuration(wireguard_configuration)
@staticmethod @staticmethod
@ -235,3 +242,24 @@ class ProfileController:
@staticmethod @staticmethod
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]): def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
return profile.has_wireguard_configuration() return profile.has_wireguard_configuration()
@staticmethod
def __generate_wireguard_keys():
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
raw_private_key = X25519PrivateKey.generate()
public_key = raw_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
private_key = raw_private_key.private_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption()
)
return dict(
private=base64.b64encode(private_key).decode(),
public=base64.b64encode(public_key).decode()
)

View file

@ -135,9 +135,11 @@ class WebServiceApiService:
return None return None
@staticmethod @staticmethod
def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None): def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, proxies) response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, {
'public_key': public_key,
}, proxies)
if response.status_code == requests.codes.created: if response.status_code == requests.codes.created:
return response.text return response.text

View file

@ -12,6 +12,7 @@ classifiers = [
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
] ]
dependencies = [ dependencies = [
"cryptography ~= 44.0.1",
"dataclasses-json ~= 0.6.4", "dataclasses-json ~= 0.6.4",
"marshmallow ~= 3.21.1", "marshmallow ~= 3.21.1",
"psutil ~= 5.9.8", "psutil ~= 5.9.8",