moved imports for unix socket connections to init methods of client and server

This commit is contained in:
Daniil Fajnberg 2022-02-25 19:09:28 +01:00
parent c63f079da4
commit ed6badb088
2 changed files with 8 additions and 4 deletions

View File

@ -23,7 +23,7 @@ import json
import shutil import shutil
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -153,6 +153,8 @@ class UnixControlClient(ControlClient):
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument. The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
""" """
from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path) self._socket_path = Path(socket_path)
super().__init__(**conn_kwargs) super().__init__(**conn_kwargs)
@ -164,7 +166,7 @@ class UnixControlClient(ControlClient):
otherwise, the stream-reader and -writer tuple is returned. otherwise, the stream-reader and -writer tuple is returned.
""" """
try: try:
return await open_unix_connection(self._socket_path, **kwargs) return await self._open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError: except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr) print("No socket at", self._socket_path, file=sys.stderr)
return None, None return None, None

View File

@ -23,7 +23,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from asyncio import AbstractServer from asyncio import AbstractServer
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server from asyncio.streams import StreamReader, StreamWriter
from asyncio.tasks import Task, create_task from asyncio.tasks import Task, create_task
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
@ -137,11 +137,13 @@ class UnixControlServer(ControlServer):
_client_class = UnixControlClient _client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None: def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
from asyncio.streams import start_unix_server
self._start_unix_server = start_unix_server
self._socket_path = Path(server_kwargs.pop('path')) self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs) super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer: async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs) server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path)) log.debug("Opened socket '%s'", str(self._socket_path))
return server return server