refactor: Major upgrade to v5.1b6 - AWS SSM support & Distributed Architecture
Core & Protocols: - Native AWS SSM support added (aws ssm start-session). - Improved Pexpect logic for ssm, kubectl, and docker. - Cleaned connection success messages (omitting ports for non-IP protocols). gRPC Layer: - Migrated gRPC modules to 'connpy/grpc_layer/'. - Implemented dynamic node naming (e.g. ssm-i-xxxx@aws) for accurate server-side logging. - Added automatic sys.path resolution for gRPC generated modules. - Enhanced InteractNode response with initial connection status. Printer & Concurrency: - Implemented ThreadLocalStream for isolated thread-safe output. - Self-healing Console objects to prevent 'closed file' errors in test/async environments. - Capture clean plugin output in remote executions. AI & Services: - Improved tool registration and debug visualization. - Restored native dictionary returns for AI tools to fix Web UI rendering. - Increased backup retention to 100 copies in SyncService. - Silenced noisy auto-sync CLI messages. Quality & Docs: - Total tests: 267 (all passing). - New test suites for gRPC layer and printer concurrency. - Updated .gitignore to exclude internal planning docs. - Full technical documentation regenerated with pdoc.
This commit is contained in:
+2
-2
@@ -2,10 +2,10 @@
|
||||
'''
|
||||
## Connection manager
|
||||
|
||||
Connpy is a SSH, SFTP, Telnet, kubectl, and Docker pod connection manager and automation module for Linux, Mac, and Docker.
|
||||
Connpy is a SSH, SFTP, Telnet, kubectl, Docker pod, and AWS SSM connection manager and automation module for Linux, Mac, and Docker.
|
||||
|
||||
### Features
|
||||
- Manage connections using SSH, SFTP, Telnet, kubectl, and Docker exec.
|
||||
- Manage connections using SSH, SFTP, Telnet, kubectl, Docker exec, and AWS SSM.
|
||||
- Set contexts to manage specific nodes from specific contexts (work/home/clients/etc).
|
||||
- You can generate profiles and reference them from nodes using @profilename so you don't
|
||||
need to edit multiple nodes when changing passwords or other information.
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
__version__ = "5.1b5"
|
||||
__version__ = "5.1b6"
|
||||
|
||||
+117
-31
@@ -31,6 +31,8 @@ from . import printer
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.console import Group
|
||||
from rich.rule import Rule
|
||||
|
||||
console = printer.console
|
||||
|
||||
@@ -209,14 +211,20 @@ class ai:
|
||||
status_formatter (callable): Function(args_dict) -> status string.
|
||||
"""
|
||||
name = tool_definition["function"]["name"]
|
||||
|
||||
# Check if already registered to prevent duplicates
|
||||
if target in ("engineer", "both"):
|
||||
self.external_engineer_tools.append(tool_definition)
|
||||
if not any(t["function"]["name"] == name for t in self.external_engineer_tools):
|
||||
self.external_engineer_tools.append(tool_definition)
|
||||
if target in ("architect", "both"):
|
||||
self.external_architect_tools.append(tool_definition)
|
||||
if not any(t["function"]["name"] == name for t in self.external_architect_tools):
|
||||
self.external_architect_tools.append(tool_definition)
|
||||
|
||||
self.external_tool_handlers[name] = handler
|
||||
if engineer_prompt:
|
||||
|
||||
if engineer_prompt and engineer_prompt not in self.engineer_prompt_extensions:
|
||||
self.engineer_prompt_extensions.append(engineer_prompt)
|
||||
if architect_prompt:
|
||||
if architect_prompt and architect_prompt not in self.architect_prompt_extensions:
|
||||
self.architect_prompt_extensions.append(architect_prompt)
|
||||
if status_formatter:
|
||||
self.tool_status_formatters[name] = status_formatter
|
||||
@@ -448,12 +456,46 @@ class ai:
|
||||
|
||||
def _truncate(self, text, limit=None):
|
||||
"""Truncate text to specified limit, keeping head (60%) and tail (40%)."""
|
||||
if not isinstance(text, str): return str(text)
|
||||
final_limit = limit or self.max_truncate
|
||||
if len(text) <= final_limit: return text
|
||||
head_limit = int(final_limit * 0.6)
|
||||
tail_limit = int(final_limit * 0.4)
|
||||
return (text[:head_limit] + f"\n\n[... OUTPUT TRUNCATED ...]\n\n" + text[-tail_limit:])
|
||||
|
||||
def _print_debug_observation(self, fn, obs):
|
||||
"""Prints a tool observation in a readable way during debug mode."""
|
||||
# Try to parse as JSON if it's a string
|
||||
if isinstance(obs, str):
|
||||
try:
|
||||
obs_data = json.loads(obs)
|
||||
except Exception:
|
||||
obs_data = obs
|
||||
else:
|
||||
obs_data = obs
|
||||
|
||||
if isinstance(obs_data, dict):
|
||||
elements = []
|
||||
for k, v in obs_data.items():
|
||||
elements.append(Text(f"• {k}:", style="key"))
|
||||
# Use Text for values to ensure newlines are rendered
|
||||
val = str(v)
|
||||
# If it's a multiline string from a delegation task, keep it clean
|
||||
elements.append(Text(val))
|
||||
|
||||
if not elements:
|
||||
content = Text("Empty data set")
|
||||
else:
|
||||
# Add a small spacer instead of a Rule for cleaner look
|
||||
content = Group(*elements)
|
||||
elif isinstance(obs_data, list):
|
||||
content = Text("\n".join(f"• {item}" for item in obs_data))
|
||||
else:
|
||||
content = Text(str(obs_data))
|
||||
|
||||
title = f"[bold]{fn}[/bold]"
|
||||
self.console.print(Panel(content, title=title, border_style="ai_status"))
|
||||
|
||||
def manage_memory_tool(self, content, action="append"):
|
||||
"""Save or update long-term memory. Only use when user explicitly requests it."""
|
||||
if not content or not content.strip():
|
||||
@@ -491,8 +533,8 @@ class ai:
|
||||
ts = data.get("tags")
|
||||
if isinstance(ts, dict): os_tag = ts.get("os", "unknown")
|
||||
res[name] = {"os": os_tag}
|
||||
return json.dumps(res)
|
||||
return json.dumps({"count": len(matched_names), "nodes": matched_names, "note": "Use 'get_node_info' for details."})
|
||||
return res
|
||||
return {"count": len(matched_names), "nodes": matched_names, "note": "Use 'get_node_info' for details."}
|
||||
except Exception as e:
|
||||
return f"Error listing nodes: {str(e)}"
|
||||
|
||||
@@ -566,7 +608,7 @@ class ai:
|
||||
if not matched_names: return "No nodes found matching filter."
|
||||
thisnodes_dict = self.config.getitems(matched_names, extract=True)
|
||||
result = nodes(thisnodes_dict, config=self.config).run(commands)
|
||||
return self._truncate(json.dumps(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error executing commands: {str(e)}"
|
||||
|
||||
@@ -575,7 +617,7 @@ class ai:
|
||||
try:
|
||||
d = self.config.getitem(node_name, extract=True)
|
||||
if 'password' in d: d['password'] = '***'
|
||||
return json.dumps(d)
|
||||
return d
|
||||
except Exception as e:
|
||||
return f"Error getting node info: {str(e)}"
|
||||
|
||||
@@ -619,7 +661,7 @@ class ai:
|
||||
self.console.print(f"[warning] You can press Ctrl+C to interrupt and get a summary.[/warning]")
|
||||
soft_limit_warned = True
|
||||
|
||||
if status: status.update(f"[ai_status]Engineer: Analyzing mission... (step {iteration})")
|
||||
if status and not chat_history: status.update(f"[ai_status]Engineer: Analyzing mission... (step {iteration})")
|
||||
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
@@ -642,8 +684,8 @@ class ai:
|
||||
for tc in resp_msg.tool_calls:
|
||||
fn, args = tc.function.name, json.loads(tc.function.arguments)
|
||||
|
||||
# Notificación en tiempo real de la tarea técnica
|
||||
if status:
|
||||
# Notificación en tiempo real de la tarea técnica (Only if not in Architect loop)
|
||||
if status and not chat_history:
|
||||
if fn == "list_nodes": status.update(f"[ai_status]Engineer: [SEARCH] {args.get('filter_pattern','.*')}")
|
||||
elif fn == "run_commands":
|
||||
cmds = args.get('commands', [])
|
||||
@@ -652,7 +694,8 @@ class ai:
|
||||
elif fn == "get_node_info": status.update(f"[ai_status]Engineer: [INSPECT] {args.get('node_name','')}")
|
||||
elif fn in self.tool_status_formatters: status.update(self.tool_status_formatters[fn](args))
|
||||
|
||||
if debug: self.console.print(Panel(Text(json.dumps(args, indent=2)), title=f"[bold engineer]Engineer Tool: {fn}[/bold engineer]", border_style="engineer"))
|
||||
if debug:
|
||||
self._print_debug_observation(f"Decision: {fn}", args)
|
||||
|
||||
if fn == "list_nodes": obs = self.list_nodes_tool(**args)
|
||||
elif fn == "run_commands": obs = self.run_commands_tool(**args, status=status)
|
||||
@@ -660,8 +703,12 @@ class ai:
|
||||
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
|
||||
else: obs = f"Error: Unknown tool '{fn}'."
|
||||
|
||||
if debug: self.console.print(Panel(Text(str(obs)), title=f"[bold pass]Engineer Observation: {fn}[/bold pass]", border_style="success"))
|
||||
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": obs})
|
||||
if debug:
|
||||
self._print_debug_observation(f"Observation: {fn}", obs)
|
||||
|
||||
# Ensure observation is a string and truncated for the LLM
|
||||
obs_str = obs if isinstance(obs, str) else json.dumps(obs)
|
||||
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": self._truncate(obs_str)})
|
||||
|
||||
if iteration >= self.hard_limit_iterations:
|
||||
self.console.print(f"[error]⛔ Engineer reached hard limit ({self.hard_limit_iterations} steps). Forcing stop.[/error]")
|
||||
@@ -675,30 +722,46 @@ class ai:
|
||||
|
||||
def _get_engineer_tools(self):
|
||||
"""Define tools available to the Engineer."""
|
||||
tools = [
|
||||
base_tools = [
|
||||
{"type": "function", "function": {"name": "list_nodes", "description": "Lists available nodes in the inventory.", "parameters": {"type": "object", "properties": {"filter_pattern": {"type": "string", "description": "Regex to filter nodes (e.g. '.*', 'border.*')."}}}}},
|
||||
{"type": "function", "function": {"name": "run_commands", "description": "Runs one or more commands on matched nodes. MANDATORY: You MUST call 'list_nodes' first to verify the target list.", "parameters": {"type": "object", "properties": {"nodes_filter": {"type": "string", "description": "Exact node name or verified filter pattern."}, "commands": {"type": "array", "items": {"type": "string"}, "description": "List of commands (e.g. ['show ip route', 'show int desc'])."}}, "required": ["nodes_filter", "commands"]}}},
|
||||
{"type": "function", "function": {"name": "get_node_info", "description": "Gets full metadata for a specific node.", "parameters": {"type": "object", "properties": {"node_name": {"type": "string"}}, "required": ["node_name"]}}}
|
||||
]
|
||||
|
||||
if self.architect_key:
|
||||
tools.extend([
|
||||
base_tools.extend([
|
||||
{"type": "function", "function": {"name": "consult_architect", "description": "Ask the Strategic Reasoning Engine for advice on complex design, architecture, or troubleshooting decisions. You remain in control and will present the response to the user. Use this for: configuration planning, design validation, complex troubleshooting.", "parameters": {"type": "object", "properties": {"question": {"type": "string", "description": "Strategic question or decision needed."}, "technical_summary": {"type": "string", "description": "Technical findings and context gathered so far."}}, "required": ["question", "technical_summary"]}}},
|
||||
{"type": "function", "function": {"name": "escalate_to_architect", "description": "Transfer full control to the Strategic Reasoning Engine. Use ONLY when the user explicitly requests the Architect or when the problem requires strategic oversight beyond consultation. After escalation, the Architect takes over the conversation.", "parameters": {"type": "object", "properties": {"reason": {"type": "string", "description": "Why you're escalating (e.g. 'User requested Architect', 'Complex multi-site design needed')."}, "context": {"type": "string", "description": "Full context and findings to hand over."}}, "required": ["reason", "context"]}}}
|
||||
])
|
||||
|
||||
tools.extend(self.external_engineer_tools)
|
||||
return tools
|
||||
# Deduplicate by name to prevent Gemini BadRequestError
|
||||
all_tools = base_tools + self.external_engineer_tools
|
||||
seen_names = set()
|
||||
unique_tools = []
|
||||
for t in all_tools:
|
||||
name = t["function"]["name"]
|
||||
if name not in seen_names:
|
||||
unique_tools.append(t)
|
||||
seen_names.add(name)
|
||||
return unique_tools
|
||||
|
||||
def _get_architect_tools(self):
|
||||
"""Define tools available to the Strategic Reasoning Engine."""
|
||||
tools = [
|
||||
base_tools = [
|
||||
{"type": "function", "function": {"name": "delegate_to_engineer", "description": "Delegates a technical mission to the Engineer.", "parameters": {"type": "object", "properties": {"task": {"type": "string", "description": "Detailed technical mission or goal."}}, "required": ["task"]}}},
|
||||
{"type": "function", "function": {"name": "return_to_engineer", "description": "Return control to the Engineer. Use this when your strategic analysis is complete and the Engineer should handle the rest of the conversation.", "parameters": {"type": "object", "properties": {"summary": {"type": "string", "description": "Brief summary of your analysis to hand over to the Engineer."}}, "required": ["summary"]}}},
|
||||
{"type": "function", "function": {"name": "manage_memory_tool", "description": "Saves information to long-term memory. MANDATORY: Only use this if the user explicitly asks to remember or save something.", "parameters": {"type": "object", "properties": {"content": {"type": "string"}, "action": {"type": "string", "enum": ["append", "replace"]}}, "required": ["content"]}}}
|
||||
]
|
||||
tools.extend(self.external_architect_tools)
|
||||
return tools
|
||||
|
||||
all_tools = base_tools + self.external_architect_tools
|
||||
seen_names = set()
|
||||
unique_tools = []
|
||||
for t in all_tools:
|
||||
name = t["function"]["name"]
|
||||
if name not in seen_names:
|
||||
unique_tools.append(t)
|
||||
seen_names.add(name)
|
||||
return unique_tools
|
||||
|
||||
def _get_sessions(self):
|
||||
"""Returns a list of session metadata sorted by date."""
|
||||
@@ -902,12 +965,16 @@ class ai:
|
||||
soft_limit_warned = True
|
||||
|
||||
label = "[architect][bold]Architect[/bold][/architect]" if current_brain == "architect" else "[engineer][bold]Engineer[/bold][/engineer]"
|
||||
if status: status.update(f"{label} is thinking... (step {iteration})")
|
||||
if status:
|
||||
# Notify responder identity ONLY for web/remote clients (StatusBridge has is_web)
|
||||
if getattr(status, "is_web", False):
|
||||
status.update(f"__RESPONDER__:{current_brain}")
|
||||
status.update(f"{label} is thinking... (step {iteration})")
|
||||
|
||||
streamed_response = False
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
if stream and not debug:
|
||||
if stream and (not debug or chunk_callback):
|
||||
response, streamed_response = self._stream_completion(
|
||||
model=model, messages=safe_messages, tools=tools, api_key=key,
|
||||
status=status, label=label, debug=debug, num_retries=3,
|
||||
@@ -947,7 +1014,10 @@ class ai:
|
||||
messages.append(msg_dict)
|
||||
|
||||
if debug and resp_msg.content:
|
||||
self.console.print(Panel(Markdown(resp_msg.content), title=f"{label} Reasoning", border_style="architect" if current_brain == "architect" else "engineer"))
|
||||
# In CLI debug mode, only print intermediate reasoning if there are tool calls.
|
||||
# If there are no tool calls, this content is the final answer and will be printed by the caller.
|
||||
if resp_msg.tool_calls:
|
||||
self.console.print(Panel(Markdown(resp_msg.content), title=f"[{current_brain}][bold]{label} Reasoning[/bold][/{current_brain}]", border_style="architect" if current_brain == "architect" else "engineer"))
|
||||
|
||||
if not resp_msg.tool_calls: break
|
||||
|
||||
@@ -967,7 +1037,8 @@ class ai:
|
||||
if fn == "delegate_to_engineer": status.update(f"[architect]Architect: [DELEGATING MISSION] {args.get('task','')[:40]}...")
|
||||
elif fn == "manage_memory_tool": status.update(f"[architect]Architect: [UPDATING MEMORY]")
|
||||
|
||||
if debug: self.console.print(Panel(Text(json.dumps(args, indent=2)), title=f"{label} Decision: {fn}", border_style="debug"))
|
||||
if debug:
|
||||
self._print_debug_observation(f"Decision: {fn}", args)
|
||||
|
||||
if fn == "delegate_to_engineer":
|
||||
obs, eng_usage = self._engineer_loop(args["task"], status=status, debug=debug, chat_history=messages[:-1])
|
||||
@@ -1025,9 +1096,13 @@ class ai:
|
||||
elif fn == "manage_memory_tool": obs = self.manage_memory_tool(**args)
|
||||
elif fn in self.external_tool_handlers: obs = self.external_tool_handlers[fn](self, **args)
|
||||
else: obs = f"Error: {fn} unknown."
|
||||
|
||||
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": obs})
|
||||
|
||||
|
||||
if debug and fn not in ["delegate_to_engineer", "consult_architect", "escalate_to_architect", "return_to_engineer"]:
|
||||
self._print_debug_observation(f"Observation: {fn}", obs)
|
||||
|
||||
# Ensure observation is a string and truncated for the LLM
|
||||
obs_str = obs if isinstance(obs, str) else json.dumps(obs)
|
||||
messages.append({"tool_call_id": tc.id, "role": "tool", "name": fn, "content": self._truncate(obs_str)})
|
||||
# Inject pending user message AFTER all tool responses are added
|
||||
if pending_user_message:
|
||||
messages.append({"role": "user", "content": pending_user_message})
|
||||
@@ -1053,14 +1128,25 @@ class ai:
|
||||
if last_msg.get("tool_calls"):
|
||||
for tc in last_msg["tool_calls"]:
|
||||
messages.append({"tool_call_id": tc.get("id"), "role": "tool", "name": tc.get("function", {}).get("name"), "content": "Operation cancelled by user."})
|
||||
messages.append({"role": "user", "content": "USER INTERRUPTED. Briefly summarize what you were doing and stop."})
|
||||
|
||||
# Use a fresh list for the summary call to avoid history corruption
|
||||
summary_messages = list(messages)
|
||||
summary_messages.append({"role": "user", "content": "USER INTERRUPTED. Briefly summarize what you were doing and stop."})
|
||||
try:
|
||||
safe_messages = self._sanitize_messages(messages)
|
||||
safe_messages = self._sanitize_messages(summary_messages)
|
||||
# Use tools=None to force a text summary during interruption
|
||||
response = completion(model=model, messages=safe_messages, tools=None, api_key=key)
|
||||
resp_msg = response.choices[0].message
|
||||
messages.append(resp_msg.model_dump(exclude_none=True))
|
||||
except Exception: pass
|
||||
|
||||
# IMPORTANT: Manually trigger callback for the summary so Web UI sees it
|
||||
if chunk_callback and resp_msg.content:
|
||||
chunk_callback(resp_msg.content)
|
||||
except Exception:
|
||||
error_msg = "Operation interrupted by user. Summary unavailable."
|
||||
messages.append({"role": "assistant", "content": error_msg})
|
||||
if chunk_callback:
|
||||
chunk_callback(error_msg)
|
||||
finally:
|
||||
# Auto-save session
|
||||
self.save_session(messages, model=model)
|
||||
|
||||
+2
-2
@@ -48,7 +48,7 @@ def stop_api():
|
||||
return port
|
||||
|
||||
def debug_api(port=8048, config=None):
|
||||
from .grpc.server import serve
|
||||
from .grpc_layer.server import serve
|
||||
conf = config or configfile()
|
||||
server = serve(conf, port=port, debug=True)
|
||||
printer.info(f"gRPC Server running in debug mode on port {port}...")
|
||||
@@ -63,7 +63,7 @@ def start_server(port=8048, config=None):
|
||||
if base_dir not in sys.path:
|
||||
sys.path.insert(0, base_dir)
|
||||
|
||||
from connpy.grpc.server import serve
|
||||
from connpy.grpc_layer.server import serve
|
||||
conf = config or configfile()
|
||||
server = serve(conf, port=port, debug=False)
|
||||
_wait_for_termination()
|
||||
|
||||
@@ -32,6 +32,7 @@ Here are some important instructions and tips for configuring your new node:
|
||||
- telnet
|
||||
- kubectl (`kubectl exec`)
|
||||
- docker (`docker exec`)
|
||||
- ssm (`aws ssm start-session`)
|
||||
|
||||
3. **Optional Values**:
|
||||
- You can leave any value empty except for the hostname/IP.
|
||||
|
||||
@@ -122,6 +122,10 @@ class NodeHandler:
|
||||
printer.error(f"Node '{args.data}' already exists.")
|
||||
sys.exit(1)
|
||||
uniques = self.app.services.nodes.explode_unique(args.data)
|
||||
|
||||
# Fast fail if parent folder does not exist
|
||||
self.app.services.nodes.validate_parent_folder(args.data)
|
||||
|
||||
printer.console.print(Markdown(get_instructions()))
|
||||
|
||||
new_node_data = self.forms.questions_nodes(args.data, uniques)
|
||||
|
||||
@@ -14,14 +14,14 @@ class Validators:
|
||||
raise inquirer.errors.ValidationError("", reason="Profile {} don't exist".format(current))
|
||||
return True
|
||||
|
||||
def profile_protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^$)"):
|
||||
def profile_protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^ssm$|^$)"):
|
||||
if not re.match(regex, current):
|
||||
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker or leave empty")
|
||||
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker, ssm or leave empty")
|
||||
return True
|
||||
|
||||
def protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^$|^@.+$)"):
|
||||
def protocol_validation(self, answers, current, regex = "(^ssh$|^telnet$|^kubectl$|^docker$|^ssm$|^$|^@.+$)"):
|
||||
if not re.match(regex, current):
|
||||
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker leave empty or @profile")
|
||||
raise inquirer.errors.ValidationError("", reason="Pick between ssh, telnet, kubectl, docker, ssm, leave empty or @profile")
|
||||
if current.startswith("@"):
|
||||
if current[1:] not in self.app.profiles:
|
||||
raise inquirer.errors.ValidationError("", reason="Profile {} don't exist".format(current))
|
||||
|
||||
+12
-1
@@ -445,8 +445,19 @@ class connapp:
|
||||
|
||||
try:
|
||||
if args.subcommand in getattr(self.plugins, "remote_plugins", {}):
|
||||
import json as _json
|
||||
for chunk in self.services.plugins.invoke_plugin(args.subcommand, args):
|
||||
print(chunk, end="", flush=True)
|
||||
if "__interact__" in chunk:
|
||||
try:
|
||||
data = _json.loads(chunk.strip())
|
||||
params = data.get("__interact__")
|
||||
if params:
|
||||
self.services.nodes.connect_dynamic(params, debug=getattr(args, 'debug', False))
|
||||
break
|
||||
except (ValueError, KeyError):
|
||||
print(chunk, end="", flush=True)
|
||||
else:
|
||||
print(chunk, end="", flush=True)
|
||||
elif args.subcommand in self.plugins.plugins:
|
||||
self.plugins.plugins[args.subcommand].Entrypoint(args, self.plugins.plugin_parsers[args.subcommand].parser, self)
|
||||
else:
|
||||
|
||||
+30
-8
@@ -264,7 +264,8 @@ class node:
|
||||
size = re.search('columns=([0-9]+).*lines=([0-9]+)',str(os.get_terminal_size()))
|
||||
self.child.setwinsize(int(size.group(2)),int(size.group(1)))
|
||||
if logger:
|
||||
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
|
||||
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
|
||||
|
||||
if 'logfile' in dir(self):
|
||||
# Initialize self.mylog
|
||||
@@ -343,7 +344,8 @@ class node:
|
||||
now = datetime.datetime.now().strftime('%Y-%m-%d_%H%M%S')
|
||||
if connect == True:
|
||||
if logger:
|
||||
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
|
||||
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
|
||||
|
||||
# Attempt to set the terminal size
|
||||
try:
|
||||
@@ -444,7 +446,8 @@ class node:
|
||||
connect = self._connect(timeout = timeout, logger = logger)
|
||||
if connect == True:
|
||||
if logger:
|
||||
logger("success", "Connected to " + self.unique + " at " + self.host + (":" if self.port != '' else '') + self.port + " via: " + self.protocol)
|
||||
port_str = f":{self.port}" if self.port and self.protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
logger("success", f"Connected to {self.unique} at {self.host}{port_str} via: {self.protocol}")
|
||||
|
||||
# Attempt to set the terminal size
|
||||
try:
|
||||
@@ -549,6 +552,19 @@ class node:
|
||||
cmd += f" {docker_command}"
|
||||
return cmd
|
||||
|
||||
@MethodHook
|
||||
def _generate_ssm_cmd(self):
|
||||
region = self.tags.get("region", "") if isinstance(self.tags, dict) else ""
|
||||
profile = self.tags.get("profile", "") if isinstance(self.tags, dict) else ""
|
||||
cmd = f"aws ssm start-session --target {self.host}"
|
||||
if region:
|
||||
cmd += f" --region {region}"
|
||||
if profile:
|
||||
cmd += f" --profile {profile}"
|
||||
if self.options:
|
||||
cmd += f" {self.options}"
|
||||
return cmd
|
||||
|
||||
@MethodHook
|
||||
def _get_cmd(self):
|
||||
if self.protocol in ["ssh", "sftp"]:
|
||||
@@ -559,6 +575,8 @@ class node:
|
||||
return self._generate_kube_cmd()
|
||||
elif self.protocol == "docker":
|
||||
return self._generate_docker_cmd()
|
||||
elif self.protocol == "ssm":
|
||||
return self._generate_ssm_cmd()
|
||||
else:
|
||||
printer.error(f"Invalid protocol: {self.protocol}")
|
||||
sys.exit(1)
|
||||
@@ -579,7 +597,8 @@ class node:
|
||||
"sftp": ['yes/no', 'refused', 'supported', 'Invalid|[u|U]sage: sftp', 'ssh-keygen.*\"', 'timeout|timed.out', 'unavailable', 'closed', password_prompt, prompt, 'suspend', pexpect.EOF, pexpect.TIMEOUT, "No route to host", "resolve hostname", "no matching", "[b|B]ad (owner|permissions)"],
|
||||
"telnet": ['[u|U]sername:', 'refused', 'supported', 'invalid|unrecognized option', 'ssh-keygen.*\"', 'timeout|timed.out', 'unavailable', 'closed', password_prompt, prompt, 'suspend', pexpect.EOF, pexpect.TIMEOUT, "No route to host", "resolve hostname", "no matching", "[b|B]ad (owner|permissions)"],
|
||||
"kubectl": ['[u|U]sername:', '[r|R]efused', '[E|e]rror', 'DEPRECATED', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF, "expired|invalid"],
|
||||
"docker": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'not a docker command', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF]
|
||||
"docker": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'not a docker command', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF],
|
||||
"ssm": ['[u|U]sername:', 'Cannot', '[E|e]rror', 'failed', 'SessionManagerPlugin', 'unknown', 'unable to resolve', pexpect.TIMEOUT, password_prompt, prompt, pexpect.EOF]
|
||||
}
|
||||
|
||||
error_indices = {
|
||||
@@ -587,7 +606,8 @@ class node:
|
||||
"sftp": [1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16],
|
||||
"telnet": [1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 16],
|
||||
"kubectl": [1, 2, 3, 4, 8], # Define error indices for kube
|
||||
"docker": [1, 2, 3, 4, 5, 6, 7] # Define error indices for docker
|
||||
"docker": [1, 2, 3, 4, 5, 6, 7], # Define error indices for docker
|
||||
"ssm": [1, 2, 3, 4, 5, 6, 7]
|
||||
}
|
||||
|
||||
eof_indices = {
|
||||
@@ -595,7 +615,8 @@ class node:
|
||||
"sftp": [8, 9, 10, 11],
|
||||
"telnet": [8, 9, 10, 11],
|
||||
"kubectl": [5, 6, 7], # Define eof indices for kube
|
||||
"docker": [8, 9, 10] # Define eof indices for docker
|
||||
"docker": [8, 9, 10], # Define eof indices for docker
|
||||
"ssm": [8, 9, 10]
|
||||
}
|
||||
|
||||
initial_indices = {
|
||||
@@ -603,7 +624,8 @@ class node:
|
||||
"sftp": [0],
|
||||
"telnet": [0],
|
||||
"kubectl": [0], # Define special indices for kube
|
||||
"docker": [0] # Define special indices for docker
|
||||
"docker": [0], # Define special indices for docker
|
||||
"ssm": [0]
|
||||
}
|
||||
|
||||
attempts = 1
|
||||
@@ -627,7 +649,7 @@ class node:
|
||||
if results in initial_indices[self.protocol]:
|
||||
if self.protocol in ["ssh", "sftp"]:
|
||||
child.sendline('yes')
|
||||
elif self.protocol in ["telnet", "kubectl", "docker"]:
|
||||
elif self.protocol in ["telnet", "kubectl", "docker", "ssm"]:
|
||||
if self.user:
|
||||
child.sendline(self.user)
|
||||
else:
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,8 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# gRPC generated files use absolute imports that assume their directory is in sys.path.
|
||||
# We add this directory to sys.path to allow imports like 'import connpy_pb2' to succeed.
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,7 @@
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import connpy_pb2 as connpy__pb2
|
||||
import connpy_pb2 as connpy__pb2
|
||||
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.80.0'
|
||||
@@ -85,6 +85,11 @@ class NodeServiceStub(object):
|
||||
request_serializer=connpy__pb2.BulkRequest.SerializeToString,
|
||||
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.validate_parent_folder = channel.unary_unary(
|
||||
'/connpy.NodeService/validate_parent_folder',
|
||||
request_serializer=connpy__pb2.IdRequest.SerializeToString,
|
||||
response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.set_reserved_names = channel.unary_unary(
|
||||
'/connpy.NodeService/set_reserved_names',
|
||||
request_serializer=connpy__pb2.ListRequest.SerializeToString,
|
||||
@@ -170,6 +175,12 @@ class NodeServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def validate_parent_folder(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def set_reserved_names(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
@@ -247,6 +258,11 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=connpy__pb2.BulkRequest.FromString,
|
||||
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'validate_parent_folder': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.validate_parent_folder,
|
||||
request_deserializer=connpy__pb2.IdRequest.FromString,
|
||||
response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'set_reserved_names': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.set_reserved_names,
|
||||
request_deserializer=connpy__pb2.ListRequest.FromString,
|
||||
@@ -548,6 +564,33 @@ class NodeService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def validate_parent_folder(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/connpy.NodeService/validate_parent_folder',
|
||||
connpy__pb2.IdRequest.SerializeToString,
|
||||
google_dot_protobuf_dot_empty__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def set_reserved_names(request,
|
||||
target,
|
||||
@@ -12,6 +12,7 @@ from . import connpy_pb2, connpy_pb2_grpc, remote_plugin_pb2, remote_plugin_pb2_
|
||||
import json
|
||||
from .utils import to_value, from_value, to_struct, from_struct
|
||||
from ..services.exceptions import ConnpyError
|
||||
from .. import printer
|
||||
|
||||
# Import local services
|
||||
from ..services.node_service import NodeService
|
||||
@@ -24,16 +25,34 @@ from ..services.execution_service import ExecutionService
|
||||
from ..services.import_export_service import ImportExportService
|
||||
|
||||
def handle_errors(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ConnpyError as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||
except Exception as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
||||
return wrapper
|
||||
import inspect
|
||||
if inspect.isgeneratorfunction(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
for item in func(*args, **kwargs):
|
||||
yield item
|
||||
except ConnpyError as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||
except Exception as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
||||
finally:
|
||||
printer.clear_thread_state()
|
||||
return wrapper
|
||||
else:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except ConnpyError as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||
except Exception as e:
|
||||
context = kwargs.get("context") or args[-1]
|
||||
context.abort(grpc.StatusCode.UNKNOWN, str(e))
|
||||
finally:
|
||||
printer.clear_thread_state()
|
||||
return wrapper
|
||||
|
||||
class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
def __init__(self, config):
|
||||
@@ -56,18 +75,72 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
unique_id = first_req.id
|
||||
sftp = first_req.sftp
|
||||
debug = first_req.debug
|
||||
printer.console.print(f"[debug][DEBUG][/debug] gRPC interact_node request for: [bold cyan]{unique_id}[/bold cyan]")
|
||||
|
||||
node_data = self.service.config.getitem(unique_id, extract=False)
|
||||
profile_service = ProfileService(self.service.config)
|
||||
resolved_data = profile_service.resolve_node_data(node_data)
|
||||
|
||||
n = node(unique_id, **resolved_data, config=self.service.config)
|
||||
if sftp:
|
||||
n.protocol = "sftp"
|
||||
if first_req.connection_params_json:
|
||||
import json
|
||||
params = json.loads(first_req.connection_params_json)
|
||||
base_node_id = params.get("base_node")
|
||||
# Valid attributes that a node object accepts
|
||||
valid_attrs = ['host', 'options', 'logs', 'password', 'port', 'protocol', 'user', 'jumphost']
|
||||
|
||||
fallback_id = f"{unique_id}@remote"
|
||||
if unique_id == "dynamic" and params.get("host"):
|
||||
fallback_id = f"dynamic-{params.get('host')}@remote"
|
||||
|
||||
if base_node_id:
|
||||
# Look up the base node in config and use its full data
|
||||
nodes = self.service.config._getallnodes(base_node_id)
|
||||
if nodes:
|
||||
device = self.service.config.getitem(nodes[0])
|
||||
# Override device properties with any passed in params
|
||||
for attr in valid_attrs:
|
||||
if attr in params:
|
||||
device[attr] = params[attr]
|
||||
|
||||
if "tags" in params:
|
||||
device_tags = device.get("tags", {})
|
||||
if not isinstance(device_tags, dict):
|
||||
device_tags = {}
|
||||
device_tags.update(params["tags"])
|
||||
device["tags"] = device_tags
|
||||
|
||||
node_name = params.get("name", base_node_id)
|
||||
n = node(node_name, **device, config=self.service.config)
|
||||
else:
|
||||
# base_node not found, fall back to dynamic
|
||||
node_name = params.get("name", fallback_id)
|
||||
n = node(node_name, host=params.get("host", ""), config=self.service.config)
|
||||
for attr in valid_attrs:
|
||||
if attr in params:
|
||||
setattr(n, attr, params[attr])
|
||||
if "tags" in params:
|
||||
n.tags = params["tags"]
|
||||
else:
|
||||
node_name = params.get("name", fallback_id)
|
||||
n = node(node_name, host=params.get("host", ""), config=self.service.config)
|
||||
for attr in valid_attrs:
|
||||
if attr in params:
|
||||
setattr(n, attr, params[attr])
|
||||
if "tags" in params:
|
||||
n.tags = params["tags"]
|
||||
else:
|
||||
node_data = self.service.config.getitem(unique_id, extract=False)
|
||||
if not node_data:
|
||||
context.abort(grpc.StatusCode.NOT_FOUND, f"Node {unique_id} not found")
|
||||
profile_service = ProfileService(self.service.config)
|
||||
resolved_data = profile_service.resolve_node_data(node_data)
|
||||
n = node(unique_id, **resolved_data, config=self.service.config)
|
||||
if sftp:
|
||||
n.protocol = "sftp"
|
||||
|
||||
connect = n._connect(debug=debug)
|
||||
if connect != True:
|
||||
context.abort(grpc.StatusCode.INTERNAL, "Failed to connect to node")
|
||||
yield connpy_pb2.InteractResponse(success=False, error_message=str(connect))
|
||||
return
|
||||
|
||||
# Signal successful connection to the client
|
||||
yield connpy_pb2.InteractResponse(success=True)
|
||||
|
||||
import threading
|
||||
import queue
|
||||
@@ -145,6 +218,11 @@ class NodeServicer(connpy_pb2_grpc.NodeServiceServicer):
|
||||
def explode_unique(self, request, context):
|
||||
return connpy_pb2.ValueResponse(data=to_value(self.service.explode_unique(request.id)))
|
||||
|
||||
@handle_errors
|
||||
def validate_parent_folder(self, request, context):
|
||||
self.service.validate_parent_folder(request.id)
|
||||
return Empty()
|
||||
|
||||
@handle_errors
|
||||
def generate_cache(self, request, context):
|
||||
self.service.generate_cache()
|
||||
@@ -446,16 +524,18 @@ class ImportExportServicer(connpy_pb2_grpc.ImportExportServiceServicer):
|
||||
return Empty()
|
||||
|
||||
class StatusBridge:
|
||||
def __init__(self, q, request_queue=None):
|
||||
def __init__(self, q, request_queue=None, is_web=False):
|
||||
self.q = q
|
||||
self.request_queue = request_queue
|
||||
self.on_interrupt = self._force_interrupt
|
||||
self.thread = None
|
||||
self.is_web = is_web
|
||||
|
||||
def _force_interrupt(self):
|
||||
"""Forcefully raise KeyboardInterrupt in the target thread."""
|
||||
if self.thread and self.thread.ident:
|
||||
# Standard Python trick to raise an exception in a specific thread
|
||||
import ctypes
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_long(self.thread.ident),
|
||||
ctypes.py_object(KeyboardInterrupt)
|
||||
@@ -477,13 +557,32 @@ class StatusBridge:
|
||||
|
||||
def _print_to_queue(self, msg_type, *args, **kwargs):
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from io import StringIO
|
||||
from ..printer import connpy_theme
|
||||
|
||||
processed_args = list(args)
|
||||
if self.is_web:
|
||||
# Remove Panels to avoid box characters on web, but preserve Title
|
||||
processed_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Panel):
|
||||
# If it has a title, prepend it to the content to allow detection
|
||||
content = arg.renderable
|
||||
if arg.title:
|
||||
processed_args.append(f"{arg.title}\n")
|
||||
processed_args.append(content)
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
buf = StringIO()
|
||||
# Use a high-quality console for rendering with the app's theme
|
||||
c = Console(file=buf, force_terminal=True, width=100, theme=connpy_theme)
|
||||
c.print(*args, **kwargs)
|
||||
self.q.put((msg_type, buf.getvalue()))
|
||||
# force_terminal=False removes ANSI escape codes for Web
|
||||
c = Console(file=buf, force_terminal=not self.is_web, width=100, theme=connpy_theme)
|
||||
c.print(*processed_args, **kwargs)
|
||||
|
||||
text_content = buf.getvalue().strip()
|
||||
if text_content:
|
||||
self.q.put((msg_type, text_content))
|
||||
|
||||
def confirm(self, prompt, default="n"):
|
||||
"""Bridge confirmation to the gRPC client."""
|
||||
@@ -520,94 +619,108 @@ class AIServicer(connpy_pb2_grpc.AIServiceServicer):
|
||||
def ask(self, request_iterator, context):
|
||||
import queue
|
||||
import threading
|
||||
|
||||
# In bidirectional mode, the first request contains the query
|
||||
try:
|
||||
first_request = next(request_iterator)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
history = from_value(first_request.chat_history)
|
||||
|
||||
overrides = {}
|
||||
if first_request.engineer_model: overrides["engineer_model"] = first_request.engineer_model
|
||||
if first_request.engineer_api_key: overrides["engineer_api_key"] = first_request.engineer_api_key
|
||||
if first_request.architect_model: overrides["architect_model"] = first_request.architect_model
|
||||
if first_request.architect_api_key: overrides["architect_api_key"] = first_request.architect_api_key
|
||||
|
||||
chunk_queue = queue.Queue()
|
||||
request_queue = queue.Queue()
|
||||
bridge = StatusBridge(chunk_queue, request_queue=request_queue)
|
||||
bridge = None
|
||||
history = []
|
||||
is_web = False
|
||||
|
||||
# Start a thread to pull subsequent requests from the client (confirmations)
|
||||
def pull_requests():
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if req.interrupt and bridge.on_interrupt:
|
||||
bridge.on_interrupt()
|
||||
request_queue.put(req)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
request_queue.put(None)
|
||||
|
||||
threading.Thread(target=pull_requests, daemon=True).start()
|
||||
# Dedicated event to signal AI thread to stop
|
||||
ai_thread = None
|
||||
agent_instance = None
|
||||
|
||||
def callback(chunk):
|
||||
chunk_queue.put(("text", chunk))
|
||||
|
||||
result_container = {}
|
||||
|
||||
def run_ai():
|
||||
def run_ai_task(input_text, session_id, debug, overrides, trust):
|
||||
nonlocal history, bridge, agent_instance
|
||||
try:
|
||||
# Run the AI interaction (this blocks this specific thread)
|
||||
res = self.service.ask(
|
||||
first_request.input_text,
|
||||
dryrun=first_request.dryrun,
|
||||
input_text,
|
||||
chat_history=history if history else None,
|
||||
session_id=first_request.session_id if first_request.session_id else None,
|
||||
debug=first_request.debug,
|
||||
session_id=session_id,
|
||||
debug=debug,
|
||||
status=bridge,
|
||||
console=bridge,
|
||||
confirm_handler=bridge.confirm,
|
||||
chunk_callback=callback,
|
||||
trust=first_request.trust,
|
||||
trust=trust,
|
||||
**overrides
|
||||
)
|
||||
result_container["res"] = res
|
||||
|
||||
# Update history for next message
|
||||
if "chat_history" in res:
|
||||
history = res["chat_history"]
|
||||
|
||||
# Send final chunk marker
|
||||
chunk_queue.put(("final_mark", res))
|
||||
except Exception as e:
|
||||
chunk_queue.put(("status", f"[bold fail]Error: {str(e)}[/bold fail]"))
|
||||
result_container["error"] = e
|
||||
import traceback
|
||||
print(f"AI Task Error: {e}")
|
||||
traceback.print_exc()
|
||||
chunk_queue.put(("status", f"Error: {str(e)}"))
|
||||
|
||||
def request_listener():
|
||||
nonlocal bridge, is_web, ai_thread, agent_instance
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if req.interrupt:
|
||||
if bridge and bridge.on_interrupt:
|
||||
bridge.on_interrupt()
|
||||
continue
|
||||
|
||||
if req.confirmation_answer:
|
||||
request_queue.put(req)
|
||||
continue
|
||||
|
||||
if req.input_text:
|
||||
is_web = "web" in (req.session_id or "").lower() or (req.session_id or "").lower().startswith("ws-")
|
||||
if not bridge:
|
||||
bridge = StatusBridge(chunk_queue, request_queue=request_queue, is_web=is_web)
|
||||
|
||||
overrides = {}
|
||||
if req.engineer_model: overrides["engineer_model"] = req.engineer_model
|
||||
if req.engineer_api_key: overrides["engineer_api_key"] = req.engineer_api_key
|
||||
|
||||
# Start AI in its own thread so we can keep listening for interrupts
|
||||
ai_thread = threading.Thread(
|
||||
target=run_ai_task,
|
||||
args=(req.input_text, req.session_id, req.debug, overrides, req.trust),
|
||||
daemon=True
|
||||
)
|
||||
ai_thread.start()
|
||||
except Exception as e:
|
||||
print(f"Request Listener Error: {e}")
|
||||
finally:
|
||||
chunk_queue.put(None) # Sentinel
|
||||
# When client closes stream, send sentinel
|
||||
chunk_queue.put((None, None))
|
||||
|
||||
t = threading.Thread(target=run_ai, daemon=True)
|
||||
bridge.thread = t
|
||||
t.start()
|
||||
# Start listening for client requests/signals
|
||||
threading.Thread(target=request_listener, daemon=True).start()
|
||||
|
||||
# Main response loop (yields to gRPC)
|
||||
while True:
|
||||
item = chunk_queue.get()
|
||||
if item is None:
|
||||
if item == (None, None):
|
||||
break
|
||||
|
||||
|
||||
msg_type, val = item
|
||||
if msg_type == "text":
|
||||
yield connpy_pb2.AIResponse(text_chunk=val, is_final=False)
|
||||
elif msg_type == "status":
|
||||
yield connpy_pb2.AIResponse(status_update=val, is_final=False)
|
||||
if is_web and "is thinking" in val.lower(): continue
|
||||
clean_val = val.replace("[ai_status]", "").replace("[/ai_status]", "")
|
||||
yield connpy_pb2.AIResponse(status_update=clean_val, is_final=False)
|
||||
elif msg_type == "debug":
|
||||
yield connpy_pb2.AIResponse(debug_message=val, is_final=False)
|
||||
elif msg_type == "important":
|
||||
yield connpy_pb2.AIResponse(important_message=val, is_final=False)
|
||||
elif msg_type == "confirm":
|
||||
yield connpy_pb2.AIResponse(status_update=val, requires_confirmation=True, is_final=False)
|
||||
|
||||
if "error" in result_container:
|
||||
raise result_container["error"]
|
||||
|
||||
yield connpy_pb2.AIResponse(
|
||||
is_final=True,
|
||||
full_result=to_struct(result_container.get("res", {}))
|
||||
)
|
||||
elif msg_type == "final_mark":
|
||||
yield connpy_pb2.AIResponse(is_final=True, full_result=to_struct(val))
|
||||
|
||||
@handle_errors
|
||||
def confirm(self, request, context):
|
||||
@@ -663,8 +776,8 @@ class SystemServicer(connpy_pb2_grpc.SystemServiceServicer):
|
||||
class LoggingInterceptor(grpc.ServerInterceptor):
|
||||
def __init__(self):
|
||||
from rich.console import Console
|
||||
from ..printer import connpy_theme
|
||||
self.console = Console(theme=connpy_theme)
|
||||
from ..printer import connpy_theme, get_original_stdout
|
||||
self.console = Console(theme=connpy_theme, file=get_original_stdout())
|
||||
|
||||
def intercept_service(self, continuation, handler_call_details):
|
||||
import time
|
||||
@@ -73,11 +73,112 @@ class NodeStub:
|
||||
except OSError:
|
||||
break
|
||||
|
||||
# Fetch node details for the connection message
|
||||
try:
|
||||
node_details = self.get_node_details(unique_id)
|
||||
host = node_details.get("host", "unknown")
|
||||
port = str(node_details.get("port", ""))
|
||||
protocol = "sftp" if sftp else node_details.get("protocol", "ssh")
|
||||
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
conn_msg = f"Connected to {unique_id} at {host}{port_str} via: {protocol}"
|
||||
except Exception:
|
||||
conn_msg = f"Connected to {unique_id}"
|
||||
|
||||
old_tty = termios.tcgetattr(sys.stdin)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
response_iterator = self.stub.interact_node(request_generator())
|
||||
|
||||
# First response is connection status
|
||||
try:
|
||||
first_res = next(response_iterator)
|
||||
if first_res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
else:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {first_res.error_message}")
|
||||
return
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
os.write(sys.stdout.fileno(), res.stdout_data)
|
||||
finally:
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
|
||||
@handle_errors
|
||||
def connect_dynamic(self, connection_params, debug=False):
|
||||
import sys
|
||||
import select
|
||||
import tty
|
||||
import termios
|
||||
import os
|
||||
import json
|
||||
|
||||
params_json = json.dumps(connection_params)
|
||||
|
||||
def request_generator():
|
||||
cols, rows = 80, 24
|
||||
try:
|
||||
size = os.get_terminal_size()
|
||||
cols, rows = size.columns, size.lines
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
yield connpy_pb2.InteractRequest(
|
||||
id="dynamic", debug=debug, cols=cols, rows=rows,
|
||||
connection_params_json=params_json
|
||||
)
|
||||
|
||||
while True:
|
||||
r, _, _ = select.select([sys.stdin.fileno()], [], [])
|
||||
if r:
|
||||
try:
|
||||
data = os.read(sys.stdin.fileno(), 1024)
|
||||
if not data:
|
||||
break
|
||||
yield connpy_pb2.InteractRequest(stdin_data=data)
|
||||
except OSError:
|
||||
break
|
||||
|
||||
# Prepare connection message
|
||||
try:
|
||||
node_name = connection_params.get("name", "dynamic@remote")
|
||||
host = connection_params.get("host", "dynamic")
|
||||
port = str(connection_params.get("port", ""))
|
||||
protocol = connection_params.get("protocol", "ssh")
|
||||
port_str = f":{port}" if port and protocol not in ["ssm", "kubectl", "docker"] else ""
|
||||
conn_msg = f"Connected to {node_name} at {host}{port_str} via: {protocol}"
|
||||
except Exception:
|
||||
node_name = connection_params.get("name", "dynamic@remote") if isinstance(connection_params, dict) else "dynamic@remote"
|
||||
conn_msg = f"Connected to {node_name}"
|
||||
|
||||
old_tty = termios.tcgetattr(sys.stdin)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
response_iterator = self.stub.interact_node(request_generator())
|
||||
|
||||
# First response is connection status
|
||||
try:
|
||||
first_res = next(response_iterator)
|
||||
if first_res.success:
|
||||
# Connection established on server, show success message
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.success(conn_msg)
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
else:
|
||||
# Connection failed on server
|
||||
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_tty)
|
||||
printer.error(f"Connection failed: {first_res.error_message}")
|
||||
return
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
for res in response_iterator:
|
||||
if res.stdout_data:
|
||||
os.write(sys.stdout.fileno(), res.stdout_data)
|
||||
@@ -104,6 +205,10 @@ class NodeStub:
|
||||
def explode_unique(self, unique_id):
|
||||
return from_value(self.stub.explode_unique(connpy_pb2.IdRequest(id=unique_id)).data)
|
||||
|
||||
@handle_errors
|
||||
def validate_parent_folder(self, unique_id):
|
||||
self.stub.validate_parent_folder(connpy_pb2.IdRequest(id=unique_id))
|
||||
|
||||
@handle_errors
|
||||
def generate_cache(self, nodes=None, folders=None, profiles=None):
|
||||
# 1. Update remote cache on server
|
||||
@@ -226,6 +331,30 @@ class ProfileStub:
|
||||
if self.node_stub:
|
||||
self.node_stub._trigger_local_cache_sync()
|
||||
|
||||
class ConfigStub:
|
||||
def __init__(self, channel, remote_host):
|
||||
self.stub = connpy_pb2_grpc.ConfigServiceStub(channel)
|
||||
self.remote_host = remote_host
|
||||
|
||||
@handle_errors
|
||||
def get_settings(self):
|
||||
return from_struct(self.stub.get_settings(Empty()).data)
|
||||
|
||||
@handle_errors
|
||||
def update_setting(self, key, value):
|
||||
self.stub.update_setting(connpy_pb2.UpdateRequest(key=key, value=to_value(value)))
|
||||
|
||||
@handle_errors
|
||||
def get_default_dir(self):
|
||||
return self.stub.get_default_dir(Empty()).value
|
||||
|
||||
@handle_errors
|
||||
def set_config_folder(self, folder):
|
||||
self.stub.set_config_folder(connpy_pb2.StringRequest(value=folder))
|
||||
|
||||
@handle_errors
|
||||
def encrypt_password(self, password):
|
||||
return self.stub.encrypt_password(connpy_pb2.StringRequest(value=password)).value
|
||||
|
||||
class PluginStub:
|
||||
def __init__(self, channel, remote_host):
|
||||
+146
-30
@@ -1,7 +1,71 @@
|
||||
# Lazy-loaded printer module to speed up CLI startup
|
||||
_console = None
|
||||
_err_console = None
|
||||
_theme = None
|
||||
import sys
|
||||
import threading
|
||||
import io
|
||||
|
||||
_local = threading.local()
|
||||
|
||||
class ThreadLocalStream:
|
||||
def __init__(self, original):
|
||||
self._original = original
|
||||
|
||||
def _get_stream(self):
|
||||
s = getattr(_local, 'stream', None)
|
||||
return s if s is not None else self._original
|
||||
|
||||
def write(self, data):
|
||||
stream = self._get_stream()
|
||||
if stream:
|
||||
stream.write(data)
|
||||
|
||||
def flush(self):
|
||||
stream = self._get_stream()
|
||||
if stream:
|
||||
stream.flush()
|
||||
|
||||
def isatty(self):
|
||||
stream = self._get_stream()
|
||||
return stream.isatty() if stream else False
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Avoid recursion during initialization or if _original is not yet set
|
||||
if name in ('_original', '_get_stream'):
|
||||
raise AttributeError(name)
|
||||
stream = self._get_stream()
|
||||
if stream:
|
||||
return getattr(stream, name)
|
||||
raise AttributeError(f"'NoneType' object has no attribute '{name}'")
|
||||
|
||||
# Patch stdout/stderr only once at module level
|
||||
if not isinstance(sys.stdout, ThreadLocalStream):
|
||||
sys.stdout = ThreadLocalStream(sys.stdout)
|
||||
if not isinstance(sys.stderr, ThreadLocalStream):
|
||||
sys.stderr = ThreadLocalStream(sys.stderr)
|
||||
|
||||
def _get_local():
|
||||
if not hasattr(_local, 'console'):
|
||||
_local.console = None
|
||||
if not hasattr(_local, 'err_console'):
|
||||
_local.err_console = None
|
||||
if not hasattr(_local, 'theme'):
|
||||
_local.theme = None
|
||||
return _local
|
||||
|
||||
def set_thread_stream(stream):
|
||||
if stream is None:
|
||||
if hasattr(_local, 'stream'):
|
||||
del _local.stream
|
||||
else:
|
||||
_local.stream = stream
|
||||
|
||||
def get_original_stdout():
|
||||
if isinstance(sys.stdout, ThreadLocalStream):
|
||||
return sys.stdout._original
|
||||
return sys.stdout
|
||||
|
||||
def get_original_stderr():
|
||||
if isinstance(sys.stderr, ThreadLocalStream):
|
||||
return sys.stderr._original
|
||||
return sys.stderr
|
||||
|
||||
# Centralized design system
|
||||
STYLES = {
|
||||
@@ -23,24 +87,76 @@ STYLES = {
|
||||
}
|
||||
|
||||
def _get_console():
|
||||
global _console, _theme
|
||||
if _console is None:
|
||||
local = _get_local()
|
||||
|
||||
# Self-healing patch: if sys.stdout was replaced (e.g. by pytest), re-wrap it.
|
||||
if not isinstance(sys.stdout, ThreadLocalStream):
|
||||
sys.stdout = ThreadLocalStream(sys.stdout)
|
||||
|
||||
current_out = sys.stdout
|
||||
|
||||
# Detect if we need to recreate the console (stream changed or closed)
|
||||
needs_recreate = (local.console is None or
|
||||
getattr(local, '_last_stdout', None) is not current_out)
|
||||
|
||||
# Extra check for closed files in test environments
|
||||
if not needs_recreate and local.console is not None:
|
||||
try:
|
||||
if hasattr(local.console.file, 'closed') and local.console.file.closed:
|
||||
needs_recreate = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if needs_recreate:
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
if _theme is None:
|
||||
_theme = Theme(STYLES)
|
||||
_console = Console(theme=_theme)
|
||||
return _console
|
||||
if local.theme is None:
|
||||
local.theme = Theme(STYLES)
|
||||
local.console = Console(theme=local.theme, file=current_out)
|
||||
local._last_stdout = current_out
|
||||
|
||||
return local.console
|
||||
|
||||
def _get_err_console():
|
||||
global _err_console, _theme
|
||||
if _err_console is None:
|
||||
local = _get_local()
|
||||
|
||||
# Self-healing patch for stderr
|
||||
if not isinstance(sys.stderr, ThreadLocalStream):
|
||||
sys.stderr = ThreadLocalStream(sys.stderr)
|
||||
|
||||
current_err = sys.stderr
|
||||
|
||||
needs_recreate = (local.err_console is None or
|
||||
getattr(local, '_last_stderr', None) is not current_err)
|
||||
|
||||
if not needs_recreate and local.err_console is not None:
|
||||
try:
|
||||
if hasattr(local.err_console.file, 'closed') and local.err_console.file.closed:
|
||||
needs_recreate = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if needs_recreate:
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
if _theme is None:
|
||||
_theme = Theme(STYLES)
|
||||
_err_console = Console(stderr=True, theme=_theme)
|
||||
return _err_console
|
||||
if local.theme is None:
|
||||
local.theme = Theme(STYLES)
|
||||
local.err_console = Console(stderr=True, theme=local.theme, file=current_err)
|
||||
local._last_stderr = current_err
|
||||
|
||||
return local.err_console
|
||||
|
||||
def set_thread_console(console):
|
||||
_get_local().console = console
|
||||
|
||||
def set_thread_err_console(console):
|
||||
_get_local().err_console = console
|
||||
|
||||
def clear_thread_state():
|
||||
"""Removes all thread-local printer state. Useful for gRPC thread reuse."""
|
||||
for attr in ["stream", "console", "err_console", "theme", "_last_stdout", "_last_stderr"]:
|
||||
if hasattr(_local, attr):
|
||||
delattr(_local, attr)
|
||||
|
||||
@property
|
||||
def console():
|
||||
@@ -52,18 +168,18 @@ def err_console():
|
||||
|
||||
@property
|
||||
def connpy_theme():
|
||||
global _theme
|
||||
if _theme is None:
|
||||
local = _get_local()
|
||||
if local.theme is None:
|
||||
from rich.theme import Theme
|
||||
_theme = Theme(STYLES)
|
||||
return _theme
|
||||
local.theme = Theme(STYLES)
|
||||
return local.theme
|
||||
|
||||
def apply_theme(user_styles=None):
|
||||
"""
|
||||
Updates the global console themes with user-defined styles.
|
||||
If a style is missing in user_styles, it falls back to the default in STYLES.
|
||||
"""
|
||||
global _theme, _console, _err_console
|
||||
local = _get_local()
|
||||
from rich.theme import Theme
|
||||
|
||||
# Start with a copy of defaults
|
||||
@@ -74,11 +190,11 @@ def apply_theme(user_styles=None):
|
||||
if key in active_styles:
|
||||
active_styles[key] = value
|
||||
|
||||
_theme = Theme(active_styles)
|
||||
if _console:
|
||||
_console.push_theme(_theme)
|
||||
if _err_console:
|
||||
_err_console.push_theme(_theme)
|
||||
local.theme = Theme(active_styles)
|
||||
if local.console:
|
||||
local.console.push_theme(local.theme)
|
||||
if local.err_console:
|
||||
local.err_console.push_theme(local.theme)
|
||||
return active_styles
|
||||
|
||||
|
||||
@@ -273,10 +389,10 @@ err_console = _ErrConsoleProxy()
|
||||
# theme also needs to be lazy
|
||||
class _ThemeProxy:
|
||||
def __getattr__(self, name):
|
||||
global _theme
|
||||
if _theme is None:
|
||||
local = _get_local()
|
||||
if local.theme is None:
|
||||
from rich.theme import Theme
|
||||
_theme = Theme(STYLES)
|
||||
return getattr(_theme, name)
|
||||
local.theme = Theme(STYLES)
|
||||
return getattr(local.theme, name)
|
||||
|
||||
connpy_theme = _ThemeProxy()
|
||||
|
||||
@@ -16,6 +16,7 @@ service NodeService {
|
||||
rpc delete_node (DeleteRequest) returns (google.protobuf.Empty) {}
|
||||
rpc move_node (MoveRequest) returns (google.protobuf.Empty) {}
|
||||
rpc bulk_add (BulkRequest) returns (google.protobuf.Empty) {}
|
||||
rpc validate_parent_folder (IdRequest) returns (google.protobuf.Empty) {}
|
||||
rpc set_reserved_names (ListRequest) returns (google.protobuf.Empty) {}
|
||||
rpc interact_node (stream InteractRequest) returns (stream InteractResponse) {}
|
||||
rpc full_replace (FullReplaceRequest) returns (google.protobuf.Empty) {}
|
||||
@@ -87,10 +88,13 @@ message InteractRequest {
|
||||
bytes stdin_data = 4;
|
||||
int32 cols = 5;
|
||||
int32 rows = 6;
|
||||
string connection_params_json = 7;
|
||||
}
|
||||
|
||||
message InteractResponse {
|
||||
bytes stdout_data = 1;
|
||||
bool success = 2;
|
||||
string error_message = 3;
|
||||
}
|
||||
|
||||
message FilterRequest {
|
||||
|
||||
@@ -73,10 +73,13 @@ class NodeService(BaseService):
|
||||
|
||||
def get_node_details(self, unique_id):
|
||||
"""Return full configuration dictionary for a specific node."""
|
||||
details = self.config.getitem(unique_id)
|
||||
if not details:
|
||||
try:
|
||||
details = self.config.getitem(unique_id)
|
||||
if not details:
|
||||
raise NodeNotFoundError(f"Node '{unique_id}' not found.")
|
||||
return details
|
||||
except (KeyError, TypeError):
|
||||
raise NodeNotFoundError(f"Node '{unique_id}' not found.")
|
||||
return details
|
||||
|
||||
def explode_unique(self, unique_id):
|
||||
"""Explode a unique ID into a dictionary of its parts."""
|
||||
@@ -86,6 +89,14 @@ class NodeService(BaseService):
|
||||
"""Generate and update the internal nodes cache."""
|
||||
self.config._generate_nodes_cache(nodes=nodes, folders=folders, profiles=profiles)
|
||||
|
||||
def validate_parent_folder(self, unique_id):
|
||||
"""Check if parent folder exists for a given node unique ID."""
|
||||
node_folder = unique_id.partition("@")[2]
|
||||
if node_folder:
|
||||
parent_folder = f"@{node_folder}"
|
||||
if parent_folder not in self.config._getallfolders():
|
||||
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
|
||||
|
||||
|
||||
def add_node(self, unique_id, data, is_folder=False):
|
||||
"""Logic for adding a new node or folder to configuration."""
|
||||
@@ -104,9 +115,7 @@ class NodeService(BaseService):
|
||||
|
||||
# Check if parent folder exists when creating a subfolder
|
||||
if "subfolder" in uniques:
|
||||
parent_folder = f"@{uniques['folder']}"
|
||||
if parent_folder not in all_folders:
|
||||
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
|
||||
self.validate_parent_folder(unique_id)
|
||||
|
||||
self.config._folder_add(**uniques)
|
||||
self.config._saveconfig(self.config.file)
|
||||
@@ -115,11 +124,7 @@ class NodeService(BaseService):
|
||||
raise NodeAlreadyExistsError(f"Node '{unique_id}' already exists.")
|
||||
|
||||
# Check if parent folder exists when creating a node in a folder
|
||||
node_folder = unique_id.partition("@")[2]
|
||||
if node_folder:
|
||||
parent_folder = f"@{node_folder}"
|
||||
if parent_folder not in all_folders:
|
||||
raise NodeNotFoundError(f"Folder '{parent_folder}' not found.")
|
||||
self.validate_parent_folder(unique_id)
|
||||
|
||||
# Ensure 'id' is in data for config._connections_add
|
||||
if "id" not in data:
|
||||
|
||||
@@ -180,6 +180,7 @@ class PluginService(BaseService):
|
||||
from ..services.exceptions import InvalidConfigurationError
|
||||
from connpy.plugins import Plugins
|
||||
class MockApp:
|
||||
is_mock = True
|
||||
def __init__(self, config):
|
||||
from ..core import node, nodes
|
||||
from ..ai import ai
|
||||
@@ -191,14 +192,20 @@ class PluginService(BaseService):
|
||||
self.ai = ai
|
||||
|
||||
self.services = ServiceProvider(config, mode="local")
|
||||
|
||||
# Get settings for CLI behavior
|
||||
settings = self.services.config_svc.get_settings()
|
||||
self.case = settings.get("case", False)
|
||||
self.fzf = settings.get("fzf", False)
|
||||
|
||||
try:
|
||||
self.nodes_list = self.services.nodes.list_nodes()
|
||||
self.folders = self.services.nodes.list_folders()
|
||||
self.profiles = self.services.profiles.list_profiles()
|
||||
except Exception:
|
||||
self.nodes_list = {}
|
||||
self.folders = {}
|
||||
self.profiles = {}
|
||||
self.nodes_list = []
|
||||
self.folders = []
|
||||
self.profiles = []
|
||||
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
@@ -225,26 +232,26 @@ class PluginService(BaseService):
|
||||
from .. import printer
|
||||
from rich.console import Console
|
||||
|
||||
from rich.console import Console
|
||||
buf = io.StringIO()
|
||||
old_console = printer.console
|
||||
old_err_console = printer.err_console
|
||||
old_console = printer._get_console()
|
||||
old_err_console = printer._get_err_console()
|
||||
|
||||
printer.console = Console(file=buf, theme=printer.connpy_theme, force_terminal=True)
|
||||
printer.err_console = Console(file=buf, theme=printer.connpy_theme, force_terminal=True)
|
||||
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = buf
|
||||
printer.set_thread_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_err_console(Console(file=buf, theme=printer.connpy_theme, force_terminal=True))
|
||||
printer.set_thread_stream(buf)
|
||||
|
||||
try:
|
||||
if hasattr(module, "Entrypoint"):
|
||||
module.Entrypoint(args, parser, app)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
printer.err_console.print(traceback.format_exc())
|
||||
except BaseException as e:
|
||||
if not isinstance(e, SystemExit):
|
||||
import traceback
|
||||
printer.err_console.print(traceback.format_exc())
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
printer.console = old_console
|
||||
printer.err_console = old_err_console
|
||||
printer.set_thread_console(old_console)
|
||||
printer.set_thread_err_console(old_err_console)
|
||||
printer.set_thread_stream(None)
|
||||
|
||||
for line in buf.getvalue().splitlines(keepends=True):
|
||||
yield line
|
||||
|
||||
@@ -58,7 +58,7 @@ class ServiceProvider:
|
||||
raise InvalidConfigurationError("Remote host must be specified in remote mode")
|
||||
|
||||
import grpc
|
||||
from ..grpc.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
|
||||
from ..grpc_layer.stubs import NodeStub, ProfileStub, PluginStub, AIStub, ExecutionStub, ImportExportStub, SystemStub
|
||||
|
||||
channel = grpc.insecure_channel(self.remote_host)
|
||||
|
||||
|
||||
@@ -157,9 +157,9 @@ class SyncService(BaseService):
|
||||
if os.path.exists(self.config.key):
|
||||
zipf.write(self.config.key, ".osk")
|
||||
|
||||
# Manage retention (max 10 backups)
|
||||
# Manage retention (max 100 backups)
|
||||
backups = self.list_backups()
|
||||
if len(backups) >= 10:
|
||||
if len(backups) >= 100:
|
||||
oldest = min(backups, key=lambda x: x['timestamp'] or '0')
|
||||
self.delete_backup(oldest['id'])
|
||||
|
||||
@@ -360,7 +360,7 @@ class SyncService(BaseService):
|
||||
|
||||
if not sync_enabled: return
|
||||
|
||||
printer.info("Triggering auto-sync...")
|
||||
|
||||
if self.check_login_status() != True:
|
||||
printer.warning("Auto-sync: Not logged in to Google Drive.")
|
||||
return
|
||||
|
||||
@@ -269,16 +269,16 @@ class TestToolMethods:
|
||||
|
||||
def test_list_nodes_tool_found(self, myai):
|
||||
result = myai.list_nodes_tool("router.*")
|
||||
parsed = json.loads(result)
|
||||
parsed = json.loads(result) if isinstance(result, str) else result
|
||||
assert "router1" in str(parsed)
|
||||
|
||||
def test_list_nodes_tool_not_found(self, myai):
|
||||
result = myai.list_nodes_tool("nonexistent_pattern_xyz")
|
||||
assert "No nodes found" in result
|
||||
assert "No nodes found" in str(result)
|
||||
|
||||
def test_get_node_info_masks_password(self, myai):
|
||||
result = myai.get_node_info_tool("router1")
|
||||
parsed = json.loads(result)
|
||||
parsed = json.loads(result) if isinstance(result, str) else result
|
||||
assert parsed["password"] == "***"
|
||||
|
||||
def test_is_safe_command_show(self, myai):
|
||||
|
||||
@@ -99,6 +99,23 @@ class TestCommandGeneration:
|
||||
assert "telnet 10.0.0.1" in cmd
|
||||
assert "23" in cmd
|
||||
|
||||
def test_ssm_cmd_basic(self):
|
||||
n = self._make_node(protocol="ssm", host="i-12345")
|
||||
cmd = n._get_cmd()
|
||||
assert "aws ssm start-session" in cmd
|
||||
assert "--target i-12345" in cmd
|
||||
|
||||
def test_ssm_cmd_tags(self):
|
||||
n = self._make_node(protocol="ssm", host="i-12345", tags={"region": "us-west-2", "profile": "prod"})
|
||||
cmd = n._get_cmd()
|
||||
assert "--region us-west-2" in cmd
|
||||
assert "--profile prod" in cmd
|
||||
|
||||
def test_ssm_cmd_options(self):
|
||||
n = self._make_node(protocol="ssm", host="i-12345", options="--document-name AWS-StartInteractiveCommand")
|
||||
cmd = n._get_cmd()
|
||||
assert "--document-name AWS-StartInteractiveCommand" in cmd
|
||||
|
||||
def test_kubectl_cmd(self):
|
||||
n = self._make_node(protocol="kubectl", host="my-pod", tags={"kube_command": "/bin/sh"})
|
||||
cmd = n._get_cmd()
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
import pytest
|
||||
import grpc
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
from concurrent import futures
|
||||
from connpy.grpc_layer import server, connpy_pb2, connpy_pb2_grpc, stubs
|
||||
from connpy.services.exceptions import ConnpyError
|
||||
|
||||
class MockContext:
|
||||
def abort(self, code, details):
|
||||
raise Exception(f"gRPC Abort: {code} - {details}")
|
||||
|
||||
# --- UNIT TESTS (with mocks) ---
|
||||
|
||||
class TestNodeServicerNaming:
|
||||
@pytest.fixture
|
||||
def servicer(self, populated_config):
|
||||
return server.NodeServicer(populated_config)
|
||||
|
||||
@patch("connpy.core.node")
|
||||
def test_interact_node_uses_passed_name(self, mock_node, servicer):
|
||||
# Setup request with custom name
|
||||
params = {"name": "custom-node-name@test", "host": "1.2.3.4", "protocol": "ssh"}
|
||||
request = connpy_pb2.InteractRequest(
|
||||
id="dynamic",
|
||||
connection_params_json=json.dumps(params)
|
||||
)
|
||||
|
||||
# Mock node to allow _connect
|
||||
mock_node_instance = MagicMock()
|
||||
mock_node_instance._connect.return_value = True
|
||||
mock_node.return_value = mock_node_instance
|
||||
|
||||
# We only need the first iteration of the generator to check naming
|
||||
gen = servicer.interact_node(iter([request]), MockContext())
|
||||
next(gen) # Skip the success response
|
||||
|
||||
# Verify that node() was called with the custom name
|
||||
mock_node.assert_called()
|
||||
found = False
|
||||
for call in mock_node.call_args_list:
|
||||
if call.args[0] == "custom-node-name@test":
|
||||
found = True
|
||||
break
|
||||
assert found
|
||||
|
||||
@patch("connpy.core.node")
|
||||
def test_interact_node_fallback_naming(self, mock_node, servicer):
|
||||
# Setup request without custom name but with host
|
||||
params = {"host": "my-instance", "protocol": "ssm"}
|
||||
request = connpy_pb2.InteractRequest(
|
||||
id="dynamic",
|
||||
connection_params_json=json.dumps(params)
|
||||
)
|
||||
|
||||
mock_node_instance = MagicMock()
|
||||
mock_node_instance._connect.return_value = True
|
||||
mock_node.return_value = mock_node_instance
|
||||
|
||||
gen = servicer.interact_node(iter([request]), MockContext())
|
||||
next(gen)
|
||||
|
||||
# Verify fallback name: dynamic-{host}@remote
|
||||
found = False
|
||||
for call in mock_node.call_args_list:
|
||||
if call.args[0] == "dynamic-my-instance@remote":
|
||||
found = True
|
||||
break
|
||||
assert found
|
||||
|
||||
class TestStubsMessageFormatting:
|
||||
@patch("termios.tcsetattr")
|
||||
@patch("termios.tcgetattr")
|
||||
@patch("tty.setraw")
|
||||
@patch("os.read")
|
||||
@patch("select.select")
|
||||
def test_connect_dynamic_msg_formatting_ssm(self, mock_select, mock_read, mock_setraw, mock_getattr, mock_setattr):
|
||||
from connpy.grpc_layer.stubs import NodeStub
|
||||
|
||||
mock_getattr.return_value = [0, 0, 0, 0, 0, 0, [0] * 32]
|
||||
mock_channel = MagicMock()
|
||||
stub = NodeStub(mock_channel, "localhost:8048")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.success = True
|
||||
stub.stub.interact_node.return_value = iter([mock_resp])
|
||||
|
||||
with patch("connpy.printer.success") as mock_success:
|
||||
with patch("sys.stdin.fileno", return_value=0):
|
||||
mock_select.return_value = ([], [], [])
|
||||
params = {"protocol": "ssm", "host": "i-12345", "name": "my-ssm-node@aws"}
|
||||
|
||||
with patch("select.select", side_effect=KeyboardInterrupt):
|
||||
try:
|
||||
stub.connect_dynamic(params)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
mock_success.assert_called()
|
||||
msg = mock_success.call_args[0][0]
|
||||
assert "Connected to my-ssm-node@aws" in msg
|
||||
assert "at i-12345" in msg
|
||||
assert ":22" not in msg
|
||||
assert "via: ssm" in msg
|
||||
|
||||
|
||||
# --- INTEGRATION TESTS (Real Server/Stub Communication) ---
|
||||
|
||||
class TestGRPCIntegration:
|
||||
@pytest.fixture
|
||||
def grpc_server(self, populated_config):
|
||||
"""Starts a local gRPC server for integration testing."""
|
||||
srv = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
|
||||
|
||||
# Register services
|
||||
connpy_pb2_grpc.add_NodeServiceServicer_to_server(server.NodeServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ProfileServiceServicer_to_server(server.ProfileServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ConfigServiceServicer_to_server(server.ConfigServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ExecutionServiceServicer_to_server(server.ExecutionServicer(populated_config), srv)
|
||||
connpy_pb2_grpc.add_ImportExportServiceServicer_to_server(server.ImportExportServicer(populated_config), srv)
|
||||
|
||||
port = srv.add_insecure_port('127.0.0.1:0')
|
||||
srv.start()
|
||||
yield f"127.0.0.1:{port}"
|
||||
srv.stop(0)
|
||||
|
||||
@pytest.fixture
|
||||
def channel(self, grpc_server):
|
||||
with grpc.insecure_channel(grpc_server) as channel:
|
||||
yield channel
|
||||
|
||||
@pytest.fixture
|
||||
def node_stub(self, channel):
|
||||
return stubs.NodeStub(channel, "localhost")
|
||||
|
||||
@pytest.fixture
|
||||
def profile_stub(self, channel):
|
||||
return stubs.ProfileStub(channel, "localhost")
|
||||
|
||||
@pytest.fixture
|
||||
def config_stub(self, channel):
|
||||
return stubs.ConfigStub(channel, "localhost")
|
||||
|
||||
def test_list_nodes_integration(self, node_stub):
|
||||
nodes = node_stub.list_nodes()
|
||||
assert "router1" in nodes
|
||||
assert "server1@office" in nodes
|
||||
|
||||
def test_get_node_details_integration(self, node_stub):
|
||||
details = node_stub.get_node_details("router1")
|
||||
assert details["host"] == "10.0.0.1"
|
||||
|
||||
def test_node_not_found_integration(self, node_stub):
|
||||
with pytest.raises(ConnpyError) as exc:
|
||||
node_stub.get_node_details("non-existent")
|
||||
assert "Node 'non-existent' not found." in str(exc.value)
|
||||
|
||||
def test_list_profiles_integration(self, profile_stub):
|
||||
profiles = profile_stub.list_profiles()
|
||||
assert "office-user" in profiles
|
||||
|
||||
def test_get_settings_integration(self, config_stub):
|
||||
settings = config_stub.get_settings()
|
||||
assert "idletime" in settings
|
||||
|
||||
def test_update_setting_integration(self, config_stub):
|
||||
config_stub.update_setting("idletime", 99)
|
||||
settings = config_stub.get_settings()
|
||||
assert settings["idletime"] == 99
|
||||
|
||||
def test_add_delete_node_integration(self, node_stub):
|
||||
node_stub.add_node("integration-test-node", {"host": "9.9.9.9"})
|
||||
assert "integration-test-node" in node_stub.list_nodes()
|
||||
node_stub.delete_node("integration-test-node")
|
||||
assert "integration-test-node" not in node_stub.list_nodes()
|
||||
|
||||
def test_import_yaml_integration(self, channel, node_stub):
|
||||
import yaml
|
||||
from connpy.grpc_layer import stubs
|
||||
stub = stubs.ImportExportStub(channel, "localhost")
|
||||
|
||||
# ImportExportService expects a flat dict of nodes, not a full config structure
|
||||
inventory = {
|
||||
"imported-node": {"host": "8.8.8.8", "protocol": "ssh", "type": "connection"}
|
||||
}
|
||||
yaml_content = yaml.dump(inventory)
|
||||
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
stub.import_from_file(temp_path)
|
||||
# Verify the node was imported and is visible via NodeStub
|
||||
nodes = node_stub.list_nodes()
|
||||
assert "imported-node" in nodes
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
@@ -0,0 +1,65 @@
|
||||
import threading
|
||||
import io
|
||||
import time
|
||||
import sys
|
||||
import pytest
|
||||
from connpy import printer
|
||||
|
||||
def test_printer_thread_isolation():
|
||||
"""Verify that printer output is isolated per thread when using set_thread_stream."""
|
||||
num_threads = 5
|
||||
iterations = 20
|
||||
results = {}
|
||||
|
||||
def worker(thread_id):
|
||||
# Create a private buffer for this thread
|
||||
buf = io.StringIO()
|
||||
printer.set_thread_stream(buf)
|
||||
|
||||
# Ensure we have a clean console for this thread
|
||||
# In a real gRPC request, this happens automatically as it's a new thread
|
||||
printer.set_thread_console(None)
|
||||
|
||||
# Each thread prints its own ID
|
||||
expected_msg = f"Thread-{thread_id}"
|
||||
for _ in range(iterations):
|
||||
printer.info(expected_msg)
|
||||
time.sleep(0.01)
|
||||
|
||||
results[thread_id] = buf.getvalue()
|
||||
printer.set_thread_stream(None)
|
||||
|
||||
threads = []
|
||||
for i in range(num_threads):
|
||||
t = threading.Thread(target=worker, args=(i,))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Validation
|
||||
for thread_id, output in results.items():
|
||||
expected_msg = f"Thread-{thread_id}"
|
||||
assert expected_msg in output
|
||||
|
||||
# Ensure no leaks
|
||||
for other_id in range(num_threads):
|
||||
if other_id == thread_id: continue
|
||||
assert f"Thread-{other_id}" not in output
|
||||
|
||||
def test_printer_manual_stream():
|
||||
"""Verify that setting a thread stream correctly captures printer output in the current thread."""
|
||||
buf = io.StringIO()
|
||||
|
||||
# We must clear the thread-local console to force it to pick up the new sys.stdout proxy
|
||||
printer.set_thread_console(None)
|
||||
printer.set_thread_stream(buf)
|
||||
|
||||
printer.info("Captured-Message")
|
||||
|
||||
output = buf.getvalue()
|
||||
printer.set_thread_stream(None)
|
||||
printer.set_thread_console(None)
|
||||
|
||||
assert "Captured-Message" in output
|
||||
Reference in New Issue
Block a user