Implement support for country subdivisions
This commit is contained in:
		
							parent
							
								
									7f9c6ab1a9
								
							
						
					
					
						commit
						fd6948a6a3
					
				
					 5 changed files with 28 additions and 19 deletions
				
			
		| 
						 | 
					@ -6,8 +6,8 @@ from typing import Optional
 | 
				
			||||||
class LocationController:
 | 
					class LocationController:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def get(code: str):
 | 
					    def get(country_code: str, code: str):
 | 
				
			||||||
        return Location.find(code)
 | 
					        return Location.find(country_code, code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def get_all():
 | 
					    def get_all():
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -225,7 +225,7 @@ class ProfileController:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        wireguard_keys = ProfileController.__generate_wireguard_keys()
 | 
					        wireguard_keys = ProfileController.__generate_wireguard_keys()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
 | 
					        wireguard_configuration = ConnectionController.with_preferred_connection(profile.location.country_code, profile.location.code, profile.subscription.billing_code, wireguard_keys.get('public'), task=WebServiceApiService.post_wireguard_session, connection_observer=connection_observer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if wireguard_configuration is None:
 | 
					        if wireguard_configuration is None:
 | 
				
			||||||
            raise InvalidSubscriptionError()
 | 
					            raise InvalidSubscriptionError()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,11 +84,11 @@ class BaseProfile:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if profile['location'] is not None:
 | 
					        if profile['location'] is not None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            location = Location.find(profile['location']['code'] or None)
 | 
					            location = Location.find(profile['location']['country_code'] or None, profile['location']['code'] or None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if location is not None:
 | 
					            if location is not None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if profile['location']['time_zone'] is not None:
 | 
					                if profile['location'].get('time_zone') is not None:
 | 
				
			||||||
                    location.time_zone = profile['location']['time_zone']
 | 
					                    location.time_zone = profile['location']['time_zone']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                profile['location'] = location
 | 
					                profile['location'] = location
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,18 +8,27 @@ _table_name: str = 'locations'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_table_definition: str = """
 | 
					_table_definition: str = """
 | 
				
			||||||
    'id' int UNIQUE,
 | 
					    'id' int UNIQUE,
 | 
				
			||||||
    'code' varchar UNIQUE,
 | 
					    'country_code' varchar,
 | 
				
			||||||
    'name' varchar UNIQUE
 | 
					    'country_name' varchar,
 | 
				
			||||||
 | 
					    'code' varchar,
 | 
				
			||||||
 | 
					    'name' varchar,
 | 
				
			||||||
 | 
					    'time_zone' varchar,
 | 
				
			||||||
 | 
					    UNIQUE(code, country_code)
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class Location(Model):
 | 
					class Location(Model):
 | 
				
			||||||
 | 
					    country_code: str
 | 
				
			||||||
    code: str
 | 
					    code: str
 | 
				
			||||||
    id: Optional[int] = field(
 | 
					    id: Optional[int] = field(
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        metadata=config(exclude=Exclude.ALWAYS)
 | 
					        metadata=config(exclude=Exclude.ALWAYS)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    country_name: Optional[str] = field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        metadata=config(exclude=Exclude.ALWAYS)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    name: Optional[str] = field(
 | 
					    name: Optional[str] = field(
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        metadata=config(exclude=Exclude.ALWAYS)
 | 
					        metadata=config(exclude=Exclude.ALWAYS)
 | 
				
			||||||
| 
						 | 
					@ -33,12 +42,12 @@ class Location(Model):
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.time_zone is None:
 | 
					        if self.time_zone is None:
 | 
				
			||||||
            self.time_zone = country_timezones[self.code][0]
 | 
					            self.time_zone = country_timezones[self.country_code][0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.available = self.exists(self.code)
 | 
					        self.available = self.exists(self.country_code, self.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_available(self):
 | 
					    def is_available(self):
 | 
				
			||||||
        return self.exists(self.code)
 | 
					        return self.exists(self.country_code, self.code)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def find_by_id(id: int):
 | 
					    def find_by_id(id: int):
 | 
				
			||||||
| 
						 | 
					@ -46,14 +55,14 @@ class Location(Model):
 | 
				
			||||||
        return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id])
 | 
					        return Model._query_one('SELECT * FROM locations WHERE id = ? LIMIT 1', Location.factory, [id])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def find(code: str):
 | 
					    def find(country_code: str, code: str):
 | 
				
			||||||
        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
					        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
				
			||||||
        return Model._query_one('SELECT * FROM locations WHERE code = ? LIMIT 1', Location.factory, [code])
 | 
					        return Model._query_one('SELECT * FROM locations WHERE country_code = ? AND code = ? LIMIT 1', Location.factory, [country_code, code])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def exists(code: str):
 | 
					    def exists(country_code: str, code: str):
 | 
				
			||||||
        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
					        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
				
			||||||
        return Model._query_exists('SELECT * FROM locations WHERE code = ?', [code])
 | 
					        return Model._query_exists('SELECT * FROM locations WHERE country_code = ? AND code = ?', [country_code, code])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def all():
 | 
					    def all():
 | 
				
			||||||
| 
						 | 
					@ -67,7 +76,7 @@ class Location(Model):
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def save_many(locations):
 | 
					    def save_many(locations):
 | 
				
			||||||
        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
					        Model._create_table_if_not_exists(table_name=_table_name, table_definition=_table_definition)
 | 
				
			||||||
        Model._insert_many('INSERT INTO locations VALUES(?, ?, ?)', Location.tuple_factory, locations)
 | 
					        Model._insert_many('INSERT INTO locations VALUES(?, ?, ?, ?, ?, ?)', Location.tuple_factory, locations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def factory(cursor, row):
 | 
					    def factory(cursor, row):
 | 
				
			||||||
| 
						 | 
					@ -76,4 +85,4 @@ class Location(Model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def tuple_factory(location):
 | 
					    def tuple_factory(location):
 | 
				
			||||||
        return location.id, location.code, location.name
 | 
					        return location.id, location.country_code, location.country_name, location.code, location.name, location.time_zone
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,7 +59,7 @@ class WebServiceApiService:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if response.status_code == requests.codes.ok:
 | 
					        if response.status_code == requests.codes.ok:
 | 
				
			||||||
            for location in response.json()['data']:
 | 
					            for location in response.json()['data']:
 | 
				
			||||||
                locations.append(Location(location['code'], location['id'], location['name']))
 | 
					                locations.append(Location(location['country']['code'], location['code'], location['id'], location['country']['name'], location['name'], location['time_zone']['code']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return locations
 | 
					        return locations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -135,9 +135,9 @@ class WebServiceApiService:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def post_wireguard_session(location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
 | 
					    def post_wireguard_session(country_code: str, location_code: str, billing_code: str, public_key: str, proxies: Optional[dict] = None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        response = WebServiceApiService.__post(f'/locations/{location_code}/wireguard-sessions', billing_code, {
 | 
					        response = WebServiceApiService.__post(f'/countries/{country_code}/locations/{location_code}/wireguard-sessions', billing_code, {
 | 
				
			||||||
            'public_key': public_key,
 | 
					            'public_key': public_key,
 | 
				
			||||||
        }, proxies)
 | 
					        }, proxies)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue