refactor: Major upgrade to v5.1b6 - AWS SSM support & Distributed Architecture

Core & Protocols:
- Native AWS SSM support added (aws ssm start-session).
- Improved Pexpect logic for ssm, kubectl, and docker.
- Cleaned connection success messages (omitting ports for non-IP protocols).

gRPC Layer:
- Migrated gRPC modules to 'connpy/grpc_layer/'.
- Implemented dynamic node naming (e.g. ssm-i-xxxx@aws) for accurate server-side logging.
- Added automatic sys.path resolution for gRPC generated modules.
- Enhanced InteractNode response with initial connection status.

Printer & Concurrency:
- Implemented ThreadLocalStream for isolated thread-safe output.
- Self-healing Console objects to prevent 'closed file' errors in test/async environments.
- Capture clean plugin output in remote executions.

AI & Services:
- Improved tool registration and debug visualization.
- Restored native dictionary returns for AI tools to fix Web UI rendering.
- Increased backup retention to 100 copies in SyncService.
- Silenced noisy auto-sync CLI messages.

Quality & Docs:
- Total tests: 267 (all passing).
- New test suites for gRPC layer and printer concurrency.
- Updated .gitignore to exclude internal planning docs.
- Full technical documentation regenerated with pdoc.
This commit is contained in:
2026-04-24 19:23:00 -03:00
parent 287acde1e4
commit 1c814eb9fd
94 changed files with 12656 additions and 22613 deletions
+2 -2
View File
@@ -2,10 +2,10 @@
'''
## Connection manager
Connpy is a SSH, SFTP, Telnet, kubectl, and Docker pod connection manager and automation module for Linux, Mac, and Docker.
Connpy is a SSH, SFTP, Telnet, kubectl, Docker pod, and AWS SSM connection manager and automation module for Linux, Mac, and Docker.
### Features
- Manage connections using SSH, SFTP, Telnet, kubectl, and Docker exec.
- Manage connections using SSH, SFTP, Telnet, kubectl, Docker exec, and AWS SSM.
- Set contexts to manage specific nodes from specific contexts (work/home/clients/etc).
- You can generate profiles and reference them from nodes using @profilename so you don't
need to edit multiple nodes when changing passwords or other information.
+1 -1
View File
@@ -1 +1 @@
__version__ = "5.1b5"
__version__ = "5.1b6"
+117 -31
View File
@@ -31,6 +31,8 @@ from . import printer
from rich.markdown import Markdown
from rich.panel import Panel
from rich.text import Text
from rich.console import Group
from rich.rule import Rule
console = printer.console
@@ -209,14 +211,20 @@ class ai:
status_formatter (callable): Function(args_dict) -> status string.
"""
name = tool_definition["function"]["name"]
# Check if already registered to prevent duplicates
if target in ("engineer", "both"):
self.external_engineer_tools.append(tool_definition)
if not any(t["function"]["name"] == name for t in self.external_engineer_tools):
self.external_engineer_tools.append(tool_definition)
if target in ("architect", "both"):
self.external_architect_tools.append(tool_definition)
if not any(t["function"]["name"] == name for t in self.external_architect_tools):
self.external_architect_tools.append(tool_definition)
self.external_tool_handlers[name] = handler
if engineer_prompt:
if engineer_prompt and engineer_prompt not in self.engineer_prompt_extensions:
self.engineer_prompt_extensions.append(engineer_prompt)
if architect_prompt:
if architect_prompt and architect_prompt not in self.architect_prompt_extensions:
self.architect_prompt_extensions.append(architect_prompt)
if status_formatter:
self.tool_status_formatters[name] = status_formatter
@@ -448,12 +456,46 @@ class ai:
def _truncate(self, text, limit=None):
"""Truncate text to specified limit, keeping head (60%) and tail (40%)."""
if not isinstance(text, str): return str(text)
final_limit = limit or self.max_truncate
if len(text) <= final_limit: return text
head_limit = int(final_limit * 0.6)
tail_limit = int(final_limit * 0.4)
return (text[:head_limit] + f"\n\n[... OUTPUT TRUNCATED ...]\n\n" + text[-tail_limit:])
def _print_debug_observation(self, fn, obs):
"""Prints a tool observation in a readable way during debug mode."""
# Try to parse as JSON if it's a string
if isinstance(obs, str):
try:
obs_data = json.loads(obs)
except Exception:
obs_data = obs
else:
obs_data = obs
if isinstance(obs_data, dict):
elements = []
for k, v in obs_data.items():
elements.append(Text(f"{k}:", style="key"))
# Use Text for values to ensure newlines are rendered
val = str(v)
# If it's a multiline string from a delegation task, keep it clean
elements.append(Text(val))
if not elements:
content = Text("Empty data set")
else:
# Add a small spacer instead of a Rule for cleaner look
content = Group(*elements)
elif isinstance(obs_data, list):
content = Text("\n".join(f"{item}" for item in obs_data))
else:
content = Text(str(obs_data))
title = f"[bold]{fn}[/bold]"
self.console.print(Panel(content, title=title, border_style="ai_status"))
def manage_memory_tool(self, content, action="append"):
"""Save or update long-term memory. Only use when user explicitly requests it."""
if not content or not content.strip():
@@ -491,8 +533,8 @@ class ai:
ts = data.get("tags")
if isinstance(ts, dict): os_tag = ts.get("os", "unknown")
res[name] = {"os": os_tag}
return json.dumps(res)
return json.dumps({"count": len(matched_names), "nodes": matched_names, "note": "Use 'get_node_info' for details."})
return res
return {"count": len(matched_names), "nodes": matched_names, "note": "Use 'get_node_info' for details."}
except Exception as e:
return f"Error listing nodes: {str(e)}"
@@ -566,7 +608,7 @@ class ai:
if not matched_names: return "No nodes found matching filter."
thisnodes_dict = self.config.getitems(matched_names, extract=True)
result = nodes(thisnodes_dict, config=self.config).run(commands)
return self._truncate(json.dumps(result))
return result
except Exception as e:
return f"Error executing commands: {str(e)}"
@@ -575,7 +617,7 @@ class ai:
try:
d = self.config.getitem(node_name, extract=True)
if 'password' in d: d['password'] = '***'
return json.dumps(d)
return d
except Exception as e:
return f"Error getting node info: {str(e)}"
@@ -619,7 +661,7 @@ class ai:
self.console.print(f"[warning] You can press Ctrl+C to interrupt and get a summary.[/warning]")
soft_limit_warned = True
if status: status.update(f"[ai_status]Engineer: Analyzing mission... (step {iteration})")
if status and not chat_history: status.update(f"[ai_status]Engineer: Analyzing mission... (step {iteration})")
try:
safe_messages = self._sanitize_messages(messages)
@@ -642,8 +684,8 @@ class ai:
for tc in resp_msg.tool_calls:
fn, args = tc.function.name, json.loads(tc.function.arguments)
# Notificación en tiempo real de la tarea técnica
if status:
# Notificación en tiempo real de la tarea técnica (Only if not in Architect loop)
if status and not chat_history:
if fn == "list_nodes": status.update(f"[ai_status]Engineer: [SEARCH] {args.get('filter_pattern','.*')}")
elif fn == "run_commands":
cmds = args.get('commands', [])
@@ -652,7 +694,8 @@ class ai:
elif fn == "get_node_info": status.update(f"[ai_status]Engineer: [INSPECT] {args.get('node_name','')}")
elif fn in self.tool_status_formatters: status.update(self.tool_status_formatters[fn](args))
if debug: self.console.print(Panel(Text(json.dumps(args, indent=2)), title=f"[bold engineer]Engineer Tool: {fn}[/bold engineer]", border_style="engineer"))
if debug:
self._print_debug_observation(f"Decision: {fn}", args)
if fn == "list_nodes": obs = self.list_nodes_tool(**args)
elif fn == "run_commands": obs = self.run_commands_tool(**args, status=status)
@@ -660,8 +703,12 @@ class ai:
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
else: obs = f"Error: Unknown tool '{fn}'."
if debug: self.console.print(Panel(Text(str(obs)), title=f"[bold pass]Engineer Observation: {fn}[/bold pass]", border_style="success"))
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": obs})
if debug:
self._print_debug_observation(f"Observation: {fn}", obs)
# Ensure observation is a string and truncated for the LLM
obs_str = obs if isinstance(obs, str) else json.dumps(obs)
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": self._truncate(obs_str)})
if iteration >= self.hard_limit_iterations:
self.console.print(f"[error]⛔ Engineer reached hard limit ({self.hard_limit_iterations} steps). Forcing stop.[/error]")
@@ -675,30 +722,46 @@ class ai:
def _get_engineer_tools(self):
"""Define tools available to the Engineer."""
tools = [
base_tools = [
{"type": "function", "function": {"name": "list_nodes", "description": "Lists available nodes in the inventory.", "parameters": {"type": "object", "properties": {"filter_pattern": {"type": "string", "description": "Regex to filter nodes (e.g. '.*', 'border.*')."}}}}},
{"type": "function", "function": {"name": "run_commands", "description": "Runs one or more commands on matched nodes. MANDATORY: You MUST call 'list_nodes' first to verify the target list.", "parameters": {"type": "object", "properties": {"nodes_filter": {"type": "string", "description": "Exact node name or verified filter pattern."}, "commands": {"type": "array", "items": {"type": "string"}, "description": "List of commands (e.g. ['show ip route', 'show int desc'])."}}, "required": ["nodes_filter", "commands"]}}},
{"type": "function", "function": {"name": "get_node_info", "description": "Gets full metadata for a specific node.", "parameters": {"type": "object", "properties": {"node_name": {"type": "string"}}, "required": ["node_name"]}}}
]
if self.architect_key:
tools.extend([
base_tools.extend([
{"type": "function", "function": {"name": "consult_architect", "description": "Ask the Strategic Reasoning Engine for advice on complex design, architecture, or troubleshooting decisions. You remain in control and will present the response to the user. Use this for: configuration planning, design validation, complex troubleshooting.", "parameters": {"type": "object", "properties": {"question": {"type": "string", "description": "Strategic question or decision needed."}, "technical_summary": {"type": "string", "description": "Technical findings and context gathered so far."}}, "required": ["question", "technical_summary"]}}},
{"type": "function", "function": {"name": "escalate_to_architect", "description": "Transfer full control to the Strategic Reasoning Engine. Use ONLY when the user explicitly requests the Architect or when the problem requires strategic oversight beyond consultation. After escalation, the Architect takes over the conversation.", "parameters": {"type": "object", "properties": {"reason": {"type": "string", "description": "Why you're escalating (e.g. 'User requested Architect', 'Complex multi-site design needed')."}, "context": {"type": "string", "description": "Full context and findings to hand over."}}, "required": ["reason", "context"]}}}
])
tools.extend(self.external_engineer_tools)
return tools
# Deduplicate by name to prevent Gemini BadRequestError
all_tools = base_tools + self.external_engineer_tools
seen_names = set()
unique_tools = []
for t in all_tools:
name = t["function"]["name"]
if name not in seen_names:
unique_tools.append(t)
seen_names.add(name)
return unique_tools
def _get_architect_tools(self):
"""Define tools available to the Strategic Reasoning Engine."""
tools = [
base_tools = [
{"type": "function", "function": {"name": "delegate_to_engineer", "description": "Delegates a technical mission to the Engineer.", "parameters": {"type": "object", "properties": {"task": {"type": "string", "description": "Detailed technical mission or goal."}}, "required": ["task"]}}},
{"type": "function", "function": {"name": "return_to_engineer", "description": "Return control to the Engineer. Use this when your strategic analysis is complete and the Engineer should handle the rest of the conversation.", "parameters": {"type": "object", "properties": {"summary": {"type": "string", "description": "Brief summary of your analysis to hand over to the Engineer."}}, "required": ["summary"]}}},
{"type": "function", "function": {"name": "manage_memory_tool", "description": "Saves information to long-term memory. MANDATORY: Only use this if the user explicitly asks to remember or save something.", "parameters": {"type": "object", "properties": {"content": {"type": "string"}, "action": {"type": "string", "enum": ["append", "replace"]}}, "required": ["content"]}}}
]
tools.extend(self.external_architect_tools)
return tools
all_tools = base_tools + self.external_architect_tools
seen_names = set()
unique_tools = []
for t in all_tools:
name = t["function"]["name"]
if name not in seen_names:
unique_tools.append(t)
seen_names.add(name)
return unique_tools
def _get_sessions(self):
"""Returns a list of session metadata sorted by date."""
@@ -902,12 +965,16 @@ class ai:
soft_limit_warned = True
label = "[architect][bold]Architect[/bold][/architect]" if current_brain == "architect" else "[engineer][bold]Engineer[/bold][/engineer]"
if status: status.update(f"{label} is thinking... (step {iteration})")
if status:
# Notify responder identity ONLY for web/remote clients (StatusBridge has is_web)
if getattr(status, "is_web", False):
status.update(f"__RESPONDER__:{current_brain}")
status.update(f"{label} is thinking... (step {iteration})")
streamed_response = False
try:
safe_messages = self._sanitize_messages(messages)
if stream and not debug:
if stream and (not debug or chunk_callback):
response, streamed_response = self._stream_completion(
model=model, messages=safe_messages, tools=tools, api_key=key,
status=status, label=label, debug=debug, num_retries=3,
@@ -947,7 +1014,10 @@ class ai:
messages.append(msg_dict)
if debug and resp_msg.content:
self.console.print(Panel(Markdown(resp_msg.content), title=f"{label} Reasoning", border_style="architect" if current_brain == "architect" else "engineer"))
# In CLI debug mode, only print intermediate reasoning if there are tool calls.
# If there are no tool calls, this content is the final answer and will be printed by the caller.
if resp_msg.tool_calls:
self.console.print(Panel(Markdown(resp_msg.content), title=f"[{current_brain}][bold]{label} Reasoning[/bold][/{current_brain}]", border_style="architect" if current_brain == "architect" else "engineer"))
if not resp_msg.tool_calls: break
@@ -967,7 +1037,8 @@ class ai:
if fn == "delegate_to_engineer": status.update(f"[architect]Architect: [DELEGATING MISSION] {args.get('task','')[:40]}...")
elif fn == "manage_memory_tool": status.update(f"[architect]Architect: [UPDATING MEMORY]")
if debug: self.console.print(Panel(Text(json.dumps(args, indent=2)), title=f"{label} Decision: {fn}", border_style="debug"))
if debug:
self._print_debug_observation(f"Decision: {fn}", args)
if fn == "delegate_to_engineer":
obs, eng_usage = self._engineer_loop(args["task"], status=status, debug=debug, chat_history=messages[:-1])
@@ -1025,9 +1096,13 @@ class ai:
elif fn == "manage_memory_tool": obs = self.manage_memory_tool(**args)
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
else: obs = f"Error: {fn} unknown."
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": obs})
if debug and fn not in ["delegate_to_engineer", "consult_architect", "escalate_to_architect", "return_to_engineer"]:
self._print_debug_observation(f"Observation: {fn}", obs)
# Ensure observation is a string and truncated for the LLM
obs_str = obs if isinstance(obs, str) else json.dumps(obs)
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": self._truncate(obs_str)})
# Inject pending user message AFTER all tool responses are added
if pending_user_message:
messages.append({"role": "user", "content": pending_user_message})
@@ -1053,14 +1128,25 @@ class ai:
if last_msg.get("tool_calls"):
for tc in last_msg["tool_calls"]:
messages.append({"tool_call_id": tc.get("id"), "role": "tool", "name": tc.get("function", {}).get("name"), "content": "Operation cancelled by user."})
messages.append({"role": "user", "content": "USER INTERRUPTED. Briefly summarize what you were doing and stop."})
# Use a fresh list for the summary call to avoid history corruption
summary_messages = list(messages)
summary_messages.append({"role": "user", "content": "USER INTERRUPTED. Briefly summarize what you were doing and stop."})
try:
safe_messages = self._sanitize_messages(messages)
safe_messages = self._sanitize_messages(summary_messages)
# Use tools=None to force a text summary during interruption
response = completion(model=model, messages=safe_messages, tools=None, api_key=key)
resp_msg = response.choices[0].message
messages.append(resp_msg.model_dump(exclude_none=True))
except Exception: pass
# IMPORTANT: Manually trigger callback for the summary so Web UI sees it
if chunk_callback and resp_msg.content:
chunk_callback(resp_msg.content)
except Exception:
error_msg = "Operation interrupted by user. Summary unavailable."
messages.append({"role": "assistant", "content": error_msg})
if chunk_callback:
chunk_callback(error_msg)
finally:
# Auto-save session
self.save_session(messages, model=model)
+2 -2
View File
@@ -48,7 +48,7 @@ def stop_api():
return port
def debug_api(port=8048, config=None):
from .grpc.server import serve
from .grpc_layer.server import serve
conf = config or configfile()
server = serve(conf, port=port, debug=True)
printer.info(f"gRPC Server running in debug mode on port {port}...")
@@ -63,7 +63,7 @@ def start_server(port=8048, config=None):
if base_dir not in sys.path:
sys.path.insert(0, base_dir)
from connpy.grpc.server import serve
from connpy.grpc_layer.server import serve
conf = config or configfile()
server = serve(conf, port=port, debug=False)
_wait_for_termination()
+1
View File
@@ -32,6 +32,7 @@ Here are some important instructions and tips for configuring your new node:
- telnet
- kubectl (`kubectl exec`)
- docker (`docker exec`)
- ssm (`aws ssm start-session`)
3. **Optional Values**:
- You can leave any value empty except for the hostname/IP.
+4
View File
@@ -122,6 +122,10 @@ class NodeHandler:
printer.error(f"Node '{args.data}' already exists.")
sys.exit(1)
uniques = self.app.services.nodes.explode_unique(args.data)
# Fast fail if parent folder does not exist
self.app.services.nodes.validate_parent_folder(args.data)
printer.console.print(Markdown(get_instructions()))
new_node_data = self.forms.questions_nodes(args.data, uniques)
+4 -4
View File
@@ -14,14 +14,14 @@ class Validators:
raise inquirer.errors.ValidationError("", reason="Profile {} don't exist".format(current))
return True
def profile_protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^$)"):
def profile_protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^ssm$|^$)"):
if not re.match(regex, current):
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker or leave empty")
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker, ssm or leave empty")
return True
def protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^$|^@.+$)"):
def protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^ssm$|^$|^@.+$)"):
if not re.match(regex, current):
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker leave empty or @profile")
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker, ssm, leave empty or @profile")
if current.startswith("@"):
if current[1:] not in self.app.profiles:
raise inquirer.errors.ValidationError("", reason="Profile {} don't exist".format(current))
+12 -1
View File
@@ -445,8 +445,19 @@ class connapp:
try:
if args.subcommand in getattr(self.plugins, "remote_plugins", {}):
import json as _json
for chunk in self.services.plugins.invoke_plugin(args.subcommand, args):
print(chunk, end="", flush=True)
if "__interact__" in chunk:
try:
data = _json.loads(chunk.strip())
params = data.get("__interact__")
if params:
self.services.nodes.connect_dynamic(params, debug=getattr(args, 'debug', False))
break
except (ValueError, KeyError):
print(chunk, end="", flush=True)
else:
print(chunk, end="", flush=True)
elif args.subcommand in self.plugins.plugins:
self.plugins.plugins[args.subcommand].Entrypoint(args, self.plugins.plugin_parsers[args.subcommand].parser, self)
else:
+30 -8
View File
@@ -264,7 +264,8 @@ class node:
size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size()))
self.child.setwinsize(int(size.group(2)),int(size.group(1)))
if logger:
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
if 'logfile' in dir(self):
# Initialize self.mylog
@@ -343,7 +344,8 @@ class node:
now = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
if connect == True:
if logger:
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
# Attempt to set the terminal size
try:
@@ -444,7 +446,8 @@ class node:
connect = self._connect(timeout = timeout, logger = logger)
if connect == True:
if logger:
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
# Attempt to set the terminal size
try:
@@ -549,6 +552,19 @@ class node:
cmd += f" {docker_command}"
return cmd
@MethodHook
def _generate_ssm_cmd(self):
region = self.tags.get("region", "") if isinstance(self.tags, dict) else ""
profile = self.tags.get("profile", "") if isinstance(self.tags, dict) else ""
cmd = f"aws ssm start-session --target {self.host}"
if region:
cmd += f" --region {region}"
if profile:
cmd += f" --profile {profile}"
if self.options:
cmd += f" {self.options}"
return cmd
@MethodHook
def _get_cmd(self):
if self.protocol in ["ssh", "sftp"]:
@@ -559,6 +575,8 @@ class node:
return self._generate_kube_cmd()
elif self.protocol == "docker":
return self._generate_docker_cmd()
elif self.protocol == "ssm":
return self._generate_ssm_cmd()
else:
printer.error(f"Invalid protocol: {self.protocol}")
sys.exit(1)
@@ -579,7 +597,8 @@ class node:
"sftp": ['yes/no', 'refused', 'supported', 'Invalid|[u|U]sage: sftp', 'ssh-keygen.*\"', 'timeout|timed.out', 'unavailable', 'closed', password_prompt, prompt, 'suspend', pexpect.EOF, pexpect.TIMEOUT, "No route to host", "resolve hostname", "no matching", "[b|B]ad (owner|permissions)"],
"telnet": ['[u|U]sername:', 'refused', 'supported', 'invalid|unrecognized option', 'ssh-keygen.*\"', 'timeout|timed.out', 'unavailable', 'closed', password_prompt, prompt, 'suspend', pexpect.EOF, pexpect.TIMEOUT, "No route to host", "resolve hostname", "no matching", "[b|B]ad (owner|permissions)"],
"kubectl": ['[u|U]sername:', '[r|R]efused', '[E|e]rror', 'DEPRECATED', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF, "expired|invalid"],
"docker": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'not a docker command', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF]
"docker": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'not a docker command', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF],
"ssm": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'SessionManagerPlugin', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF]
}
error_indices = {
@@ -587,7 +606,8 @@ class node:
"sftp": [1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16],
"telnet": [1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16],
"kubectl": [1, 2, 3, 4, 8], # Define error indices for kube
"docker": [1, 2, 3, 4, 5, 6, 7] # Define error indices for docker
"docker": [1, 2, 3, 4, 5, 6, 7], # Define error indices for docker
"ssm": [1, 2, 3, 4, 5, 6, 7]
}
eof_indices = {
@@ -595,7 +615,8 @@ class node:
"sftp": [8, 9, 10, 11],
"telnet": [8, 9, 10, 11],
"kubectl": [5, 6, 7], # Define eof indices for kube
"docker": [8, 9, 10] # Define eof indices for docker
"docker": [8, 9, 10], # Define eof indices for docker
"ssm": [8, 9, 10]
}
initial_indices = {
@@ -603,7 +624,8 @@ class node:
"sftp": [0],
"telnet": [0],
"kubectl": [0], # Define special indices for kube
"docker": [0] # Define special indices for docker
"docker": [0], # Define special indices for docker
"ssm": [0]
}
attempts = 1
@@ -627,7 +649,7 @@ class node:
if results in initial_indices[self.protocol]:
if self.protocol in ["ssh", "sftp"]:
child.sendline('yes')
elif self.protocol in ["telnet", "kubectl", "docker"]:
elif self.protocol in ["telnet", "kubectl", "docker", "ssm"]:
if self.user:
child.sendline(self.user)
else:
View File
File diff suppressed because one or more lines are too long
+8
View File
@@ -0,0 +1,8 @@
import sys
import os
# gRPC generated files use absolute imports that assume their directory is in sys.path.
# We add this directory to sys.path to allow imports like 'import connpy_pb2' to succeed.
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
File diff suppressed because one or more lines are too long
@@ -3,7 +3,7 @@
import grpc
import warnings
from . import connpy_pb2 as connpy__pb2
import connpy_pb2 as connpy__pb2
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
GRPC_GENERATED_VERSION = '1.80.0'
@@ -85,6 +85,11 @@ class NodeServiceStub(object):
request_serializer=connpy__pb2.BulkRequest.SerializeToString,
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
_registered_method=True)
self.validate_parent_folder = channel.unary_unary(
'/connpy.NodeService/validate_parent_folder',
request_serializer=connpy__pb2.IdRequest.SerializeToString,
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
_registered_method=True)
self.set_reserved_names = channel.unary_unary(
'/connpy.NodeService/set_reserved_names',
request_serializer=connpy__pb2.ListRequest.SerializeToString,
@@ -170,6 +175,12 @@ class NodeServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def validate_parent_folder(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def set_reserved_names(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -247,6 +258,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
request_deserializer=connpy__pb2.BulkRequest.FromString,
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
),
'validate_parent_folder': grpc.unary_unary_rpc_method_handler(
servicer.validate_parent_folder,
request_deserializer=connpy__pb2.IdRequest.FromString,
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
),
'set_reserved_names': grpc.unary_unary_rpc_method_handler(
servicer.set_reserved_names,
request_deserializer=connpy__pb2.ListRequest.FromString,
@@ -548,6 +564,33 @@ class NodeService(object):
metadata,
_registered_method=True)
@staticmethod
def validate_parent_folder(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/connpy.NodeService/validate_parent_folder',
connpy__pb2.IdRequest.SerializeToString,
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def set_reserved_names(request,
target,
@@ -12,6 +12,7 @@ from . import connpy_pb2, connpy_pb2_grpc, remote_plugin_pb2, remote_plugin_pb2_
import json
from .utils import to_value, from_value, to_struct, from_struct
from ..services.exceptions import ConnpyError
from .. import printer
# Import local services
from ..services.node_service import NodeService
@@ -24,16 +25,34 @@ from ..services.execution_service import ExecutionService
from ..services.import_export_service import ImportExportService
def handle_errors(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ConnpyError as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.INTERNAL, str(e))
except Exception as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.UNKNOWN, str(e))
return wrapper
import inspect
if inspect.isgeneratorfunction(func):
def wrapper(*args, **kwargs):
try:
for item in func(*args, **kwargs):
yield item
except ConnpyError as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.INTERNAL, str(e))
except Exception as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.UNKNOWN, str(e))
finally:
printer.clear_thread_state()
return wrapper
else:
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ConnpyError as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.INTERNAL, str(e))
except Exception as e:
context = kwargs.get("context") or args[-1]
context.abort(grpc.StatusCode.UNKNOWN, str(e))
finally:
printer.clear_thread_state()
return wrapper
class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
def __init__(self, config):
@@ -56,18 +75,72 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
unique_id = first_req.id
sftp = first_req.sftp
debug = first_req.debug
printer.console.print(f"[debug][DEBUG][/debug] gRPC interact_node request for: [bold cyan]{unique_id}[/bold cyan]")
node_data = self.service.config.getitem(unique_id, extract=False)
profile_service = ProfileService(self.service.config)
resolved_data = profile_service.resolve_node_data(node_data)
n = node(unique_id, **resolved_data, config=self.service.config)
if sftp:
n.protocol = "sftp"
if first_req.connection_params_json:
import json
params = json.loads(first_req.connection_params_json)
base_node_id = params.get("base_node")
# Valid attributes that a node object accepts
valid_attrs = ['host', 'options', 'logs', 'password', 'port', 'protocol', 'user', 'jumphost']
fallback_id = f"{unique_id}@remote"
if unique_id == "dynamic" and params.get("host"):
fallback_id = f"dynamic-{params.get('host')}@remote"
if base_node_id:
# Look up the base node in config and use its full data
nodes = self.service.config._getallnodes(base_node_id)
if nodes:
device = self.service.config.getitem(nodes[0])
# Override device properties with any passed in params
for attr in valid_attrs:
if attr in params:
device[attr] = params[attr]
if "tags" in params:
device_tags = device.get("tags", {})
if not isinstance(device_tags, dict):
device_tags = {}
device_tags.update(params["tags"])
device["tags"] = device_tags
node_name = params.get("name", base_node_id)
n = node(node_name, **device, config=self.service.config)
else:
# base_node not found, fall back to dynamic
node_name = params.get("name", fallback_id)
n = node(node_name, host=params.get("host", ""), config=self.service.config)
for attr in valid_attrs:
if attr in params:
setattr(n, attr, params[attr])
if "tags" in params:
n.tags = params["tags"]
else:
node_name = params.get("name", fallback_id)
n = node(node_name, host=params.get("host", ""), config=self.service.config)
for attr in valid_attrs:
if attr in params:
setattr(n, attr, params[attr])
if "tags" in params:
n.tags = params["tags"]
else:
node_data = self.service.config.getitem(unique_id, extract=False)
if not node_data:
context.abort(grpc.StatusCode.NOT_FOUND, f"Node {unique_id} not found")
profile_service = ProfileService(self.service.config)
resolved_data = profile_service.resolve_node_data(node_data)
n = node(unique_id, **resolved_data, config=self.service.config)
if sftp:
n.protocol = "sftp"
connect = n._connect(debug=debug)
if connect != True:
context.abort(grpc.StatusCode.INTERNAL, "Failed to connect to node")
yield connpy_pb2.InteractResponse(success=False, error_message=str(connect))
return
# Signal successful connection to the client
yield connpy_pb2.InteractResponse(success=True)
import threading
import queue
@@ -145,6 +218,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
def explode_unique(self, request, context):
return connpy_pb2.ValueResponse(data=to_value(self.service.explode_unique(request.id)))
@handle_errors
def validate_parent_folder(self, request, context):
self.service.validate_parent_folder(request.id)
return Empty()
@handle_errors
def generate_cache(self, request, context):
self.service.generate_cache()
@@ -446,16 +524,18 @@ class ImportExportServicer(connpy_pb2_grpc.ImportExportServiceServicer):
return Empty()
class StatusBridge:
def __init__(self, q, request_queue=None):
def __init__(self, q, request_queue=None, is_web=False):
self.q = q
self.request_queue = request_queue
self.on_interrupt = self._force_interrupt
self.thread = None
self.is_web = is_web
def _force_interrupt(self):
"""Forcefully raise KeyboardInterrupt in the target thread."""
if self.thread and self.thread.ident:
# Standard Python trick to raise an exception in a specific thread
import ctypes
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(self.thread.ident),
ctypes.py_object(KeyboardInterrupt)
@@ -477,13 +557,32 @@ class StatusBridge:
def _print_to_queue(self, msg_type, *args, **kwargs):
from rich.console import Console
from rich.panel import Panel
from io import StringIO
from ..printer import connpy_theme
processed_args = list(args)
if self.is_web:
# Remove Panels to avoid box characters on web, but preserve Title
processed_args = []
for arg in args:
if isinstance(arg, Panel):
# If it has a title, prepend it to the content to allow detection
content = arg.renderable
if arg.title:
processed_args.append(f"{arg.title}\n")
processed_args.append(content)
else:
processed_args.append(arg)
buf = StringIO()
# Use a high-quality console for rendering with the app's theme
c = Console(file=buf, force_terminal=True, width=100, theme=connpy_theme)
c.print(*args, **kwargs)
self.q.put((msg_type, buf.getvalue()))
# force_terminal=False removes ANSI escape codes for Web
c = Console(file=buf, force_terminal=not self.is_web, width=100, theme=connpy_theme)
c.print(*processed_args, **kwargs)
text_content = buf.getvalue().strip()
if text_content:
self.q.put((msg_type, text_content))
def confirm(self, prompt, default="n"):
"""Bridge confirmation to the gRPC client."""
@@ -520,94 +619,108 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
def ask(self, request_iterator, context):
import queue
import threading
# In bidirectional mode, the first request contains the query
try:
first_request = next(request_iterator)
except StopIteration:
return
history = from_value(first_request.chat_history)
overrides = {}
if first_request.engineer_model: overrides["engineer_model"] = first_request.engineer_model
if first_request.engineer_api_key: overrides["engineer_api_key"] = first_request.engineer_api_key
if first_request.architect_model: overrides["architect_model"] = first_request.architect_model
if first_request.architect_api_key: overrides["architect_api_key"] = first_request.architect_api_key
chunk_queue = queue.Queue()
request_queue = queue.Queue()
bridge = StatusBridge(chunk_queue, request_queue=request_queue)
bridge = None
history = []
is_web = False
# Start a thread to pull subsequent requests from the client (confirmations)
def pull_requests():
try:
for req in request_iterator:
if req.interrupt and bridge.on_interrupt:
bridge.on_interrupt()
request_queue.put(req)
except Exception:
pass
finally:
request_queue.put(None)
threading.Thread(target=pull_requests, daemon=True).start()
# Dedicated event to signal AI thread to stop
ai_thread = None
agent_instance = None
def callback(chunk):
chunk_queue.put(("text", chunk))
result_container = {}
def run_ai():
def run_ai_task(input_text, session_id, debug, overrides, trust):
nonlocal history, bridge, agent_instance
try:
# Run the AI interaction (this blocks this specific thread)
res = self.service.ask(
first_request.input_text,
dryrun=first_request.dryrun,
input_text,
chat_history=history if history else None,
session_id=first_request.session_id if first_request.session_id else None,
debug=first_request.debug,
session_id=session_id,
debug=debug,
status=bridge,
console=bridge,
confirm_handler=bridge.confirm,
chunk_callback=callback,
trust=first_request.trust,
trust=trust,
**overrides
)
result_container["res"] = res
# Update history for next message
if "chat_history" in res:
history = res["chat_history"]
# Send final chunk marker
chunk_queue.put(("final_mark", res))
except Exception as e:
chunk_queue.put(("status", f"[bold fail]Error: {str(e)}[/bold fail]"))
result_container["error"] = e
import traceback
print(f"AI Task Error: {e}")
traceback.print_exc()
chunk_queue.put(("status", f"Error: {str(e)}"))
def request_listener():
nonlocal bridge, is_web, ai_thread, agent_instance
try:
for req in request_iterator:
if req.interrupt:
if bridge and bridge.on_interrupt:
bridge.on_interrupt()
continue
if req.confirmation_answer:
request_queue.put(req)
continue
if req.input_text:
is_web = "web" in (req.session_id or "").lower() or (req.session_id or "").lower().startswith("ws-")
if not bridge:
bridge = StatusBridge(chunk_queue, request_queue=request_queue, is_web=is_web)
overrides = {}
if req.engineer_model: overrides["engineer_model"] = req.engineer_model
if req.engineer_api_key: overrides["engineer_api_key"] = req.engineer_api_key
# Start AI in its own thread so we can keep listening for interrupts
ai_thread = threading.Thread(
target=run_ai_task,
args=(req.input_text, req.session_id, req.debug, overrides, req.trust),
daemon=True
)
ai_thread.start()
except Exception as e:
print(f"Request Listener Error: {e}")
finally:
chunk_queue.put(None) # Sentinel
# When client closes stream, send sentinel
chunk_queue.put((None, None))
t = threading.Thread(target=run_ai, daemon=True)
bridge.thread = t
t.start()
# Start listening for client requests/signals
threading.Thread(target=request_listener, daemon=True).start()
# Main response loop (yields to gRPC)
while True:
item = chunk_queue.get()
if item is None:
if item == (None, None):
break
msg_type, val = item
if msg_type == "text":
yield connpy_pb2.AIResponse(text_chunk=val, is_final=False)
elif msg_type == "status":
yield connpy_pb2.AIResponse(status_update=val, is_final=False)
if is_web and "is thinking" in val.lower(): continue
clean_val = val.replace("[ai_status]", "").replace("[/ai_status]", "")
yield connpy_pb2.AIResponse(status_update=clean_val, is_final=False)
elif msg_type == "debug":
yield connpy_pb2.AIResponse(debug_message=val, is_final=False)
elif msg_type == "important":
yield connpy_pb2.AIResponse(important_message=val, is_final=False)
elif msg_type == "confirm":
yield connpy_pb2.AIResponse(status_update=val, requires_confirmation=True, is_final=False)
if "error" in result_container:
raise result_container["error"]
yield connpy_pb2.AIResponse(
is_final=True,
full_result=to_struct(result_container.get("res", {}))
)
elif msg_type == "final_mark":
yield connpy_pb2.AIResponse(is_final=True, full_result=to_struct(val))
@handle_errors
def confirm(self, request, context):
@@ -663,8 +776,8 @@ class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
class LoggingInterceptor(grpc.ServerInterceptor):
def __init__(self):
from rich.console import Console
from ..printer import connpy_theme
self.console = Console(theme=connpy_theme)
from ..printer import connpy_theme, get_original_stdout
self.console = Console(theme=connpy_theme, file=get_original_stdout())
def intercept_service(self, continuation, handler_call_details):
import time
@@ -73,11 +73,112 @@ class NodeStub:
except OSError:
break
# Fetch node details for the connection message
try:
node_details = self.get_node_details(unique_id)
host = node_details.get("host", "unknown")
port = str(node_details.get("port", ""))
protocol = "sftp" if sftp else node_details.get("protocol", "ssh")
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
conn_msg = f"Connected to {unique_id} at {host}{port_str} via: {protocol}"
except Exception:
conn_msg = f"Connected to {unique_id}"
old_tty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
response_iterator = self.stub.interact_node(request_generator())
# First response is connection status
try:
first_res = next(response_iterator)
if first_res.success:
# Connection established on server, show success message
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.success(conn_msg)
tty.setraw(sys.stdin.fileno())
else:
# Connection failed on server
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.error(f"Connection failed: {first_res.error_message}")
return
except StopIteration:
return
for res in response_iterator:
if res.stdout_data:
os.write(sys.stdout.fileno(), res.stdout_data)
finally:
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
@handle_errors
def connect_dynamic(self, connection_params, debug=False):
import sys
import select
import tty
import termios
import os
import json
params_json = json.dumps(connection_params)
def request_generator():
cols, rows = 80, 24
try:
size = os.get_terminal_size()
cols, rows = size.columns, size.lines
except OSError:
pass
yield connpy_pb2.InteractRequest(
id="dynamic", debug=debug, cols=cols, rows=rows,
connection_params_json=params_json
)
while True:
r, _, _ = select.select([sys.stdin.fileno()], [], [])
if r:
try:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
yield connpy_pb2.InteractRequest(stdin_data=data)
except OSError:
break
# Prepare connection message
try:
node_name = connection_params.get("name", "dynamic@remote")
host = connection_params.get("host", "dynamic")
port = str(connection_params.get("port", ""))
protocol = connection_params.get("protocol", "ssh")
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
conn_msg = f"Connected to {node_name} at {host}{port_str} via: {protocol}"
except Exception:
node_name = connection_params.get("name", "dynamic@remote") if isinstance(connection_params, dict) else "dynamic@remote"
conn_msg = f"Connected to {node_name}"
old_tty = termios.tcgetattr(sys.stdin)
try:
tty.setraw(sys.stdin.fileno())
response_iterator = self.stub.interact_node(request_generator())
# First response is connection status
try:
first_res = next(response_iterator)
if first_res.success:
# Connection established on server, show success message
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.success(conn_msg)
tty.setraw(sys.stdin.fileno())
else:
# Connection failed on server
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
printer.error(f"Connection failed: {first_res.error_message}")
return
except StopIteration:
return
for res in response_iterator:
if res.stdout_data:
os.write(sys.stdout.fileno(), res.stdout_data)
@@ -104,6 +205,10 @@ class NodeStub:
def explode_unique(self, unique_id):
return from_value(self.stub.explode_unique(connpy_pb2.IdRequest(id=unique_id)).data)
@handle_errors
def validate_parent_folder(self, unique_id):
self.stub.validate_parent_folder(connpy_pb2.IdRequest(id=unique_id))
@handle_errors
def generate_cache(self, nodes=None, folders=None, profiles=None):
# 1. Update remote cache on server
@@ -226,6 +331,30 @@ class ProfileStub:
if self.node_stub:
self.node_stub._trigger_local_cache_sync()
class ConfigStub:
def __init__(self, channel, remote_host):
self.stub = connpy_pb2_grpc.ConfigServiceStub(channel)
self.remote_host = remote_host
@handle_errors
def get_settings(self):
return from_struct(self.stub.get_settings(Empty()).data)
@handle_errors
def update_setting(self, key, value):
self.stub.update_setting(connpy_pb2.UpdateRequest(key=key, value=to_value(value)))
@handle_errors
def get_default_dir(self):
return self.stub.get_default_dir(Empty()).value
@handle_errors
def set_config_folder(self, folder):
self.stub.set_config_folder(connpy_pb2.StringRequest(value=folder))
@handle_errors
def encrypt_password(self, password):
return self.stub.encrypt_password(connpy_pb2.StringRequest(value=password)).value
class PluginStub:
def __init__(self, channel, remote_host):
+146 -30
View File
@@ -1,7 +1,71 @@
# Lazy-loaded printer module to speed up CLI startup
_console = None
_err_console = None
_theme = None
import sys
import threading
import io
_local = threading.local()
class ThreadLocalStream:
def __init__(self, original):
self._original = original
def _get_stream(self):
s = getattr(_local, 'stream', None)
return s if s is not None else self._original
def write(self, data):
stream = self._get_stream()
if stream:
stream.write(data)
def flush(self):
stream = self._get_stream()
if stream:
stream.flush()
def isatty(self):
stream = self._get_stream()
return stream.isatty() if stream else False
def __getattr__(self, name):
# Avoid recursion during initialization or if _original is not yet set
if name in ('_original', '_get_stream'):
raise AttributeError(name)
stream = self._get_stream()
if stream:
return getattr(stream, name)
raise AttributeError(f"'NoneType' object has no attribute '{name}'")
# Patch stdout/stderr only once at module level
if not isinstance(sys.stdout, ThreadLocalStream):
sys.stdout = ThreadLocalStream(sys.stdout)
if not isinstance(sys.stderr, ThreadLocalStream):
sys.stderr = ThreadLocalStream(sys.stderr)
def _get_local():
if not hasattr(_local, 'console'):
_local.console = None
if not hasattr(_local, 'err_console'):
_local.err_console = None
if not hasattr(_local, 'theme'):
_local.theme = None
return _local
def set_thread_stream(stream):
if stream is None:
if hasattr(_local, 'stream'):
del _local.stream
else:
_local.stream = stream
def get_original_stdout():
if isinstance(sys.stdout, ThreadLocalStream):
return sys.stdout._original
return sys.stdout
def get_original_stderr():
if isinstance(sys.stderr, ThreadLocalStream):
return sys.stderr._original
return sys.stderr
# Centralized design system
STYLES = {
@@ -23,24 +87,76 @@ STYLES = {
}
def _get_console():
global _console, _theme
if _console is None:
local = _get_local()
# Self-healing patch: if sys.stdout was replaced (e.g. by pytest), re-wrap it.
if not isinstance(sys.stdout, ThreadLocalStream):
sys.stdout = ThreadLocalStream(sys.stdout)
current_out = sys.stdout
# Detect if we need to recreate the console (stream changed or closed)
needs_recreate = (local.console is None or
getattr(local, '_last_stdout', None) is not current_out)
# Extra check for closed files in test environments
if not needs_recreate and local.console is not None:
try:
if hasattr(local.console.file, 'closed') and local.console.file.closed:
needs_recreate = True
except Exception:
pass
if needs_recreate:
from rich.console import Console
from rich.theme import Theme
if _theme is None:
_theme = Theme(STYLES)
_console = Console(theme=_theme)
return _console
if local.theme is None:
local.theme = Theme(STYLES)
local.console = Console(theme=local.theme, file=current_out)
local._last_stdout = current_out
return local.console
def _get_err_console():
global _err_console, _theme
if _err_console is None:
local = _get_local()
# Self-healing patch for stderr
if not isinstance(sys.stderr, ThreadLocalStream):
sys.stderr = ThreadLocalStream(sys.stderr)
current_err = sys.stderr
needs_recreate = (local.err_console is None or
getattr(local, '_last_stderr', None) is not current_err)
if not needs_recreate and local.err_console is not None:
try:
if hasattr(local.err_console.file, 'closed') and local.err_console.file.closed:
needs_recreate = True
except Exception:
pass
if needs_recreate:
from rich.console import Console
from rich.theme import Theme
if _theme is None:
_theme = Theme(STYLES)
_err_console = Console(stderr=True, theme=_theme)
return _err_console
if local.theme is None:
local.theme = Theme(STYLES)
local.err_console = Console(stderr=True, theme=local.theme, file=current_err)
local._last_stderr = current_err
return local.err_console
def set_thread_console(console):
_get_local().console = console
def set_thread_err_console(console):
_get_local().err_console = console
def clear_thread_state():
"""Removes all thread-local printer state. Useful for gRPC thread reuse."""
for attr in ["stream", "console", "err_console", "theme", "_last_stdout", "_last_stderr"]:
if hasattr(_local, attr):
delattr(_local, attr)
@property
def console():
@@ -52,18 +168,18 @@ def err_console():
@property
def connpy_theme():
global _theme
if _theme is None:
local = _get_local()
if local.theme is None:
from rich.theme import Theme
_theme = Theme(STYLES)
return _theme
local.theme = Theme(STYLES)
return local.theme
def apply_theme(user_styles=None):
"""
Updates the global console themes with user-defined styles.
If a style is missing in user_styles, it falls back to the default in STYLES.
"""
global _theme, _console, _err_console
local = _get_local()
from rich.theme import Theme
# Start with a copy of defaults
@@ -74,11 +190,11 @@ def apply_theme(user_styles=None):
if key in active_styles:
active_styles[key] = value
_theme = Theme(active_styles)
if _console:
_console.push_theme(_theme)
if _err_console:
_err_console.push_theme(_theme)
local.theme = Theme(active_styles)
if local.console:
local.console.push_theme(local.theme)
if local.err_console:
local.err_console.push_theme(local.theme)
return active_styles
@@ -273,10 +389,10 @@ err_console = _ErrConsoleProxy()
# theme also needs to be lazy
class _ThemeProxy:
def __getattr__(self, name):
global _theme
if _theme is None:
local = _get_local()
if local.theme is None:
from rich.theme import Theme
_theme = Theme(STYLES)
return getattr(_theme, name)
local.theme = Theme(STYLES)
return getattr(local.theme, name)
connpy_theme = _ThemeProxy()
+4
View File
@@ -16,6 +16,7 @@ service NodeService {
rpc delete_node (DeleteRequest) returns (google.protobuf.Empty) {}
rpc move_node (MoveRequest) returns (google.protobuf.Empty) {}
rpc bulk_add (BulkRequest) returns (google.protobuf.Empty) {}
rpc validate_parent_folder (IdRequest) returns (google.protobuf.Empty) {}
rpc set_reserved_names (ListRequest) returns (google.protobuf.Empty) {}
rpc interact_node (stream InteractRequest) returns (stream InteractResponse) {}
rpc full_replace (FullReplaceRequest) returns (google.protobuf.Empty) {}
@@ -87,10 +88,13 @@ message InteractRequest {
bytes stdin_data = 4;
int32 cols = 5;
int32 rows = 6;
string connection_params_json = 7;
}
message InteractResponse {
bytes stdout_data = 1;
bool success = 2;
string error_message = 3;
}
message FilterRequest {
+16 -11
View File
@@ -73,10 +73,13 @@ class NodeService(BaseService):
def get_node_details(self, unique_id):
"""Return full configuration dictionary for a specific node."""
details = self.config.getitem(unique_id)
if not details:
try:
details = self.config.getitem(unique_id)
if not details:
raise NodeNotFoundError(f"Node '{unique_id}' not found.")
return details
except (KeyError, TypeError):
raise NodeNotFoundError(f"Node '{unique_id}' not found.")
return details
def explode_unique(self, unique_id):
"""Explode a unique ID into a dictionary of its parts."""
@@ -86,6 +89,14 @@ class NodeService(BaseService):
"""Generate and update the internal nodes cache."""
self.config._generate_nodes_cache(nodes=nodes, folders=folders, profiles=profiles)
def validate_parent_folder(self, unique_id):
"""Check if parent folder exists for a given node unique ID."""
node_folder = unique_id.partition("@")[2]
if node_folder:
parent_folder = f"@{node_folder}"
if parent_folder not in self.config._getallfolders():
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
def add_node(self, unique_id, data, is_folder=False):
"""Logic for adding a new node or folder to configuration."""
@@ -104,9 +115,7 @@ class NodeService(BaseService):
# Check if parent folder exists when creating a subfolder
if "subfolder" in uniques:
parent_folder = f"@{uniques['folder']}"
if parent_folder not in all_folders:
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
self.validate_parent_folder(unique_id)
self.config._folder_add(**uniques)
self.config._saveconfig(self.config.file)
@@ -115,11 +124,7 @@ class NodeService(BaseService):
raise NodeAlreadyExistsError(f"Node '{unique_id}' already exists.")
# Check if parent folder exists when creating a node in a folder
node_folder = unique_id.partition("@")[2]
if node_folder:
parent_folder = f"@{node_folder}"
if parent_folder not in all_folders:
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
self.validate_parent_folder(unique_id)
# Ensure 'id' is in data for config._connections_add
if "id" not in data:
+23 -16
View File
@@ -180,6 +180,7 @@ class PluginService(BaseService):
from ..services.exceptions import InvalidConfigurationError
from connpy.plugins import Plugins
class MockApp:
is_mock = True
def __init__(self, config):
from ..core import node, nodes
from ..ai import ai
@@ -191,14 +192,20 @@ class PluginService(BaseService):
self.ai = ai
self.services = ServiceProvider(config, mode="local")
# Get settings for CLI behavior
settings = self.services.config_svc.get_settings()
self.case = settings.get("case", False)
self.fzf = settings.get("fzf", False)
try:
self.nodes_list = self.services.nodes.list_nodes()
self.folders = self.services.nodes.list_folders()
self.profiles = self.services.profiles.list_profiles()
except Exception:
self.nodes_list = {}
self.folders = {}
self.profiles = {}
self.nodes_list = []
self.folders = []
self.profiles = []
args = Namespace(**args_dict)
@@ -225,26 +232,26 @@ class PluginService(BaseService):
from .. import printer
from rich.console import Console
from rich.console import Console
buf = io.StringIO()
old_console = printer.console
old_err_console = printer.err_console
old_console = printer._get_console()
old_err_console = printer._get_err_console()
printer.console = Console(file=buf, theme=printer.connpy_theme, force_terminal=True)
printer.err_console = Console(file=buf, theme=printer.connpy_theme, force_terminal=True)
old_stdout = sys.stdout
sys.stdout = buf
printer.set_thread_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
printer.set_thread_err_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
printer.set_thread_stream(buf)
try:
if hasattr(module, "Entrypoint"):
module.Entrypoint(args, parser, app)
except Exception as e:
import traceback
printer.err_console.print(traceback.format_exc())
except BaseException as e:
if not isinstance(e, SystemExit):
import traceback
printer.err_console.print(traceback.format_exc())
finally:
sys.stdout = old_stdout
printer.console = old_console
printer.err_console = old_err_console
printer.set_thread_console(old_console)
printer.set_thread_err_console(old_err_console)
printer.set_thread_stream(None)
for line in buf.getvalue().splitlines(keepends=True):
yield line
+1 -1
View File
@@ -58,7 +58,7 @@ class ServiceProvider:
raise InvalidConfigurationError("Remote host must be specified in remote mode")
import grpc
from ..grpc.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
from ..grpc_layer.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
channel = grpc.insecure_channel(self.remote_host)
+3 -3
View File
@@ -157,9 +157,9 @@ class SyncService(BaseService):
if os.path.exists(self.config.key):
zipf.write(self.config.key, ".osk")
# Manage retention (max 10 backups)
# Manage retention (max 100 backups)
backups = self.list_backups()
if len(backups) >= 10:
if len(backups) >= 100:
oldest = min(backups, key=lambda x: x['timestamp'] or '0')
self.delete_backup(oldest['id'])
@@ -360,7 +360,7 @@ class SyncService(BaseService):
if not sync_enabled: return
printer.info("Triggering auto-sync...")
if self.check_login_status() != True:
printer.warning("Auto-sync: Not logged in to Google Drive.")
return
+3 -3
View File
@@ -269,16 +269,16 @@ class TestToolMethods:
def test_list_nodes_tool_found(self, myai):
result = myai.list_nodes_tool("router.*")
parsed = json.loads(result)
parsed = json.loads(result) if isinstance(result, str) else result
assert "router1" in str(parsed)
def test_list_nodes_tool_not_found(self, myai):
result = myai.list_nodes_tool("nonexistent_pattern_xyz")
assert "No nodes found" in result
assert "No nodes found" in str(result)
def test_get_node_info_masks_password(self, myai):
result = myai.get_node_info_tool("router1")
parsed = json.loads(result)
parsed = json.loads(result) if isinstance(result, str) else result
assert parsed["password"] == "***"
def test_is_safe_command_show(self, myai):
+17
View File
@@ -99,6 +99,23 @@ class TestCommandGeneration:
assert "telnet 10.0.0.1" in cmd
assert "23" in cmd
def test_ssm_cmd_basic(self):
n = self._make_node(protocol="ssm", host="i-12345")
cmd = n._get_cmd()
assert "aws ssm start-session" in cmd
assert "--target i-12345" in cmd
def test_ssm_cmd_tags(self):
n = self._make_node(protocol="ssm", host="i-12345", tags={"region": "us-west-2", "profile": "prod"})
cmd = n._get_cmd()
assert "--region us-west-2" in cmd
assert "--profile prod" in cmd
def test_ssm_cmd_options(self):
n = self._make_node(protocol="ssm", host="i-12345", options="--document-name AWS-StartInteractiveCommand")
cmd = n._get_cmd()
assert "--document-name AWS-StartInteractiveCommand" in cmd
def test_kubectl_cmd(self):
n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"})
cmd = n._get_cmd()
+202
View File
@@ -0,0 +1,202 @@
import pytest
import grpc
import json
import os
import threading
from unittest.mock import MagicMock, patch
from concurrent import futures
from connpy.grpc_layer import server, connpy_pb2, connpy_pb2_grpc, stubs
from connpy.services.exceptions import ConnpyError
class MockContext:
def abort(self, code, details):
raise Exception(f"gRPC Abort: {code} - {details}")
# --- UNIT TESTS (with mocks) ---
class TestNodeServicerNaming:
@pytest.fixture
def servicer(self, populated_config):
return server.NodeServicer(populated_config)
@patch("connpy.core.node")
def test_interact_node_uses_passed_name(self, mock_node, servicer):
# Setup request with custom name
params = {"name": "custom-node-name@test", "host": "1.2.3.4", "protocol": "ssh"}
request = connpy_pb2.InteractRequest(
id="dynamic",
connection_params_json=json.dumps(params)
)
# Mock node to allow _connect
mock_node_instance = MagicMock()
mock_node_instance._connect.return_value = True
mock_node.return_value = mock_node_instance
# We only need the first iteration of the generator to check naming
gen = servicer.interact_node(iter([request]), MockContext())
next(gen) # Skip the success response
# Verify that node() was called with the custom name
mock_node.assert_called()
found = False
for call in mock_node.call_args_list:
if call.args[0] == "custom-node-name@test":
found = True
break
assert found
@patch("connpy.core.node")
def test_interact_node_fallback_naming(self, mock_node, servicer):
# Setup request without custom name but with host
params = {"host": "my-instance", "protocol": "ssm"}
request = connpy_pb2.InteractRequest(
id="dynamic",
connection_params_json=json.dumps(params)
)
mock_node_instance = MagicMock()
mock_node_instance._connect.return_value = True
mock_node.return_value = mock_node_instance
gen = servicer.interact_node(iter([request]), MockContext())
next(gen)
# Verify fallback name: dynamic-{host}@remote
found = False
for call in mock_node.call_args_list:
if call.args[0] == "dynamic-my-instance@remote":
found = True
break
assert found
class TestStubsMessageFormatting:
@patch("termios.tcsetattr")
@patch("termios.tcgetattr")
@patch("tty.setraw")
@patch("os.read")
@patch("select.select")
def test_connect_dynamic_msg_formatting_ssm(self, mock_select, mock_read, mock_setraw, mock_getattr, mock_setattr):
from connpy.grpc_layer.stubs import NodeStub
mock_getattr.return_value = [0, 0, 0, 0, 0, 0, [0] * 32]
mock_channel = MagicMock()
stub = NodeStub(mock_channel, "localhost:8048")
mock_resp = MagicMock()
mock_resp.success = True
stub.stub.interact_node.return_value = iter([mock_resp])
with patch("connpy.printer.success") as mock_success:
with patch("sys.stdin.fileno", return_value=0):
mock_select.return_value = ([], [], [])
params = {"protocol": "ssm", "host": "i-12345", "name": "my-ssm-node@aws"}
with patch("select.select", side_effect=KeyboardInterrupt):
try:
stub.connect_dynamic(params)
except KeyboardInterrupt:
pass
mock_success.assert_called()
msg = mock_success.call_args[0][0]
assert "Connected to my-ssm-node@aws" in msg
assert "at i-12345" in msg
assert ":22" not in msg
assert "via: ssm" in msg
# --- INTEGRATION TESTS (Real Server/Stub Communication) ---
class TestGRPCIntegration:
@pytest.fixture
def grpc_server(self, populated_config):
"""Starts a local gRPC server for integration testing."""
srv = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
# Register services
connpy_pb2_grpc.add_NodeServiceServicer_to_server(server.NodeServicer(populated_config), srv)
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(server.ProfileServicer(populated_config), srv)
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(server.ConfigServicer(populated_config), srv)
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(server.ExecutionServicer(populated_config), srv)
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(server.ImportExportServicer(populated_config), srv)
port = srv.add_insecure_port('127.0.0.1:0')
srv.start()
yield f"127.0.0.1:{port}"
srv.stop(0)
@pytest.fixture
def channel(self, grpc_server):
with grpc.insecure_channel(grpc_server) as channel:
yield channel
@pytest.fixture
def node_stub(self, channel):
return stubs.NodeStub(channel, "localhost")
@pytest.fixture
def profile_stub(self, channel):
return stubs.ProfileStub(channel, "localhost")
@pytest.fixture
def config_stub(self, channel):
return stubs.ConfigStub(channel, "localhost")
def test_list_nodes_integration(self, node_stub):
nodes = node_stub.list_nodes()
assert "router1" in nodes
assert "server1@office" in nodes
def test_get_node_details_integration(self, node_stub):
details = node_stub.get_node_details("router1")
assert details["host"] == "10.0.0.1"
def test_node_not_found_integration(self, node_stub):
with pytest.raises(ConnpyError) as exc:
node_stub.get_node_details("non-existent")
assert "Node 'non-existent' not found." in str(exc.value)
def test_list_profiles_integration(self, profile_stub):
profiles = profile_stub.list_profiles()
assert "office-user" in profiles
def test_get_settings_integration(self, config_stub):
settings = config_stub.get_settings()
assert "idletime" in settings
def test_update_setting_integration(self, config_stub):
config_stub.update_setting("idletime", 99)
settings = config_stub.get_settings()
assert settings["idletime"] == 99
def test_add_delete_node_integration(self, node_stub):
node_stub.add_node("integration-test-node", {"host": "9.9.9.9"})
assert "integration-test-node" in node_stub.list_nodes()
node_stub.delete_node("integration-test-node")
assert "integration-test-node" not in node_stub.list_nodes()
def test_import_yaml_integration(self, channel, node_stub):
import yaml
from connpy.grpc_layer import stubs
stub = stubs.ImportExportStub(channel, "localhost")
# ImportExportService expects a flat dict of nodes, not a full config structure
inventory = {
"imported-node": {"host": "8.8.8.8", "protocol": "ssh", "type": "connection"}
}
yaml_content = yaml.dump(inventory)
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
temp_path = f.name
try:
stub.import_from_file(temp_path)
# Verify the node was imported and is visible via NodeStub
nodes = node_stub.list_nodes()
assert "imported-node" in nodes
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
+65
View File
@@ -0,0 +1,65 @@
import threading
import io
import time
import sys
import pytest
from connpy import printer
def test_printer_thread_isolation():
"""Verify that printer output is isolated per thread when using set_thread_stream."""
num_threads = 5
iterations = 20
results = {}
def worker(thread_id):
# Create a private buffer for this thread
buf = io.StringIO()
printer.set_thread_stream(buf)
# Ensure we have a clean console for this thread
# In a real gRPC request, this happens automatically as it's a new thread
printer.set_thread_console(None)
# Each thread prints its own ID
expected_msg = f"Thread-{thread_id}"
for _ in range(iterations):
printer.info(expected_msg)
time.sleep(0.01)
results[thread_id] = buf.getvalue()
printer.set_thread_stream(None)
threads = []
for i in range(num_threads):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Validation
for thread_id, output in results.items():
expected_msg = f"Thread-{thread_id}"
assert expected_msg in output
# Ensure no leaks
for other_id in range(num_threads):
if other_id == thread_id: continue
assert f"Thread-{other_id}" not in output
def test_printer_manual_stream():
"""Verify that setting a thread stream correctly captures printer output in the current thread."""
buf = io.StringIO()
# We must clear the thread-local console to force it to pick up the new sys.stdout proxy
printer.set_thread_console(None)
printer.set_thread_stream(buf)
printer.info("Captured-Message")
output = buf.getvalue()
printer.set_thread_stream(None)
printer.set_thread_console(None)
assert "Captured-Message" in output