from core.Constants import Constants from dataclasses import dataclass, field from dataclasses_json import config, Exclude, dataclass_json from json import JSONDecodeError from pathlib import Path import json import os import psutil import re import shutil @dataclass_json @dataclass class SessionState: id: int = field( metadata=config(exclude=Exclude.ALWAYS) ) network_port_numbers: list[int] = field(default_factory=list) process_ids: list[int] = field(default_factory=list) def get_state_path(self): return SessionState.__get_state_path(self.id) @staticmethod def find_by_id(id: int): try: session_state_file_contents = open(f'{SessionState.__get_state_path(id)}/state.json', 'r').read() except FileNotFoundError: return None try: session_state = json.loads(session_state_file_contents) except JSONDecodeError: return None session_state['id'] = id # noinspection PyUnresolvedReferences return SessionState.from_dict(session_state) @staticmethod def exists(id: int): return os.path.isdir(SessionState.__get_state_path(id)) and re.match(r'^\d+$', str(id)) @staticmethod def save(session_state): session_state_file_contents = f'{session_state.to_json(indent=4)}\n' os.makedirs(SessionState.__get_state_path(session_state.id), exist_ok=True, mode=0o700) session_state_file_path = f'{SessionState.__get_state_path(session_state.id)}/state.json' Path(session_state_file_path).touch(exist_ok=True, mode=0o600) with open(session_state_file_path, 'w') as session_state_file: session_state_file.write(session_state_file_contents) session_state_file.close() @staticmethod def dissolve(id: int): session_state = SessionState.find_by_id(id) if session_state is not None: session_state_path = session_state.get_state_path() SessionState.__kill_associated_processes(session_state) shutil.rmtree(session_state_path, ignore_errors=True) @staticmethod def __kill_associated_processes(session_state): associated_process_ids = list(session_state.process_ids) network_connections = psutil.net_connections() for network_port_number in session_state.network_port_numbers: for network_connection in network_connections: if network_connection.laddr != tuple() and network_connection.laddr.port == network_port_number: if network_connection.pid is not None: associated_process_ids.append(network_connection.pid) if network_connection.raddr != tuple() and network_connection.raddr.port == network_port_number: if network_connection.pid is not None: associated_process_ids.append(network_connection.pid) for process in psutil.process_iter(): if process.pid in associated_process_ids and process.is_running(): for child_process in process.children(True): child_process.kill() process.kill() @staticmethod def __get_state_path(id: int): return f'{Constants.HV_SESSION_STATE_HOME}/{str(id)}'