From 6a066ce3f4ce334644fdd7d76d4dd38a7c80c499 Mon Sep 17 00:00:00 2001 From: codeking Date: Thu, 19 Sep 2024 17:54:56 +0200 Subject: [PATCH] Update and refactor existing codebase --- core/controllers/ApplicationController.py | 4 +-- .../ApplicationVersionController.py | 8 ++--- core/controllers/ConnectionController.py | 32 +++++++++---------- core/controllers/InvoiceController.py | 4 +-- core/models/BaseProfile.py | 26 +++++++-------- core/models/session/SessionState.py | 16 +++++----- core/models/system/SystemState.py | 12 +++---- core/services/WebServiceApiService.py | 24 +++++++------- 8 files changed, 63 insertions(+), 63 deletions(-) diff --git a/core/controllers/ApplicationController.py b/core/controllers/ApplicationController.py index 80c0799..3167c74 100644 --- a/core/controllers/ApplicationController.py +++ b/core/controllers/ApplicationController.py @@ -33,7 +33,7 @@ class ApplicationController: if not os.path.isdir(persistent_state_path) or len(os.listdir(persistent_state_path)) == 0: shutil.copytree(f'{version.installation_path}/resources/initial-state', persistent_state_path) - display = ApplicationController._find_unused_display() + display = ApplicationController.__find_unused_display() base_initialization_file_template = open(f'/{Constants.SP_DATA_HOME}/.init.ptpl', 'r').read() base_initialization_file_contents = base_initialization_file_template.format(display=display, time_zone=profile.location.time_zone, sp_data_home=Constants.SP_DATA_HOME) @@ -73,7 +73,7 @@ class ApplicationController: Application.save_many(applications) @staticmethod - def _find_unused_display(): + def __find_unused_display(): file_names = os.listdir('/tmp/.X11-unix') active_displays = [] diff --git a/core/controllers/ApplicationVersionController.py b/core/controllers/ApplicationVersionController.py index ebfd516..9f48d15 100644 --- a/core/controllers/ApplicationVersionController.py +++ b/core/controllers/ApplicationVersionController.py @@ -25,7 +25,7 @@ class ApplicationVersionController: raise UnsupportedApplicationVersionError('The application version in question is not supported.') from core.controllers.ConnectionController import ConnectionController - ConnectionController.with_preferred_connection(application_version, task=ApplicationVersionController._install, application_version_observer=application_version_observer, connection_observer=connection_observer) + ConnectionController.with_preferred_connection(application_version, task=ApplicationVersionController.__install, application_version_observer=application_version_observer, connection_observer=connection_observer) @staticmethod def uninstall(application_version: ApplicationVersion): @@ -48,7 +48,7 @@ class ApplicationVersionController: ApplicationVersion.save_many(application_versions) @staticmethod - def _install(application_version: ApplicationVersion, application_version_observer: Optional[ApplicationVersionObserver] = None, proxies: Optional[dict] = None): + def __install(application_version: ApplicationVersion, application_version_observer: Optional[ApplicationVersionObserver] = None, proxies: Optional[dict] = None): if application_version_observer is not None: application_version_observer.notify('downloading', application_version) @@ -60,7 +60,7 @@ class ApplicationVersionController: if response.status_code == 200: - file_hash = ApplicationVersionController._calculate_file_hash(BytesIO(response.content)) + file_hash = ApplicationVersionController.__calculate_file_hash(BytesIO(response.content)) if file_hash != application_version.file_hash: raise FileIntegrityError('Application version file integrity could not be verified.') @@ -72,7 +72,7 @@ class ApplicationVersionController: raise ConnectionError('The application version could not be downloaded.') @staticmethod - def _calculate_file_hash(file): + def __calculate_file_hash(file): hasher = hashlib.sha3_512() buffer = file.read(65536) diff --git a/core/controllers/ConnectionController.py b/core/controllers/ConnectionController.py index 975e74b..1419aee 100644 --- a/core/controllers/ConnectionController.py +++ b/core/controllers/ConnectionController.py @@ -31,20 +31,7 @@ class ConnectionController: return task(*args, **kwargs) elif connection == 'tor': - return ConnectionController._with_tor_connection(*args, task=task, connection_observer=connection_observer, **kwargs) - - @staticmethod - def _with_tor_connection(*args, task: callable, connection_observer: Optional[ConnectionObserver] = None, **kwargs): - - session_directory = tempfile.mkdtemp(prefix='sp-') - port_number = ConnectionController.get_random_available_port_number() - process = ConnectionController.establish_tor_session_connection(session_directory, port_number) - - ConnectionController.await_connection(port_number, connection_observer=connection_observer) - task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs) - process.terminate() - - return task_output + return ConnectionController.__with_tor_connection(*args, task=task, connection_observer=connection_observer, **kwargs) @staticmethod def establish_connection(profile: Union[SessionProfile, SystemProfile], force: bool = False, connection_observer: Optional[ConnectionObserver] = None): @@ -289,7 +276,7 @@ class ConnectionController: try: - ConnectionController._test_proxy_connection(port_number) + ConnectionController.__test_proxy_connection(port_number) return except ConnectionError: @@ -311,7 +298,20 @@ class ConnectionController: return bool(re.search('dev wg', str(process_output))) @staticmethod - def _test_proxy_connection(port_number: int, timeout: float = 10.0): + def __with_tor_connection(*args, task: callable, connection_observer: Optional[ConnectionObserver] = None, **kwargs): + + session_directory = tempfile.mkdtemp(prefix='sp-') + port_number = ConnectionController.get_random_available_port_number() + process = ConnectionController.establish_tor_session_connection(session_directory, port_number) + + ConnectionController.await_connection(port_number, connection_observer=connection_observer) + task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs) + process.terminate() + + return task_output + + @staticmethod + def __test_proxy_connection(port_number: int, timeout: float = 10.0): timeout = float(timeout) diff --git a/core/controllers/InvoiceController.py b/core/controllers/InvoiceController.py index d57883d..214a113 100644 --- a/core/controllers/InvoiceController.py +++ b/core/controllers/InvoiceController.py @@ -16,10 +16,10 @@ class InvoiceController: def handle_payment(billing_code: str, invoice_observer: InvoiceObserver = None, connection_observer: ConnectionObserver = None): from core.controllers.ConnectionController import ConnectionController - return ConnectionController.with_preferred_connection(billing_code, task=InvoiceController._handle_payment, invoice_observer=invoice_observer, connection_observer=connection_observer) + return ConnectionController.with_preferred_connection(billing_code, task=InvoiceController.__handle_payment, invoice_observer=invoice_observer, connection_observer=connection_observer) @staticmethod - def _handle_payment(billing_code: str, invoice_observer: Optional[InvoiceObserver] = None, proxies: Optional[dict] = None): + def __handle_payment(billing_code: str, invoice_observer: Optional[InvoiceObserver] = None, proxies: Optional[dict] = None): invoice = None diff --git a/core/models/BaseProfile.py b/core/models/BaseProfile.py index 07b6690..cd95375 100644 --- a/core/models/BaseProfile.py +++ b/core/models/BaseProfile.py @@ -23,10 +23,10 @@ class BaseProfile: location: Optional[Location] def get_config_path(self): - return BaseProfile._get_config_path(self.id) + return BaseProfile.__get_config_path(self.id) def get_data_path(self): - return BaseProfile._get_data_path(self.id) + return BaseProfile.__get_data_path(self.id) def has_subscription(self): return self.subscription is not None @@ -52,7 +52,7 @@ class BaseProfile: def find_by_id(id: int): try: - config_file_contents = open(BaseProfile._get_config_path(id) + '/config.json', 'r').read() + config_file_contents = open(BaseProfile.__get_config_path(id) + '/config.json', 'r').read() except FileNotFoundError: return None @@ -96,7 +96,7 @@ class BaseProfile: @staticmethod def exists(id: int): - return os.path.isdir(BaseProfile._get_config_path(id)) and re.match(r'^\d+$', str(id)) + return os.path.isdir(BaseProfile.__get_config_path(id)) and re.match(r'^\d+$', str(id)) @staticmethod def all(): @@ -107,7 +107,7 @@ class BaseProfile: if BaseProfile.exists(id): - if os.path.exists(BaseProfile._get_config_path(id) + '/config.json'): + if os.path.exists(BaseProfile.__get_config_path(id) + '/config.json'): profile = BaseProfile.find_by_id(id) @@ -121,29 +121,29 @@ class BaseProfile: if profile.is_session_profile() and profile.application_version is not None: - persistent_state_path = f'{BaseProfile._get_data_path(profile.id)}/persistent-state' + persistent_state_path = f'{BaseProfile.__get_data_path(profile.id)}/persistent-state' if os.path.isdir(persistent_state_path): shutil.rmtree(persistent_state_path) config_file_contents = profile.to_json(indent=4) + '\n' - os.makedirs(BaseProfile._get_config_path(profile.id), exist_ok=True) - os.makedirs(BaseProfile._get_data_path(profile.id), exist_ok=True) + os.makedirs(BaseProfile.__get_config_path(profile.id), exist_ok=True) + os.makedirs(BaseProfile.__get_data_path(profile.id), exist_ok=True) - text_io_wrapper = open(BaseProfile._get_config_path(profile.id) + '/config.json', 'w') + text_io_wrapper = open(BaseProfile.__get_config_path(profile.id) + '/config.json', 'w') text_io_wrapper.write(config_file_contents) @staticmethod def delete(profile): - shutil.rmtree(BaseProfile._get_config_path(profile.id)) - shutil.rmtree(BaseProfile._get_data_path(profile.id)) + shutil.rmtree(BaseProfile.__get_config_path(profile.id)) + shutil.rmtree(BaseProfile.__get_data_path(profile.id)) @staticmethod - def _get_config_path(id: int): + def __get_config_path(id: int): return Constants.SP_PROFILE_CONFIG_HOME + '/' + str(id) @staticmethod - def _get_data_path(id: int): + def __get_data_path(id: int): return Constants.SP_PROFILE_DATA_HOME + '/' + str(id) diff --git a/core/models/session/SessionState.py b/core/models/session/SessionState.py index adbe188..b904a7c 100644 --- a/core/models/session/SessionState.py +++ b/core/models/session/SessionState.py @@ -20,13 +20,13 @@ class SessionState: process_ids: list[int] = field(default_factory=list) def get_state_path(self): - return SessionState._get_state_path(self.id) + return SessionState.__get_state_path(self.id) @staticmethod def find_by_id(id: int): try: - session_state_file_contents = open(SessionState._get_state_path(id) + '/state.json', 'r').read() + session_state_file_contents = open(SessionState.__get_state_path(id) + '/state.json', 'r').read() except FileNotFoundError: return None @@ -42,15 +42,15 @@ class SessionState: @staticmethod def exists(id: int): - return os.path.isdir(SessionState._get_state_path(id)) and re.match(r'^\d+$', str(id)) + return os.path.isdir(SessionState.__get_state_path(id)) and re.match(r'^\d+$', str(id)) @staticmethod def save(session_state): session_state_file_contents = session_state.to_json(indent=4) + '\n' - os.makedirs(SessionState._get_state_path(session_state.id), exist_ok=True, mode=0o700) + os.makedirs(SessionState.__get_state_path(session_state.id), exist_ok=True, mode=0o700) - session_state_file_path = SessionState._get_state_path(session_state.id) + '/state.json' + session_state_file_path = SessionState.__get_state_path(session_state.id) + '/state.json' Path(session_state_file_path).touch(exist_ok=True, mode=0o600) text_io_wrapper = open(session_state_file_path, 'w') @@ -65,11 +65,11 @@ class SessionState: session_state_path = session_state.get_state_path() - SessionState._kill_associated_processes(session_state) + SessionState.__kill_associated_processes(session_state) shutil.rmtree(session_state_path, ignore_errors=True) @staticmethod - def _kill_associated_processes(session_state): + def __kill_associated_processes(session_state): associated_process_ids = list(session_state.process_ids) network_connections = psutil.net_connections() @@ -98,5 +98,5 @@ class SessionState: process.kill() @staticmethod - def _get_state_path(id: int): + def __get_state_path(id: int): return Constants.SP_SESSION_STATE_HOME + '/' + str(id) diff --git a/core/models/system/SystemState.py b/core/models/system/SystemState.py index e3218c6..caceb8b 100644 --- a/core/models/system/SystemState.py +++ b/core/models/system/SystemState.py @@ -16,7 +16,7 @@ class SystemState: def get(): try: - system_state_file_contents = open(SystemState._get_state_path() + '/system.json', 'r').read() + system_state_file_contents = open(SystemState.__get_state_path() + '/system.json', 'r').read() except FileNotFoundError: return None @@ -30,15 +30,15 @@ class SystemState: @staticmethod def exists(): - return os.path.isfile(SystemState._get_state_path() + '/system.json') + return os.path.isfile(SystemState.__get_state_path() + '/system.json') @staticmethod def save(system_state): system_state_file_contents = system_state.to_json(indent=4) + '\n' - os.makedirs(SystemState._get_state_path(), exist_ok=True, mode=0o700) + os.makedirs(SystemState.__get_state_path(), exist_ok=True, mode=0o700) - system_state_file_path = SystemState._get_state_path() + '/system.json' + system_state_file_path = SystemState.__get_state_path() + '/system.json' Path(system_state_file_path).touch(exist_ok=True, mode=0o600) text_io_wrapper = open(system_state_file_path, 'w') @@ -51,9 +51,9 @@ class SystemState: if system_state is not None: - system_state_path = SystemState._get_state_path() + '/system.json' + system_state_path = SystemState.__get_state_path() + '/system.json' pathlib.Path.unlink(Path(system_state_path), missing_ok=True) @staticmethod - def _get_state_path(): + def __get_state_path(): return Constants.SP_STATE_HOME diff --git a/core/services/WebServiceApiService.py b/core/services/WebServiceApiService.py index 0e68ab7..f54aeac 100644 --- a/core/services/WebServiceApiService.py +++ b/core/services/WebServiceApiService.py @@ -18,7 +18,7 @@ class WebServiceApiService: @staticmethod def get_applications(proxies: Optional[dict] = None): - response = WebServiceApiService._get('/platforms/linux-x86_64/applications', None, proxies) + response = WebServiceApiService.__get('/platforms/linux-x86_64/applications', None, proxies) applications = [] if response.status_code == requests.codes.ok: @@ -30,7 +30,7 @@ class WebServiceApiService: @staticmethod def get_application_versions(code: str, proxies: Optional[dict] = None): - response = WebServiceApiService._get('/platforms/linux-x86_64/applications/' + code + '/application-versions', None, proxies) + response = WebServiceApiService.__get('/platforms/linux-x86_64/applications/' + code + '/application-versions', None, proxies) application_versions = [] if response.status_code == requests.codes.ok: @@ -42,7 +42,7 @@ class WebServiceApiService: @staticmethod def get_client_versions(proxies: Optional[dict] = None): - response = WebServiceApiService._get('/platforms/linux-x86_64/client-versions', None, proxies) + response = WebServiceApiService.__get('/platforms/linux-x86_64/client-versions', None, proxies) client_versions = [] if response.status_code == requests.codes.ok: @@ -54,7 +54,7 @@ class WebServiceApiService: @staticmethod def get_locations(proxies: Optional[dict] = None): - response = WebServiceApiService._get('/locations', None, proxies) + response = WebServiceApiService.__get('/locations', None, proxies) locations = [] if response.status_code == requests.codes.ok: @@ -66,7 +66,7 @@ class WebServiceApiService: @staticmethod def get_subscription_plans(proxies: Optional[dict] = None): - response = WebServiceApiService._get('/subscription-plans', None, proxies) + response = WebServiceApiService.__get('/subscription-plans', None, proxies) subscription_plans = [] if response.status_code == requests.codes.ok: @@ -78,7 +78,7 @@ class WebServiceApiService: @staticmethod def post_subscription(subscription_plan_id, location_id, proxies: Optional[dict] = None): - response = WebServiceApiService._post('/subscriptions', None, { + response = WebServiceApiService.__post('/subscriptions', None, { 'subscription_plan_id': subscription_plan_id, 'location_id': location_id }, proxies) @@ -95,7 +95,7 @@ class WebServiceApiService: billing_code_fragments = re.findall('....?', billing_code) billing_code = '-'.join(billing_code_fragments) - response = WebServiceApiService._get('/subscriptions/current', billing_code, proxies) + response = WebServiceApiService.__get('/subscriptions/current', billing_code, proxies) if response.status_code == requests.codes.ok: @@ -105,7 +105,7 @@ class WebServiceApiService: @staticmethod def get_invoice(billing_code: str, proxies: Optional[dict] = None): - response = WebServiceApiService._get('/invoices/current', billing_code, proxies) + response = WebServiceApiService.__get('/invoices/current', billing_code, proxies) if response.status_code == requests.codes.ok: @@ -126,7 +126,7 @@ class WebServiceApiService: @staticmethod def get_proxy_configuration(billing_code: str, proxies: Optional[dict] = None): - response = WebServiceApiService._get('/proxy-configurations/current', billing_code, proxies) + response = WebServiceApiService.__get('/proxy-configurations/current', billing_code, proxies) if response.status_code == requests.codes.ok: @@ -139,7 +139,7 @@ class WebServiceApiService: @staticmethod def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None): - response = WebServiceApiService._post('/locations/' + location_code + '/wireguard-sessions', billing_code, proxies) + response = WebServiceApiService.__post('/locations/' + location_code + '/wireguard-sessions', billing_code, proxies) if response.status_code == requests.codes.created: return response.text @@ -147,7 +147,7 @@ class WebServiceApiService: return None @staticmethod - def _get(path, billing_code: Optional[str] = None, proxies: Optional[dict] = None): + def __get(path, billing_code: Optional[str] = None, proxies: Optional[dict] = None): if billing_code is not None: headers = {'X-Billing-Code': billing_code} @@ -157,7 +157,7 @@ class WebServiceApiService: return requests.get(Constants.SP_API_BASE_URL + path, headers=headers, proxies=proxies) @staticmethod - def _post(path, billing_code: Optional[str] = None, body: Optional[dict] = None, proxies: Optional[dict] = None): + def __post(path, billing_code: Optional[str] = None, body: Optional[dict] = None, proxies: Optional[dict] = None): if billing_code is not None: headers = {'X-Billing-Code': billing_code}