Implement support for client-side key generation

This commit is contained in:
codeking 2025-02-25 23:00:02 +01:00
parent 5468fe4998
commit ac125b9cb7
3 changed files with 34 additions and 3 deletions

View file

@ -12,6 +12,8 @@ from core.observers.ConnectionObserver import ConnectionObserver
from core.observers.ProfileObserver import ProfileObserver
from core.services.WebServiceApiService import WebServiceApiService
from typing import Union, Optional
import base64
import re
import time
@ -221,11 +223,16 @@ class ProfileController:
if profile.location is None:
raise MissingLocationError()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
wireguard_keys = ProfileController.__generate_wireguard_keys()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
if wireguard_configuration is None:
raise InvalidSubscriptionError()
expression = re.compile(r'^(PrivateKey =)\s?$', re.MULTILINE)
wireguard_configuration = re.sub(expression, r'\1 ' + wireguard_keys.get('private'), wireguard_configuration)
profile.attach_wireguard_configuration(wireguard_configuration)
@staticmethod
@ -235,3 +242,24 @@ class ProfileController:
@staticmethod
def has_wireguard_configuration(profile: Union[SessionProfile, SystemProfile]):
return profile.has_wireguard_configuration()
@staticmethod
def __generate_wireguard_keys():
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
raw_private_key = X25519PrivateKey.generate()
public_key = raw_private_key.public_key().public_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw
)
private_key = raw_private_key.private_bytes(
encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption()
)
return dict(
private=base64.b64encode(private_key).decode(),
public=base64.b64encode(public_key).decode()
)

View file

@ -135,9 +135,11 @@ class WebServiceApiService:
return None
@staticmethod
def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None):
def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, proxies)
response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, {
'public_key': public_key,
}, proxies)
if response.status_code == requests.codes.created:
return response.text

View file

@ -12,6 +12,7 @@ classifiers = [
"Operating System :: POSIX :: Linux",
]
dependencies = [
"cryptography ~= 44.0.1",
"dataclasses-json ~= 0.6.4",
"marshmallow ~= 3.21.1",
"psutil ~= 5.9.8",