Implement support for country subdivisions
This commit is contained in:
parent
7f9c6ab1a9
commit
fd6948a6a3
5 changed files with 28 additions and 19 deletions
|
@ -6,8 +6,8 @@ from typing import Optional
|
||||||
class LocationController:
|
class LocationController:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(code: str):
|
def get(country_code: str, code: str):
|
||||||
return Location.find(code)
|
return Location.find(country_code, code)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all():
|
def get_all():
|
||||||
|
|
|
@ -225,7 +225,7 @@ class ProfileController:
|
||||||
|
|
||||||
wireguard_keys = ProfileController.__generate_wireguard_keys()
|
wireguard_keys = ProfileController.__generate_wireguard_keys()
|
||||||
|
|
||||||
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
|
wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
|
||||||
|
|
||||||
if wireguard_configuration is None:
|
if wireguard_configuration is None:
|
||||||
raise InvalidSubscriptionError()
|
raise InvalidSubscriptionError()
|
||||||
|
|
|
@ -84,11 +84,11 @@ class BaseProfile:
|
||||||
|
|
||||||
if profile['location'] is not None:
|
if profile['location'] is not None:
|
||||||
|
|
||||||
location = Location.find(profile['location']['code'] or None)
|
location = Location.find(profile['location']['country_code'] or None, profile['location']['code'] or None)
|
||||||
|
|
||||||
if location is not None:
|
if location is not None:
|
||||||
|
|
||||||
if profile['location']['time_zone'] is not None:
|
if profile['location'].get('time_zone') is not None:
|
||||||
location.time_zone = profile['location']['time_zone']
|
location.time_zone = profile['location']['time_zone']
|
||||||
|
|
||||||
profile['location'] = location
|
profile['location'] = location
|
||||||
|
|
|
@ -8,18 +8,27 @@ _table_name: str = 'locations'
|
||||||
|
|
||||||
_table_definition: str = """
|
_table_definition: str = """
|
||||||
'id' int UNIQUE,
|
'id' int UNIQUE,
|
||||||
'code' varchar UNIQUE,
|
'country_code' varchar,
|
||||||
'name' varchar UNIQUE
|
'country_name' varchar,
|
||||||
|
'code' varchar,
|
||||||
|
'name' varchar,
|
||||||
|
'time_zone' varchar,
|
||||||
|
UNIQUE(code, country_code)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Location(Model):
|
class Location(Model):
|
||||||
|
country_code: str
|
||||||
code: str
|
code: str
|
||||||
id: Optional[int] = field(
|
id: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=config(exclude=Exclude.ALWAYS)
|
metadata=config(exclude=Exclude.ALWAYS)
|
||||||
)
|
)
|
||||||
|
country_name: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata=config(exclude=Exclude.ALWAYS)
|
||||||
|
)
|
||||||
name: Optional[str] = field(
|
name: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata=config(exclude=Exclude.ALWAYS)
|
metadata=config(exclude=Exclude.ALWAYS)
|
||||||
|
@ -33,12 +42,12 @@ class Location(Model):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
||||||
if self.time_zone is None:
|
if self.time_zone is None:
|
||||||
self.time_zone = country_timezones[self.code][0]
|
self.time_zone = country_timezones[self.country_code][0]
|
||||||
|
|
||||||
self.available = self.exists(self.code)
|
self.available = self.exists(self.country_code, self.code)
|
||||||
|
|
||||||
def is_available(self):
|
def is_available(self):
|
||||||
return self.exists(self.code)
|
return self.exists(self.country_code, self.code)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_by_id(id: int):
|
def find_by_id(id: int):
|
||||||
|
@ -46,14 +55,14 @@ class Location(Model):
|
||||||
return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id])
|
return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find(code: str):
|
def find(country_code: str, code: str):
|
||||||
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
||||||
return Model._query_one('SELECT * FROM locations WHERE code = ? LIMIT 1', Location.factory, [code])
|
return Model._query_one('SELECT * FROM locations WHERE country_code = ? AND code = ? LIMIT 1', Location.factory, [country_code, code])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exists(code: str):
|
def exists(country_code: str, code: str):
|
||||||
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
||||||
return Model._query_exists('SELECT * FROM locations WHERE code = ?', [code])
|
return Model._query_exists('SELECT * FROM locations WHERE country_code = ? AND code = ?', [country_code, code])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def all():
|
def all():
|
||||||
|
@ -67,7 +76,7 @@ class Location(Model):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_many(locations):
|
def save_many(locations):
|
||||||
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
|
||||||
Model._insert_many('INSERT INTO locations VALUES(?, ?, ?)', Location.tuple_factory, locations)
|
Model._insert_many('INSERT INTO locations VALUES(?, ?, ?, ?, ?, ?)', Location.tuple_factory, locations)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def factory(cursor, row):
|
def factory(cursor, row):
|
||||||
|
@ -76,4 +85,4 @@ class Location(Model):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tuple_factory(location):
|
def tuple_factory(location):
|
||||||
return location.id, location.code, location.name
|
return location.id, location.country_code, location.country_name, location.code, location.name, location.time_zone
|
||||||
|
|
|
@ -59,7 +59,7 @@ class WebServiceApiService:
|
||||||
|
|
||||||
if response.status_code == requests.codes.ok:
|
if response.status_code == requests.codes.ok:
|
||||||
for location in response.json()['data']:
|
for location in response.json()['data']:
|
||||||
locations.append(Location(location['code'], location['id'], location['name']))
|
locations.append(Location(location['country']['code'], location['code'], location['id'], location['country']['name'], location['name'], location['time_zone']['code']))
|
||||||
|
|
||||||
return locations
|
return locations
|
||||||
|
|
||||||
|
@ -135,9 +135,9 @@ class WebServiceApiService:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
|
def post_wireguard_session(country_code: str, location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
|
||||||
|
|
||||||
response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, {
|
response = WebServiceApiService.__post(f'/countries/{country_code}/locations/{location_code}/wireguard-sessions', billing_code, {
|
||||||
'public_key': public_key,
|
'public_key': public_key,
|
||||||
}, proxies)
|
}, proxies)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue