from core.Constants import Constants from dataclasses import dataclass, field from dataclasses_json import config, Exclude, dataclass_json from json import JSONDecodeError from pathlib import Path import json import os import psutil import re import shutil @dataclass_json @dataclass class SessionState: id: int = field( metadata=config(exclude=Exclude.ALWAYS) ) network_port_numbers: list[int] = field(default_factory=list) process_ids: list[int] = field(default_factory=list) def get_state_path(self): return SessionState._get_state_path(self.id) @staticmethod def find_by_id(id: int): try: session_state_file_contents = open(SessionState._get_state_path(id) + '/state.json', 'r').read() except FileNotFoundError: return None try: session_state = json.loads(session_state_file_contents) except JSONDecodeError: return None session_state['id'] = id # noinspection PyUnresolvedReferences return SessionState.from_dict(session_state) @staticmethod def exists(id: int): return os.path.isdir(SessionState._get_state_path(id)) and re.match(r'^\d+$', str(id)) @staticmethod def save(session_state): session_state_file_contents = session_state.to_json(indent=4) + '\n' os.makedirs(SessionState._get_state_path(session_state.id), exist_ok=True, mode=0o700) session_state_file_path = SessionState._get_state_path(session_state.id) + '/state.json' Path(session_state_file_path).touch(exist_ok=True, mode=0o600) text_io_wrapper = open(session_state_file_path, 'w') text_io_wrapper.write(session_state_file_contents) @staticmethod def dissolve(id: int): session_state = SessionState.find_by_id(id) if session_state is not None: session_state_path = session_state.get_state_path() SessionState._kill_associated_processes(session_state) shutil.rmtree(session_state_path, ignore_errors=True) @staticmethod def _kill_associated_processes(session_state): associated_process_ids = list(session_state.process_ids) network_connections = psutil.net_connections() for network_port_number in session_state.network_port_numbers: for network_connection in network_connections: if network_connection.laddr != tuple() and network_connection.laddr.port == network_port_number: if network_connection.pid is not None: associated_process_ids.append(network_connection.pid) if network_connection.raddr != tuple() and network_connection.raddr.port == network_port_number: if network_connection.pid is not None: associated_process_ids.append(network_connection.pid) for process in psutil.process_iter(): if process.pid in associated_process_ids and process.is_running(): for child_process in process.children(True): child_process.kill() process.kill() @staticmethod def _get_state_path(id: int): return Constants.SP_SESSION_STATE_HOME + '/' + str(id)