moved imports for unix socket connections to init methods of client and server

This commit is contained in:
Daniil Fajnberg 2022-02-25 19:09:28 +01:00
parent c63f079da4
commit ed6badb088
2 changed files with 8 additions and 4 deletions

View File

@ -23,7 +23,7 @@ import json
import shutil
import sys
from abc import ABC, abstractmethod
from asyncio.streams import StreamReader, StreamWriter, open_unix_connection
from asyncio.streams import StreamReader, StreamWriter
from pathlib import Path
from typing import Optional
@ -153,6 +153,8 @@ class UnixControlClient(ControlClient):
The `_socket_path` attribute is set to the `Path` object created from the `socket_path` argument.
"""
from asyncio.streams import open_unix_connection
self._open_unix_connection = open_unix_connection
self._socket_path = Path(socket_path)
super().__init__(**conn_kwargs)
@ -164,7 +166,7 @@ class UnixControlClient(ControlClient):
otherwise, the stream-reader and -writer tuple is returned.
"""
try:
return await open_unix_connection(self._socket_path, **kwargs)
return await self._open_unix_connection(self._socket_path, **kwargs)
except FileNotFoundError:
print("No socket at", self._socket_path, file=sys.stderr)
return None, None

View File

@ -23,7 +23,7 @@ import logging
from abc import ABC, abstractmethod
from asyncio import AbstractServer
from asyncio.exceptions import CancelledError
from asyncio.streams import StreamReader, StreamWriter, start_unix_server
from asyncio.streams import StreamReader, StreamWriter
from asyncio.tasks import Task, create_task
from pathlib import Path
from typing import Optional, Union
@ -137,11 +137,13 @@ class UnixControlServer(ControlServer):
_client_class = UnixControlClient
def __init__(self, pool: SimpleTaskPool, **server_kwargs) -> None:
from asyncio.streams import start_unix_server
self._start_unix_server = start_unix_server
self._socket_path = Path(server_kwargs.pop('path'))
super().__init__(pool, **server_kwargs)
async def _get_server_instance(self, client_connected_cb: ConnectedCallbackT, **kwargs) -> AbstractServer:
server = await start_unix_server(client_connected_cb, self._socket_path, **kwargs)
server = await self._start_unix_server(client_connected_cb, self._socket_path, **kwargs)
log.debug("Opened socket '%s'", str(self._socket_path))
return server