diff --git a/src/asyncio_taskpool/constants.py b/src/asyncio_taskpool/constants.py index b3dbce2..8713894 100644 --- a/src/asyncio_taskpool/constants.py +++ b/src/asyncio_taskpool/constants.py @@ -4,4 +4,5 @@ CMD_START = 'start' CMD_STOP = 'stop' CMD_STOP_ALL = 'stop_all' CMD_SIZE = 'size' +CMD_FUNC = 'func' CLIENT_EXIT = 'exit' diff --git a/src/asyncio_taskpool/pool.py b/src/asyncio_taskpool/pool.py index 87f01e1..421bb5d 100644 --- a/src/asyncio_taskpool/pool.py +++ b/src/asyncio_taskpool/pool.py @@ -11,8 +11,16 @@ log = logging.getLogger(__name__) class TaskPool: + _pools: List['TaskPool'] = [] + + @classmethod + def _add_pool(cls, pool: 'TaskPool') -> int: + cls._pools.append(pool) + return len(cls._pools) - 1 + def __init__(self, func: CoroutineFunc, args: Iterable[Any] = (), kwargs: Mapping[str, Any] = None, - final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None) -> None: + final_callback: FinalCallbackT = None, cancel_callback: CancelCallbackT = None, + name: str = None) -> None: self._func: CoroutineFunc = func self._args: Iterable[Any] = args self._kwargs: Mapping[str, Any] = kwargs if kwargs is not None else {} @@ -20,6 +28,9 @@ class TaskPool: self._cancel_callback: CancelCallbackT = cancel_callback self._tasks: List[Task] = [] self._cancelled: List[Task] = [] + self._idx: int = self._add_pool(self) + self._name: str = name + log.debug("%s initialized", repr(self)) @property def func_name(self) -> str: @@ -29,8 +40,11 @@ class TaskPool: def size(self) -> int: return len(self._tasks) + def __str__(self) -> str: + return f'{self.__class__.__name__}-{self._name or self._idx}' + def __repr__(self) -> str: - return f'<{self.__class__.__name__} func={self.func_name} size={self.size}>' + return f'<{self} func={self.func_name}>' def _task_name(self, i: int) -> str: return f'{self.func_name}_pool_task_{i}' diff --git a/src/asyncio_taskpool/server.py b/src/asyncio_taskpool/server.py index 72a8cbf..7f57ac3 100644 --- a/src/asyncio_taskpool/server.py +++ b/src/asyncio_taskpool/server.py @@ -70,6 +70,10 @@ class ControlServer(ABC): log.debug("%s requests pool size", self.client_class.__name__) writer.write(f'{self._pool.size}'.encode()) + def _pool_func(self, writer: StreamWriter) -> None: + log.debug("%s requests pool function", self.client_class.__name__) + writer.write(self._pool.func_name.encode()) + async def _listen(self, reader: StreamReader, writer: StreamWriter) -> None: while self._server.is_serving(): msg = (await reader.read(constants.MSG_BYTES)).decode().strip() @@ -85,6 +89,8 @@ class ControlServer(ABC): self._stop_all_tasks(writer) elif cmd == constants.CMD_SIZE: self._pool_size(writer) + elif cmd == constants.CMD_FUNC: + self._pool_func(writer) else: log.debug("%s sent invalid command: %s", self.client_class.__name__, msg) writer.write(b"Invalid command!")