Update and refactor existing codebase

This commit is contained in:
codeking 2024-09-19 17:54:56 +02:00
parent 864c9c3242
commit 6a066ce3f4
8 changed files with 63 additions and 63 deletions

View file

@ -33,7 +33,7 @@ class ApplicationController:
if not os.path.isdir(persistent_state_path) or len(os.listdir(persistent_state_path)) == 0: if not os.path.isdir(persistent_state_path) or len(os.listdir(persistent_state_path)) == 0:
shutil.copytree(f'{version.installation_path}/resources/initial-state', persistent_state_path) shutil.copytree(f'{version.installation_path}/resources/initial-state', persistent_state_path)
display = ApplicationController._find_unused_display() display = ApplicationController.__find_unused_display()
base_initialization_file_template = open(f'/{Constants.SP_DATA_HOME}/.init.ptpl', 'r').read() base_initialization_file_template = open(f'/{Constants.SP_DATA_HOME}/.init.ptpl', 'r').read()
base_initialization_file_contents = base_initialization_file_template.format(display=display, time_zone=profile.location.time_zone, sp_data_home=Constants.SP_DATA_HOME) base_initialization_file_contents = base_initialization_file_template.format(display=display, time_zone=profile.location.time_zone, sp_data_home=Constants.SP_DATA_HOME)
@ -73,7 +73,7 @@ class ApplicationController:
Application.save_many(applications) Application.save_many(applications)
@staticmethod @staticmethod
def _find_unused_display(): def __find_unused_display():
file_names = os.listdir('/tmp/.X11-unix') file_names = os.listdir('/tmp/.X11-unix')
active_displays = [] active_displays = []

View file

@ -25,7 +25,7 @@ class ApplicationVersionController:
raise UnsupportedApplicationVersionError('The application version in question is not supported.') raise UnsupportedApplicationVersionError('The application version in question is not supported.')
from core.controllers.ConnectionController import ConnectionController from core.controllers.ConnectionController import ConnectionController
ConnectionController.with_preferred_connection(application_version, task=ApplicationVersionController._install, application_version_observer=application_version_observer, connection_observer=connection_observer) ConnectionController.with_preferred_connection(application_version, task=ApplicationVersionController.__install, application_version_observer=application_version_observer, connection_observer=connection_observer)
@staticmethod @staticmethod
def uninstall(application_version: ApplicationVersion): def uninstall(application_version: ApplicationVersion):
@ -48,7 +48,7 @@ class ApplicationVersionController:
ApplicationVersion.save_many(application_versions) ApplicationVersion.save_many(application_versions)
@staticmethod @staticmethod
def _install(application_version: ApplicationVersion, application_version_observer: Optional[ApplicationVersionObserver] = None, proxies: Optional[dict] = None): def __install(application_version: ApplicationVersion, application_version_observer: Optional[ApplicationVersionObserver] = None, proxies: Optional[dict] = None):
if application_version_observer is not None: if application_version_observer is not None:
application_version_observer.notify('downloading', application_version) application_version_observer.notify('downloading', application_version)
@ -60,7 +60,7 @@ class ApplicationVersionController:
if response.status_code == 200: if response.status_code == 200:
file_hash = ApplicationVersionController._calculate_file_hash(BytesIO(response.content)) file_hash = ApplicationVersionController.__calculate_file_hash(BytesIO(response.content))
if file_hash != application_version.file_hash: if file_hash != application_version.file_hash:
raise FileIntegrityError('Application version file integrity could not be verified.') raise FileIntegrityError('Application version file integrity could not be verified.')
@ -72,7 +72,7 @@ class ApplicationVersionController:
raise ConnectionError('The application version could not be downloaded.') raise ConnectionError('The application version could not be downloaded.')
@staticmethod @staticmethod
def _calculate_file_hash(file): def __calculate_file_hash(file):
hasher = hashlib.sha3_512() hasher = hashlib.sha3_512()
buffer = file.read(65536) buffer = file.read(65536)

View file

@ -31,20 +31,7 @@ class ConnectionController:
return task(*args, **kwargs) return task(*args, **kwargs)
elif connection == 'tor': elif connection == 'tor':
return ConnectionController._with_tor_connection(*args, task=task, connection_observer=connection_observer, **kwargs) return ConnectionController.__with_tor_connection(*args, task=task, connection_observer=connection_observer, **kwargs)
@staticmethod
def _with_tor_connection(*args, task: callable, connection_observer: Optional[ConnectionObserver] = None, **kwargs):
session_directory = tempfile.mkdtemp(prefix='sp-')
port_number = ConnectionController.get_random_available_port_number()
process = ConnectionController.establish_tor_session_connection(session_directory, port_number)
ConnectionController.await_connection(port_number, connection_observer=connection_observer)
task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs)
process.terminate()
return task_output
@staticmethod @staticmethod
def establish_connection(profile: Union[SessionProfile, SystemProfile], force: bool = False, connection_observer: Optional[ConnectionObserver] = None): def establish_connection(profile: Union[SessionProfile, SystemProfile], force: bool = False, connection_observer: Optional[ConnectionObserver] = None):
@ -289,7 +276,7 @@ class ConnectionController:
try: try:
ConnectionController._test_proxy_connection(port_number) ConnectionController.__test_proxy_connection(port_number)
return return
except ConnectionError: except ConnectionError:
@ -311,7 +298,20 @@ class ConnectionController:
return bool(re.search('dev wg', str(process_output))) return bool(re.search('dev wg', str(process_output)))
@staticmethod @staticmethod
def _test_proxy_connection(port_number: int, timeout: float = 10.0): def __with_tor_connection(*args, task: callable, connection_observer: Optional[ConnectionObserver] = None, **kwargs):
session_directory = tempfile.mkdtemp(prefix='sp-')
port_number = ConnectionController.get_random_available_port_number()
process = ConnectionController.establish_tor_session_connection(session_directory, port_number)
ConnectionController.await_connection(port_number, connection_observer=connection_observer)
task_output = task(*args, proxies=ConnectionController.get_proxies(port_number), **kwargs)
process.terminate()
return task_output
@staticmethod
def __test_proxy_connection(port_number: int, timeout: float = 10.0):
timeout = float(timeout) timeout = float(timeout)

View file

@ -16,10 +16,10 @@ class InvoiceController:
def handle_payment(billing_code: str, invoice_observer: InvoiceObserver = None, connection_observer: ConnectionObserver = None): def handle_payment(billing_code: str, invoice_observer: InvoiceObserver = None, connection_observer: ConnectionObserver = None):
from core.controllers.ConnectionController import ConnectionController from core.controllers.ConnectionController import ConnectionController
return ConnectionController.with_preferred_connection(billing_code, task=InvoiceController._handle_payment, invoice_observer=invoice_observer, connection_observer=connection_observer) return ConnectionController.with_preferred_connection(billing_code, task=InvoiceController.__handle_payment, invoice_observer=invoice_observer, connection_observer=connection_observer)
@staticmethod @staticmethod
def _handle_payment(billing_code: str, invoice_observer: Optional[InvoiceObserver] = None, proxies: Optional[dict] = None): def __handle_payment(billing_code: str, invoice_observer: Optional[InvoiceObserver] = None, proxies: Optional[dict] = None):
invoice = None invoice = None

View file

@ -23,10 +23,10 @@ class BaseProfile:
location: Optional[Location] location: Optional[Location]
def get_config_path(self): def get_config_path(self):
return BaseProfile._get_config_path(self.id) return BaseProfile.__get_config_path(self.id)
def get_data_path(self): def get_data_path(self):
return BaseProfile._get_data_path(self.id) return BaseProfile.__get_data_path(self.id)
def has_subscription(self): def has_subscription(self):
return self.subscription is not None return self.subscription is not None
@ -52,7 +52,7 @@ class BaseProfile:
def find_by_id(id: int): def find_by_id(id: int):
try: try:
config_file_contents = open(BaseProfile._get_config_path(id) + '/config.json', 'r').read() config_file_contents = open(BaseProfile.__get_config_path(id) + '/config.json', 'r').read()
except FileNotFoundError: except FileNotFoundError:
return None return None
@ -96,7 +96,7 @@ class BaseProfile:
@staticmethod @staticmethod
def exists(id: int): def exists(id: int):
return os.path.isdir(BaseProfile._get_config_path(id)) and re.match(r'^\d+$', str(id)) return os.path.isdir(BaseProfile.__get_config_path(id)) and re.match(r'^\d+$', str(id))
@staticmethod @staticmethod
def all(): def all():
@ -107,7 +107,7 @@ class BaseProfile:
if BaseProfile.exists(id): if BaseProfile.exists(id):
if os.path.exists(BaseProfile._get_config_path(id) + '/config.json'): if os.path.exists(BaseProfile.__get_config_path(id) + '/config.json'):
profile = BaseProfile.find_by_id(id) profile = BaseProfile.find_by_id(id)
@ -121,29 +121,29 @@ class BaseProfile:
if profile.is_session_profile() and profile.application_version is not None: if profile.is_session_profile() and profile.application_version is not None:
persistent_state_path = f'{BaseProfile._get_data_path(profile.id)}/persistent-state' persistent_state_path = f'{BaseProfile.__get_data_path(profile.id)}/persistent-state'
if os.path.isdir(persistent_state_path): if os.path.isdir(persistent_state_path):
shutil.rmtree(persistent_state_path) shutil.rmtree(persistent_state_path)
config_file_contents = profile.to_json(indent=4) + '\n' config_file_contents = profile.to_json(indent=4) + '\n'
os.makedirs(BaseProfile._get_config_path(profile.id), exist_ok=True) os.makedirs(BaseProfile.__get_config_path(profile.id), exist_ok=True)
os.makedirs(BaseProfile._get_data_path(profile.id), exist_ok=True) os.makedirs(BaseProfile.__get_data_path(profile.id), exist_ok=True)
text_io_wrapper = open(BaseProfile._get_config_path(profile.id) + '/config.json', 'w') text_io_wrapper = open(BaseProfile.__get_config_path(profile.id) + '/config.json', 'w')
text_io_wrapper.write(config_file_contents) text_io_wrapper.write(config_file_contents)
@staticmethod @staticmethod
def delete(profile): def delete(profile):
shutil.rmtree(BaseProfile._get_config_path(profile.id)) shutil.rmtree(BaseProfile.__get_config_path(profile.id))
shutil.rmtree(BaseProfile._get_data_path(profile.id)) shutil.rmtree(BaseProfile.__get_data_path(profile.id))
@staticmethod @staticmethod
def _get_config_path(id: int): def __get_config_path(id: int):
return Constants.SP_PROFILE_CONFIG_HOME + '/' + str(id) return Constants.SP_PROFILE_CONFIG_HOME + '/' + str(id)
@staticmethod @staticmethod
def _get_data_path(id: int): def __get_data_path(id: int):
return Constants.SP_PROFILE_DATA_HOME + '/' + str(id) return Constants.SP_PROFILE_DATA_HOME + '/' + str(id)

View file

@ -20,13 +20,13 @@ class SessionState:
process_ids: list[int] = field(default_factory=list) process_ids: list[int] = field(default_factory=list)
def get_state_path(self): def get_state_path(self):
return SessionState._get_state_path(self.id) return SessionState.__get_state_path(self.id)
@staticmethod @staticmethod
def find_by_id(id: int): def find_by_id(id: int):
try: try:
session_state_file_contents = open(SessionState._get_state_path(id) + '/state.json', 'r').read() session_state_file_contents = open(SessionState.__get_state_path(id) + '/state.json', 'r').read()
except FileNotFoundError: except FileNotFoundError:
return None return None
@ -42,15 +42,15 @@ class SessionState:
@staticmethod @staticmethod
def exists(id: int): def exists(id: int):
return os.path.isdir(SessionState._get_state_path(id)) and re.match(r'^\d+$', str(id)) return os.path.isdir(SessionState.__get_state_path(id)) and re.match(r'^\d+$', str(id))
@staticmethod @staticmethod
def save(session_state): def save(session_state):
session_state_file_contents = session_state.to_json(indent=4) + '\n' session_state_file_contents = session_state.to_json(indent=4) + '\n'
os.makedirs(SessionState._get_state_path(session_state.id), exist_ok=True, mode=0o700) os.makedirs(SessionState.__get_state_path(session_state.id), exist_ok=True, mode=0o700)
session_state_file_path = SessionState._get_state_path(session_state.id) + '/state.json' session_state_file_path = SessionState.__get_state_path(session_state.id) + '/state.json'
Path(session_state_file_path).touch(exist_ok=True, mode=0o600) Path(session_state_file_path).touch(exist_ok=True, mode=0o600)
text_io_wrapper = open(session_state_file_path, 'w') text_io_wrapper = open(session_state_file_path, 'w')
@ -65,11 +65,11 @@ class SessionState:
session_state_path = session_state.get_state_path() session_state_path = session_state.get_state_path()
SessionState._kill_associated_processes(session_state) SessionState.__kill_associated_processes(session_state)
shutil.rmtree(session_state_path, ignore_errors=True) shutil.rmtree(session_state_path, ignore_errors=True)
@staticmethod @staticmethod
def _kill_associated_processes(session_state): def __kill_associated_processes(session_state):
associated_process_ids = list(session_state.process_ids) associated_process_ids = list(session_state.process_ids)
network_connections = psutil.net_connections() network_connections = psutil.net_connections()
@ -98,5 +98,5 @@ class SessionState:
process.kill() process.kill()
@staticmethod @staticmethod
def _get_state_path(id: int): def __get_state_path(id: int):
return Constants.SP_SESSION_STATE_HOME + '/' + str(id) return Constants.SP_SESSION_STATE_HOME + '/' + str(id)

View file

@ -16,7 +16,7 @@ class SystemState:
def get(): def get():
try: try:
system_state_file_contents = open(SystemState._get_state_path() + '/system.json', 'r').read() system_state_file_contents = open(SystemState.__get_state_path() + '/system.json', 'r').read()
except FileNotFoundError: except FileNotFoundError:
return None return None
@ -30,15 +30,15 @@ class SystemState:
@staticmethod @staticmethod
def exists(): def exists():
return os.path.isfile(SystemState._get_state_path() + '/system.json') return os.path.isfile(SystemState.__get_state_path() + '/system.json')
@staticmethod @staticmethod
def save(system_state): def save(system_state):
system_state_file_contents = system_state.to_json(indent=4) + '\n' system_state_file_contents = system_state.to_json(indent=4) + '\n'
os.makedirs(SystemState._get_state_path(), exist_ok=True, mode=0o700) os.makedirs(SystemState.__get_state_path(), exist_ok=True, mode=0o700)
system_state_file_path = SystemState._get_state_path() + '/system.json' system_state_file_path = SystemState.__get_state_path() + '/system.json'
Path(system_state_file_path).touch(exist_ok=True, mode=0o600) Path(system_state_file_path).touch(exist_ok=True, mode=0o600)
text_io_wrapper = open(system_state_file_path, 'w') text_io_wrapper = open(system_state_file_path, 'w')
@ -51,9 +51,9 @@ class SystemState:
if system_state is not None: if system_state is not None:
system_state_path = SystemState._get_state_path() + '/system.json' system_state_path = SystemState.__get_state_path() + '/system.json'
pathlib.Path.unlink(Path(system_state_path), missing_ok=True) pathlib.Path.unlink(Path(system_state_path), missing_ok=True)
@staticmethod @staticmethod
def _get_state_path(): def __get_state_path():
return Constants.SP_STATE_HOME return Constants.SP_STATE_HOME

View file

@ -18,7 +18,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_applications(proxies: Optional[dict] = None): def get_applications(proxies: Optional[dict] = None):
response = WebServiceApiService._get('/platforms/linux-x86_64/applications', None, proxies) response = WebServiceApiService.__get('/platforms/linux-x86_64/applications', None, proxies)
applications = [] applications = []
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -30,7 +30,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_application_versions(code: str, proxies: Optional[dict] = None): def get_application_versions(code: str, proxies: Optional[dict] = None):
response = WebServiceApiService._get('/platforms/linux-x86_64/applications/' + code + '/application-versions', None, proxies) response = WebServiceApiService.__get('/platforms/linux-x86_64/applications/' + code + '/application-versions', None, proxies)
application_versions = [] application_versions = []
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -42,7 +42,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_client_versions(proxies: Optional[dict] = None): def get_client_versions(proxies: Optional[dict] = None):
response = WebServiceApiService._get('/platforms/linux-x86_64/client-versions', None, proxies) response = WebServiceApiService.__get('/platforms/linux-x86_64/client-versions', None, proxies)
client_versions = [] client_versions = []
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -54,7 +54,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_locations(proxies: Optional[dict] = None): def get_locations(proxies: Optional[dict] = None):
response = WebServiceApiService._get('/locations', None, proxies) response = WebServiceApiService.__get('/locations', None, proxies)
locations = [] locations = []
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -66,7 +66,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_subscription_plans(proxies: Optional[dict] = None): def get_subscription_plans(proxies: Optional[dict] = None):
response = WebServiceApiService._get('/subscription-plans', None, proxies) response = WebServiceApiService.__get('/subscription-plans', None, proxies)
subscription_plans = [] subscription_plans = []
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -78,7 +78,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def post_subscription(subscription_plan_id, location_id, proxies: Optional[dict] = None): def post_subscription(subscription_plan_id, location_id, proxies: Optional[dict] = None):
response = WebServiceApiService._post('/subscriptions', None, { response = WebServiceApiService.__post('/subscriptions', None, {
'subscription_plan_id': subscription_plan_id, 'subscription_plan_id': subscription_plan_id,
'location_id': location_id 'location_id': location_id
}, proxies) }, proxies)
@ -95,7 +95,7 @@ class WebServiceApiService:
billing_code_fragments = re.findall('....?', billing_code) billing_code_fragments = re.findall('....?', billing_code)
billing_code = '-'.join(billing_code_fragments) billing_code = '-'.join(billing_code_fragments)
response = WebServiceApiService._get('/subscriptions/current', billing_code, proxies) response = WebServiceApiService.__get('/subscriptions/current', billing_code, proxies)
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -105,7 +105,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_invoice(billing_code: str, proxies: Optional[dict] = None): def get_invoice(billing_code: str, proxies: Optional[dict] = None):
response = WebServiceApiService._get('/invoices/current', billing_code, proxies) response = WebServiceApiService.__get('/invoices/current', billing_code, proxies)
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -126,7 +126,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def get_proxy_configuration(billing_code: str, proxies: Optional[dict] = None): def get_proxy_configuration(billing_code: str, proxies: Optional[dict] = None):
response = WebServiceApiService._get('/proxy-configurations/current', billing_code, proxies) response = WebServiceApiService.__get('/proxy-configurations/current', billing_code, proxies)
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
@ -139,7 +139,7 @@ class WebServiceApiService:
@staticmethod @staticmethod
def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None): def post_wireguard_session(location_code: str, billing_code: str, proxies: Optional[dict] = None):
response = WebServiceApiService._post('/locations/' + location_code + '/wireguard-sessions', billing_code, proxies) response = WebServiceApiService.__post('/locations/' + location_code + '/wireguard-sessions', billing_code, proxies)
if response.status_code == requests.codes.created: if response.status_code == requests.codes.created:
return response.text return response.text
@ -147,7 +147,7 @@ class WebServiceApiService:
return None return None
@staticmethod @staticmethod
def _get(path, billing_code: Optional[str] = None, proxies: Optional[dict] = None): def __get(path, billing_code: Optional[str] = None, proxies: Optional[dict] = None):
if billing_code is not None: if billing_code is not None:
headers = {'X-Billing-Code': billing_code} headers = {'X-Billing-Code': billing_code}
@ -157,7 +157,7 @@ class WebServiceApiService:
return requests.get(Constants.SP_API_BASE_URL + path, headers=headers, proxies=proxies) return requests.get(Constants.SP_API_BASE_URL + path, headers=headers, proxies=proxies)
@staticmethod @staticmethod
def _post(path, billing_code: Optional[str] = None, body: Optional[dict] = None, proxies: Optional[dict] = None): def __post(path, billing_code: Optional[str] = None, body: Optional[dict] = None, proxies: Optional[dict] = None):
if billing_code is not None: if billing_code is not None:
headers = {'X-Billing-Code': billing_code} headers = {'X-Billing-Code': billing_code}