From fd6948a6a31936ad109640cb8051002397b25495 Mon Sep 17 00:00:00 2001 From: codeking Date: Wed, 12 Mar 2025 22:04:57 +0100 Subject: [PATCH] Implement support for country subdivisions --- core/controllers/LocationController.py | 4 ++-- core/controllers/ProfileController.py | 2 +- core/models/BaseProfile.py | 4 ++-- core/models/Location.py | 31 +++++++++++++++++--------- core/services/WebServiceApiService.py | 6 ++--- 5 files changed, 28 insertions(+), 19 deletions(-) diff --git a/core/controllers/LocationController.py b/core/controllers/LocationController.py index 7ecf844..c9d84b2 100644 --- a/core/controllers/LocationController.py +++ b/core/controllers/LocationController.py @@ -6,8 +6,8 @@ from typing import Optional class LocationController: @staticmethod - def get(code: str): - return Location.find(code) + def get(country_code: str, code: str): + return Location.find(country_code, code) @staticmethod def get_all(): diff --git a/core/controllers/ProfileController.py b/core/controllers/ProfileController.py index 2c4be14..52e8a5c 100644 --- a/core/controllers/ProfileController.py +++ b/core/controllers/ProfileController.py @@ -225,7 +225,7 @@ class ProfileController: wireguard_keys = ProfileController.__generate_wireguard_keys() - wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) + wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) if wireguard_configuration is None: raise InvalidSubscriptionError() diff --git a/core/models/BaseProfile.py b/core/models/BaseProfile.py index 78fbc7b..f4d76f1 100644 --- a/core/models/BaseProfile.py +++ b/core/models/BaseProfile.py @@ -84,11 +84,11 @@ class BaseProfile: if profile['location'] is not None: - location = Location.find(profile['location']['code'] or None) + location = Location.find(profile['location']['country_code'] or None, profile['location']['code'] or None) if location is not None: - if profile['location']['time_zone'] is not None: + if profile['location'].get('time_zone') is not None: location.time_zone = profile['location']['time_zone'] profile['location'] = location diff --git a/core/models/Location.py b/core/models/Location.py index f80a855..69b37de 100644 --- a/core/models/Location.py +++ b/core/models/Location.py @@ -8,18 +8,27 @@ _table_name: str = 'locations' _table_definition: str = """ 'id' int UNIQUE, - 'code' varchar UNIQUE, - 'name' varchar UNIQUE + 'country_code' varchar, + 'country_name' varchar, + 'code' varchar, + 'name' varchar, + 'time_zone' varchar, + UNIQUE(code, country_code) """ @dataclass class Location(Model): + country_code: str code: str id: Optional[int] = field( default=None, metadata=config(exclude=Exclude.ALWAYS) ) + country_name: Optional[str] = field( + default=None, + metadata=config(exclude=Exclude.ALWAYS) + ) name: Optional[str] = field( default=None, metadata=config(exclude=Exclude.ALWAYS) @@ -33,12 +42,12 @@ class Location(Model): def __post_init__(self): if self.time_zone is None: - self.time_zone = country_timezones[self.code][0] + self.time_zone = country_timezones[self.country_code][0] - self.available = self.exists(self.code) + self.available = self.exists(self.country_code, self.code) def is_available(self): - return self.exists(self.code) + return self.exists(self.country_code, self.code) @staticmethod def find_by_id(id: int): @@ -46,14 +55,14 @@ class Location(Model): return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id]) @staticmethod - def find(code: str): + def find(country_code: str, code: str): Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) - return Model._query_one('SELECT * FROM locations WHERE code = ? LIMIT 1', Location.factory, [code]) + return Model._query_one('SELECT * FROM locations WHERE country_code = ? AND code = ? LIMIT 1', Location.factory, [country_code, code]) @staticmethod - def exists(code: str): + def exists(country_code: str, code: str): Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) - return Model._query_exists('SELECT * FROM locations WHERE code = ?', [code]) + return Model._query_exists('SELECT * FROM locations WHERE country_code = ? AND code = ?', [country_code, code]) @staticmethod def all(): @@ -67,7 +76,7 @@ class Location(Model): @staticmethod def save_many(locations): Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) - Model._insert_many('INSERT INTO locations VALUES(?, ?, ?)', Location.tuple_factory, locations) + Model._insert_many('INSERT INTO locations VALUES(?, ?, ?, ?, ?, ?)', Location.tuple_factory, locations) @staticmethod def factory(cursor, row): @@ -76,4 +85,4 @@ class Location(Model): @staticmethod def tuple_factory(location): - return location.id, location.code, location.name + return location.id, location.country_code, location.country_name, location.code, location.name, location.time_zone diff --git a/core/services/WebServiceApiService.py b/core/services/WebServiceApiService.py index 0fa9b0a..19f7ab6 100644 --- a/core/services/WebServiceApiService.py +++ b/core/services/WebServiceApiService.py @@ -59,7 +59,7 @@ class WebServiceApiService: if response.status_code == requests.codes.ok: for location in response.json()['data']: - locations.append(Location(location['code'], location['id'], location['name'])) + locations.append(Location(location['country']['code'], location['code'], location['id'], location['country']['name'], location['name'], location['time_zone']['code'])) return locations @@ -135,9 +135,9 @@ class WebServiceApiService: return None @staticmethod - def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None): + def post_wireguard_session(country_code: str, location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None): - response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, { + response = WebServiceApiService.__post(f'/countries/{country_code}/locations/{location_code}/wireguard-sessions', billing_code, { 'public_key': public_key, }, proxies)