Implement support for country subdivisions

This commit is contained in:
codeking 2025-03-12 22:04:57 +01:00
parent 7f9c6ab1a9
commit fd6948a6a3
5 changed files with 28 additions and 19 deletions

View file

@ -6,8 +6,8 @@ from typing import Optional
class LocationController: class LocationController:
@staticmethod @staticmethod
def get(code: str): def get(country_code: str, code: str):
return Location.find(code) return Location.find(country_code, code)
@staticmethod @staticmethod
def get_all(): def get_all():

View file

@ -225,7 +225,7 @@ class ProfileController:
wireguard_keys = ProfileController.__generate_wireguard_keys() wireguard_keys = ProfileController.__generate_wireguard_keys()
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer) wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
if wireguard_configuration is None: if wireguard_configuration is None:
raise InvalidSubscriptionError() raise InvalidSubscriptionError()

View file

@ -84,11 +84,11 @@ class BaseProfile:
if profile['location'] is not None: if profile['location'] is not None:
location = Location.find(profile['location']['code'] or None) location = Location.find(profile['location']['country_code'] or None, profile['location']['code'] or None)
if location is not None: if location is not None:
if profile['location']['time_zone'] is not None: if profile['location'].get('time_zone') is not None:
location.time_zone = profile['location']['time_zone'] location.time_zone = profile['location']['time_zone']
profile['location'] = location profile['location'] = location

View file

@ -8,18 +8,27 @@ _table_name: str = 'locations'
_table_definition: str = """ _table_definition: str = """
'id' int UNIQUE, 'id' int UNIQUE,
'code' varchar UNIQUE, 'country_code' varchar,
'name' varchar UNIQUE 'country_name' varchar,
'code' varchar,
'name' varchar,
'time_zone' varchar,
UNIQUE(code, country_code)
""" """
@dataclass @dataclass
class Location(Model): class Location(Model):
country_code: str
code: str code: str
id: Optional[int] = field( id: Optional[int] = field(
default=None, default=None,
metadata=config(exclude=Exclude.ALWAYS) metadata=config(exclude=Exclude.ALWAYS)
) )
country_name: Optional[str] = field(
default=None,
metadata=config(exclude=Exclude.ALWAYS)
)
name: Optional[str] = field( name: Optional[str] = field(
default=None, default=None,
metadata=config(exclude=Exclude.ALWAYS) metadata=config(exclude=Exclude.ALWAYS)
@ -33,12 +42,12 @@ class Location(Model):
def __post_init__(self): def __post_init__(self):
if self.time_zone is None: if self.time_zone is None:
self.time_zone = country_timezones[self.code][0] self.time_zone = country_timezones[self.country_code][0]
self.available = self.exists(self.code) self.available = self.exists(self.country_code, self.code)
def is_available(self): def is_available(self):
return self.exists(self.code) return self.exists(self.country_code, self.code)
@staticmethod @staticmethod
def find_by_id(id: int): def find_by_id(id: int):
@ -46,14 +55,14 @@ class Location(Model):
return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id]) return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id])
@staticmethod @staticmethod
def find(code: str): def find(country_code: str, code: str):
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
return Model._query_one('SELECT * FROM locations WHERE code = ? LIMIT 1', Location.factory, [code]) return Model._query_one('SELECT * FROM locations WHERE country_code = ? AND code = ? LIMIT 1', Location.factory, [country_code, code])
@staticmethod @staticmethod
def exists(code: str): def exists(country_code: str, code: str):
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
return Model._query_exists('SELECT * FROM locations WHERE code = ?', [code]) return Model._query_exists('SELECT * FROM locations WHERE country_code = ? AND code = ?', [country_code, code])
@staticmethod @staticmethod
def all(): def all():
@ -67,7 +76,7 @@ class Location(Model):
@staticmethod @staticmethod
def save_many(locations): def save_many(locations):
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition) Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
Model._insert_many('INSERT INTO locations VALUES(?, ?, ?)', Location.tuple_factory, locations) Model._insert_many('INSERT INTO locations VALUES(?, ?, ?, ?, ?, ?)', Location.tuple_factory, locations)
@staticmethod @staticmethod
def factory(cursor, row): def factory(cursor, row):
@ -76,4 +85,4 @@ class Location(Model):
@staticmethod @staticmethod
def tuple_factory(location): def tuple_factory(location):
return location.id, location.code, location.name return location.id, location.country_code, location.country_name, location.code, location.name, location.time_zone

View file

@ -59,7 +59,7 @@ class WebServiceApiService:
if response.status_code == requests.codes.ok: if response.status_code == requests.codes.ok:
for location in response.json()['data']: for location in response.json()['data']:
locations.append(Location(location['code'], location['id'], location['name'])) locations.append(Location(location['country']['code'], location['code'], location['id'], location['country']['name'], location['name'], location['time_zone']['code']))
return locations return locations
@ -135,9 +135,9 @@ class WebServiceApiService:
return None return None
@staticmethod @staticmethod
def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None): def post_wireguard_session(country_code: str, location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, { response = WebServiceApiService.__post(f'/countries/{country_code}/locations/{location_code}/wireguard-sessions', billing_code, {
'public_key': public_key, 'public_key': public_key,
}, proxies) }, proxies)