feat: Implement Smart Tunnel async architecture (v6.0.0b1)
- Replaced blocking pexpect interact with asyncio-based multiplexing. - Implemented LocalStream and RemoteStream for agnostic I/O handling. - Real-time SIGWINCH window resizing support. - 'Ghost buffer' mitigation for clean, artifact-free session handovers. - Upgraded _logclean into a mini terminal emulator to accurately process ANSI, backspaces, and inline clears. - Continuous auto-saving for logs without blocking the main thread. - Bumped version to 6.0.0b1 and regenerated pdoc documentation.
This commit is contained in:
+1
-1
@@ -1 +1 @@
|
||||
__version__ = "5.1b6"
|
||||
__version__ = "6.0.0b1"
|
||||
|
||||
+225
-46
@@ -14,7 +14,10 @@ from pathlib import Path
|
||||
from copy import deepcopy
|
||||
from .hooks import ClassHook, MethodHook
|
||||
import io
|
||||
import asyncio
|
||||
import fcntl
|
||||
from . import printer
|
||||
from .tunnels import LocalStream
|
||||
|
||||
|
||||
#functions and classes
|
||||
@@ -189,23 +192,54 @@ class node:
|
||||
|
||||
@MethodHook
|
||||
def _logclean(self, logfile, var = False):
|
||||
#Remove special ascii characters and other stuff from logfile.
|
||||
# Remove special ascii characters and process terminal cursor movements to clean logs.
|
||||
if var == False:
|
||||
t = open(logfile, "r").read()
|
||||
else:
|
||||
t = logfile
|
||||
while t.find("\b") != -1:
|
||||
t = re.sub('[^\b]\b', '', t)
|
||||
t = t.replace("\n","",1)
|
||||
t = t.replace("\a","")
|
||||
t = t.replace('\n\n', '\n')
|
||||
t = re.sub(r'.\[K', '', t)
|
||||
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])')
|
||||
t = ansi_escape.sub('', t)
|
||||
t = t.lstrip(" \n\r")
|
||||
t = t.replace("\r","")
|
||||
t = t.replace("\x0E","")
|
||||
t = t.replace("\x0F","")
|
||||
|
||||
lines = t.split('\n')
|
||||
cleaned_lines = []
|
||||
|
||||
# Regex to capture: ANSI sequences, control characters (\r, \b, etc), and plain text chunks
|
||||
token_re = re.compile(r'(\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/ ]*[@-~])|\r|\b|\x7f|[\x00-\x1F]|[^\x1B\r\b\x7f\x00-\x1F]+)')
|
||||
|
||||
for line in lines:
|
||||
buffer = []
|
||||
cursor = 0
|
||||
|
||||
for token in token_re.findall(line):
|
||||
if token == '\r':
|
||||
cursor = 0
|
||||
elif token in ('\b', '\x7f'):
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[D': # Left Arrow
|
||||
if cursor > 0:
|
||||
cursor -= 1
|
||||
elif token == '\x1B[C': # Right Arrow
|
||||
if cursor < len(buffer):
|
||||
cursor += 1
|
||||
elif token == '\x1B[K': # Clear to end of line
|
||||
buffer = buffer[:cursor]
|
||||
elif token.startswith('\x1B'):
|
||||
# Ignore other ANSI sequences (colors, etc)
|
||||
continue
|
||||
elif len(token) == 1 and ord(token) < 32:
|
||||
# Ignore other non-printable control chars
|
||||
continue
|
||||
else:
|
||||
# Regular printable text
|
||||
for char in token:
|
||||
if cursor == len(buffer):
|
||||
buffer.append(char)
|
||||
else:
|
||||
buffer[cursor] = char
|
||||
cursor += 1
|
||||
cleaned_lines.append("".join(buffer))
|
||||
|
||||
t = "\n".join(cleaned_lines).replace('\n\n', '\n').strip()
|
||||
|
||||
if var == False:
|
||||
d = open(logfile, "w")
|
||||
d.write(t)
|
||||
@@ -248,48 +282,193 @@ class node:
|
||||
sleep(1)
|
||||
|
||||
|
||||
@MethodHook
|
||||
def interact(self, debug = False, logger = None):
|
||||
'''
|
||||
Allow user to interact with the node directly, mostly used by connection manager.
|
||||
def _setup_interact_environment(self, debug=False, logger=None, async_mode=False):
|
||||
size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size()))
|
||||
self.child.setwinsize(int(size.group(2)),int(size.group(1)))
|
||||
if logger:
|
||||
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
|
||||
|
||||
### Optional Parameters:
|
||||
|
||||
- debug (bool): If True, display all the connecting information
|
||||
before interact. Default False.
|
||||
- logger (callable): Optional callback for status reporting.
|
||||
'''
|
||||
connect = self._connect(debug = debug, logger = logger)
|
||||
if connect == True:
|
||||
size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size()))
|
||||
self.child.setwinsize(int(size.group(2)),int(size.group(1)))
|
||||
if logger:
|
||||
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
|
||||
|
||||
if 'logfile' in dir(self):
|
||||
# Initialize self.mylog
|
||||
if not 'mylog' in dir(self):
|
||||
self.mylog = io.BytesIO()
|
||||
if 'logfile' in dir(self):
|
||||
# Initialize self.mylog
|
||||
if not 'mylog' in dir(self):
|
||||
self.mylog = io.BytesIO()
|
||||
if not async_mode:
|
||||
self.child.logfile_read = self.mylog
|
||||
|
||||
# Start the _savelog thread
|
||||
log_thread = threading.Thread(target=self._savelog)
|
||||
log_thread.daemon = True
|
||||
log_thread.start()
|
||||
if 'missingtext' in dir(self):
|
||||
print(self.child.after.decode(), end='')
|
||||
if self.idletime > 0:
|
||||
x = threading.Thread(target=self._keepalive)
|
||||
x.daemon = True
|
||||
x.start()
|
||||
if debug:
|
||||
if 'missingtext' in dir(self):
|
||||
print(self.child.after.decode(), end='')
|
||||
if self.idletime > 0 and not async_mode:
|
||||
x = threading.Thread(target=self._keepalive)
|
||||
x.daemon = True
|
||||
x.start()
|
||||
if debug:
|
||||
if 'mylog' in dir(self):
|
||||
print(self.mylog.getvalue().decode())
|
||||
self.child.interact(input_filter=self._filter)
|
||||
if 'logfile' in dir(self):
|
||||
with open(self.logfile, "w") as f:
|
||||
f.write(self._logclean(self.mylog.getvalue().decode(), True))
|
||||
|
||||
def _teardown_interact_environment(self):
|
||||
if 'logfile' in dir(self) and hasattr(self, 'mylog'):
|
||||
with open(self.logfile, "w") as f:
|
||||
f.write(self._logclean(self.mylog.getvalue().decode(), True))
|
||||
|
||||
async def _async_interact_loop(self, local_stream, resize_callback):
|
||||
local_stream.setup(resize_callback=resize_callback)
|
||||
try:
|
||||
child_fd = self.child.child_fd
|
||||
|
||||
# 1. Flush ghost buffer (Clean UX)
|
||||
ghost_buffer = b''
|
||||
if getattr(self, 'missingtext', False):
|
||||
# If we are missing the password, we MUST show the password prompt
|
||||
ghost_buffer = (self.child.after or b'') + (self.child.buffer or b'')
|
||||
else:
|
||||
# We auto-logged in. Hide the messy password negotiation and just keep any pending live stream.
|
||||
ghost_buffer = self.child.buffer or b''
|
||||
|
||||
# Fix user's pet peeve: Strip leading newlines to avoid the empty lines
|
||||
# the router echoes after receiving the password or blank line.
|
||||
if not getattr(self, 'missingtext', False):
|
||||
ghost_buffer = ghost_buffer.lstrip(b'\r\n ')
|
||||
|
||||
if ghost_buffer:
|
||||
# Add a single clean newline so it doesn't merge with the Connected message
|
||||
await local_stream.write(b'\r\n' + ghost_buffer)
|
||||
if hasattr(self, 'mylog'):
|
||||
self.mylog.write(b'\n' + ghost_buffer)
|
||||
|
||||
self.child.buffer = b''
|
||||
self.child.before = b''
|
||||
|
||||
# 2. Set child fd non-blocking
|
||||
flags = fcntl.fcntl(child_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(child_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
child_reader_queue = asyncio.Queue()
|
||||
|
||||
def _child_read_ready():
|
||||
try:
|
||||
data = os.read(child_fd, 4096)
|
||||
if data:
|
||||
child_reader_queue.put_nowait(data)
|
||||
else:
|
||||
child_reader_queue.put_nowait(b'')
|
||||
except BlockingIOError:
|
||||
pass
|
||||
except OSError:
|
||||
child_reader_queue.put_nowait(b'')
|
||||
|
||||
loop.add_reader(child_fd, _child_read_ready)
|
||||
self.lastinput = time()
|
||||
|
||||
async def ingress_task():
|
||||
while True:
|
||||
data = await local_stream.read()
|
||||
if not data:
|
||||
break
|
||||
try:
|
||||
os.write(child_fd, data)
|
||||
except OSError:
|
||||
break
|
||||
self.lastinput = time()
|
||||
|
||||
async def egress_task():
|
||||
# Continue stripping newlines from the live stream until we hit real text
|
||||
skip_newlines = not getattr(self, 'missingtext', False) and not ghost_buffer
|
||||
while True:
|
||||
data = await child_reader_queue.get()
|
||||
if not data:
|
||||
break
|
||||
|
||||
if skip_newlines:
|
||||
stripped = data.lstrip(b'\r\n')
|
||||
if stripped:
|
||||
skip_newlines = False
|
||||
data = stripped
|
||||
else:
|
||||
continue
|
||||
|
||||
await local_stream.write(data)
|
||||
if hasattr(self, 'mylog'):
|
||||
self.mylog.write(data)
|
||||
|
||||
async def keepalive_task():
|
||||
if self.idletime <= 0:
|
||||
return
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
if time() - self.lastinput >= self.idletime:
|
||||
try:
|
||||
self.child.sendcontrol("e")
|
||||
self.lastinput = time()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def savelog_task():
|
||||
if not hasattr(self, 'logfile') or not hasattr(self, 'mylog'):
|
||||
return
|
||||
prev_size = 0
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
current_size = self.mylog.tell()
|
||||
if current_size != prev_size:
|
||||
try:
|
||||
with open(self.logfile, "w") as f:
|
||||
f.write(self._logclean(self.mylog.getvalue().decode(), True))
|
||||
prev_size = current_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
# gather runs until any task completes (or we just let them run until EOF breaks them)
|
||||
# Ingress breaks on user EOF. Egress breaks on child EOF.
|
||||
# We want to exit if either happens, so return_exceptions=False, but we need to cancel the others.
|
||||
tasks = [
|
||||
asyncio.create_task(ingress_task()),
|
||||
asyncio.create_task(egress_task()),
|
||||
asyncio.create_task(keepalive_task()),
|
||||
asyncio.create_task(savelog_task())
|
||||
]
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
finally:
|
||||
loop.remove_reader(child_fd)
|
||||
try:
|
||||
flags = fcntl.fcntl(child_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(child_fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
local_stream.teardown()
|
||||
|
||||
|
||||
@MethodHook
|
||||
def interact(self, debug=False, logger=None):
|
||||
'''
|
||||
Asynchronous interactive session using Smart Tunnel architecture.
|
||||
Allows multiplexing I/O and handling SIGWINCH events locally without blocking.
|
||||
'''
|
||||
connect = self._connect(debug=debug, logger=logger)
|
||||
if connect == True:
|
||||
try:
|
||||
self._setup_interact_environment(debug=debug, logger=logger, async_mode=True)
|
||||
|
||||
local_stream = LocalStream()
|
||||
|
||||
def resize_callback(rows, cols):
|
||||
try:
|
||||
self.child.setwinsize(rows, cols)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
asyncio.run(self._async_interact_loop(local_stream, resize_callback))
|
||||
finally:
|
||||
self._teardown_interact_environment()
|
||||
else:
|
||||
if logger:
|
||||
logger("error", str(connect))
|
||||
|
||||
+38
-55
@@ -61,10 +61,13 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
@handle_errors
|
||||
def interact_node(self, request_iterator, context):
|
||||
import sys
|
||||
import select
|
||||
import os
|
||||
import asyncio
|
||||
from connpy.core import node
|
||||
from ..services.profile_service import ProfileService
|
||||
from connpy.tunnels import RemoteStream
|
||||
import queue
|
||||
import threading
|
||||
|
||||
# Fetch first setup packet
|
||||
try:
|
||||
@@ -83,11 +86,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
base_node_id = params.get("base_node")
|
||||
# Valid attributes that a node object accepts
|
||||
valid_attrs = ['host', 'options', 'logs', 'password', 'port', 'protocol', 'user', 'jumphost']
|
||||
|
||||
|
||||
fallback_id = f"{unique_id}@remote"
|
||||
if unique_id == "dynamic" and params.get("host"):
|
||||
fallback_id = f"dynamic-{params.get('host')}@remote"
|
||||
|
||||
|
||||
if base_node_id:
|
||||
# Look up the base node in config and use its full data
|
||||
nodes = self.service.config._getallnodes(base_node_id)
|
||||
@@ -97,14 +100,14 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
for attr in valid_attrs:
|
||||
if attr in params:
|
||||
device[attr] = params[attr]
|
||||
|
||||
|
||||
if "tags" in params:
|
||||
device_tags = device.get("tags", {})
|
||||
if not isinstance(device_tags, dict):
|
||||
device_tags = {}
|
||||
device_tags.update(params["tags"])
|
||||
device["tags"] = device_tags
|
||||
|
||||
|
||||
node_name = params.get("name", base_node_id)
|
||||
n = node(node_name, **device, config=self.service.config)
|
||||
else:
|
||||
@@ -138,34 +141,10 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
if connect != True:
|
||||
yield connpy_pb2.InteractResponse(success=False, error_message=str(connect))
|
||||
return
|
||||
|
||||
|
||||
# Signal successful connection to the client
|
||||
yield connpy_pb2.InteractResponse(success=True)
|
||||
|
||||
import threading
|
||||
import queue
|
||||
|
||||
stdin_queue = queue.Queue()
|
||||
running = True
|
||||
|
||||
def read_requests():
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if not running:
|
||||
break
|
||||
if req.cols > 0 and req.rows > 0:
|
||||
try:
|
||||
n.child.setwinsize(req.rows, req.cols)
|
||||
except Exception:
|
||||
pass
|
||||
if req.stdin_data:
|
||||
stdin_queue.put(req.stdin_data)
|
||||
except grpc.RpcError:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=read_requests, daemon=True)
|
||||
t.start()
|
||||
|
||||
# Set initial window size if provided
|
||||
if first_req.cols > 0 and first_req.rows > 0:
|
||||
try:
|
||||
@@ -173,32 +152,34 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while n.child.isalive() and running:
|
||||
r, _, _ = select.select([n.child.child_fd], [], [], 0.05)
|
||||
if r:
|
||||
try:
|
||||
data = os.read(n.child.child_fd, 4096)
|
||||
if not data:
|
||||
break
|
||||
yield connpy_pb2.InteractResponse(stdout_data=data)
|
||||
except OSError:
|
||||
break
|
||||
|
||||
while not stdin_queue.empty():
|
||||
data = stdin_queue.get_nowait()
|
||||
try:
|
||||
os.write(n.child.child_fd, data)
|
||||
except OSError:
|
||||
running = False
|
||||
break
|
||||
finally:
|
||||
running = False
|
||||
try:
|
||||
n.child.terminate(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
response_queue = queue.Queue()
|
||||
remote_stream = RemoteStream(request_iterator, response_queue)
|
||||
|
||||
def run_async_loop():
|
||||
try:
|
||||
n._setup_interact_environment(debug=debug, logger=None, async_mode=True)
|
||||
def resize_callback(rows, cols):
|
||||
try:
|
||||
n.child.setwinsize(rows, cols)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
asyncio.run(n._async_interact_loop(remote_stream, resize_callback))
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
n._teardown_interact_environment()
|
||||
response_queue.put(None) # Signal EOF
|
||||
|
||||
t_loop = threading.Thread(target=run_async_loop, daemon=True)
|
||||
t_loop.start()
|
||||
|
||||
while True:
|
||||
data = response_queue.get()
|
||||
if data is None:
|
||||
printer.console.print(f"[debug][DEBUG][/debug] gRPC interact_node session closed for: [bold cyan]{unique_id}[/bold cyan]")
|
||||
break
|
||||
yield connpy_pb2.InteractResponse(stdout_data=data)
|
||||
@handle_errors
|
||||
def list_nodes(self, request, context):
|
||||
f = request.filter_str if request.filter_str else None
|
||||
@@ -691,6 +672,8 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
daemon=True
|
||||
)
|
||||
ai_thread.start()
|
||||
except grpc.RpcError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Request Listener Error: {e}")
|
||||
finally:
|
||||
|
||||
@@ -182,7 +182,7 @@ class NodeService(BaseService):
|
||||
n = node(unique_id, **resolved_data, config=self.config)
|
||||
if sftp:
|
||||
n.protocol = "sftp"
|
||||
|
||||
|
||||
n.interact(debug=debug, logger=logger)
|
||||
|
||||
def move_node(self, src_id, dst_id, copy=False):
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import termios
|
||||
import tty
|
||||
import signal
|
||||
import struct
|
||||
import fcntl
|
||||
|
||||
class LocalStream:
|
||||
"""
|
||||
Asynchronous stream wrapper for local stdin/stdout.
|
||||
Handles terminal raw mode, async I/O, and SIGWINCH signals.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.stdin_fd = sys.stdin.fileno()
|
||||
self.stdout_fd = sys.stdout.fileno()
|
||||
self.original_tty_settings = None
|
||||
self.resize_callback = None
|
||||
self._reader_queue = asyncio.Queue()
|
||||
self._loop = None
|
||||
|
||||
def setup(self, resize_callback=None):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.resize_callback = resize_callback
|
||||
|
||||
# Save original terminal settings
|
||||
try:
|
||||
self.original_tty_settings = termios.tcgetattr(self.stdin_fd)
|
||||
tty.setraw(self.stdin_fd)
|
||||
except termios.error:
|
||||
# Not a TTY, maybe piped or redirected
|
||||
pass
|
||||
|
||||
# Set stdin non-blocking
|
||||
flags = fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
# Setup read callback
|
||||
self._loop.add_reader(self.stdin_fd, self._read_ready)
|
||||
|
||||
# Register SIGWINCH
|
||||
if resize_callback:
|
||||
try:
|
||||
self._loop.add_signal_handler(signal.SIGWINCH, self._handle_winch)
|
||||
except (NotImplementedError, RuntimeError):
|
||||
# signal handling not supported on some loops (e.g., Windows Proactor)
|
||||
pass
|
||||
|
||||
def teardown(self):
|
||||
if self._loop:
|
||||
try:
|
||||
self._loop.remove_reader(self.stdin_fd)
|
||||
except Exception:
|
||||
pass
|
||||
if self.resize_callback:
|
||||
try:
|
||||
self._loop.remove_signal_handler(signal.SIGWINCH)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Restore terminal settings
|
||||
if self.original_tty_settings is not None:
|
||||
try:
|
||||
termios.tcsetattr(self.stdin_fd, termios.TCSADRAIN, self.original_tty_settings)
|
||||
except termios.error:
|
||||
pass
|
||||
|
||||
# Restore blocking mode for stdin
|
||||
try:
|
||||
flags = fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _read_ready(self):
|
||||
try:
|
||||
# Read whatever is available
|
||||
data = os.read(self.stdin_fd, 4096)
|
||||
if data:
|
||||
self._reader_queue.put_nowait(data)
|
||||
else:
|
||||
self._reader_queue.put_nowait(b'') # EOF
|
||||
except BlockingIOError:
|
||||
pass
|
||||
except OSError:
|
||||
self._reader_queue.put_nowait(b'') # EOF on error
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Asynchronously read bytes from stdin."""
|
||||
return await self._reader_queue.get()
|
||||
|
||||
async def write(self, data: bytes):
|
||||
"""Asynchronously write bytes to stdout."""
|
||||
if not data:
|
||||
return
|
||||
|
||||
try:
|
||||
os.write(self.stdout_fd, data)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _handle_winch(self):
|
||||
if self.resize_callback:
|
||||
try:
|
||||
# Use ioctl to get the current window size
|
||||
s = struct.pack("HHHH", 0, 0, 0, 0)
|
||||
a = fcntl.ioctl(self.stdout_fd, termios.TIOCGWINSZ, s)
|
||||
rows, cols, _, _ = struct.unpack("HHHH", a)
|
||||
|
||||
# We schedule the callback safely inside the asyncio loop
|
||||
# instead of running it raw in the signal handler
|
||||
self._loop.call_soon(self.resize_callback, rows, cols)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
import threading
|
||||
|
||||
class RemoteStream:
|
||||
"""
|
||||
Asynchronous stream wrapper for gRPC remote connections.
|
||||
Bridges the blocking gRPC iterators with the async _async_interact_loop.
|
||||
"""
|
||||
def __init__(self, request_iterator, response_queue):
|
||||
self.request_iterator = request_iterator
|
||||
self.response_queue = response_queue
|
||||
self.running = True
|
||||
self._reader_queue = asyncio.Queue()
|
||||
self.resize_callback = None
|
||||
self._loop = None
|
||||
self.t = None
|
||||
|
||||
def setup(self, resize_callback=None):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self.resize_callback = resize_callback
|
||||
|
||||
def read_requests():
|
||||
try:
|
||||
for req in self.request_iterator:
|
||||
if not self.running:
|
||||
break
|
||||
if req.cols > 0 and req.rows > 0:
|
||||
if self.resize_callback:
|
||||
self._loop.call_soon_threadsafe(self.resize_callback, req.rows, req.cols)
|
||||
if req.stdin_data:
|
||||
self._loop.call_soon_threadsafe(self._reader_queue.put_nowait, req.stdin_data)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if self._loop and not self._loop.is_closed():
|
||||
try:
|
||||
self._loop.call_soon_threadsafe(self._reader_queue.put_nowait, b'')
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
self.t = threading.Thread(target=read_requests, daemon=True)
|
||||
self.t.start()
|
||||
|
||||
def teardown(self):
|
||||
self.running = False
|
||||
self.response_queue.put(None) # Signal EOF
|
||||
|
||||
async def read(self) -> bytes:
|
||||
"""Asynchronously read bytes from the gRPC iterator queue."""
|
||||
return await self._reader_queue.get()
|
||||
|
||||
async def write(self, data: bytes):
|
||||
"""Asynchronously write bytes to the gRPC response queue."""
|
||||
if data:
|
||||
self.response_queue.put(data)
|
||||
Reference in New Issue
Block a user