Files
spark-control/image/app/ssh.py
T
Grant ae8efa1754 Initial scaffold: image/ FastAPI app, models.yaml, docs
- image/ FastAPI app: /api/status, /api/swap, /api/swap/{id}/stream, /api/test-connection
- models.yaml: 5-model catalog (qwen3-vl, gemma4, qwen36, qwen3-235b-fp8, qwen25-72b)
- README, runbook, known-issues
- Dry-run swap verified against live Spark 1 (gemma4 currently loaded)
2026-05-12 09:29:13 -05:00

92 lines
2.9 KiB
Python

"""Async wrappers around the system `ssh` client.
We shell out rather than use Paramiko/asyncssh so that:
- Host key + auth behavior is identical to what a user would see at the shell.
- The same ssh config file (`~/.ssh/config`) and key files work in dev.
- We don't pull in a heavy crypto dependency for the container image.
"""
from __future__ import annotations
import asyncio
from typing import AsyncIterator
from .config import Settings
def _base_args(settings: Settings) -> list[str]:
args = [
"ssh",
"-o", "BatchMode=yes",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "ServerAliveInterval=15",
"-o", "ServerAliveCountMax=4",
]
if settings.ssh_key_path:
args += ["-i", settings.ssh_key_path]
if settings.ssh_known_hosts:
args += ["-o", f"UserKnownHostsFile={settings.ssh_known_hosts}"]
return args
async def ssh_run(
host: str,
user: str,
command: str,
settings: Settings,
timeout: float = 30.0,
) -> tuple[int, str, str]:
"""Run a one-shot SSH command. Returns (rc, stdout, stderr)."""
args = _base_args(settings) + [f"{user}@{host}", command]
proc = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout_b, stderr_b = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
return 124, "", f"timeout after {timeout}s"
assert proc.returncode is not None
return proc.returncode, stdout_b.decode(errors="replace"), stderr_b.decode(errors="replace")
class StreamHandle:
"""Holds the final returncode once an `ssh_stream()` generator completes."""
def __init__(self) -> None:
self.returncode: int | None = None
async def ssh_stream(
host: str,
user: str,
command: str,
settings: Settings,
handle: StreamHandle | None = None,
) -> AsyncIterator[str]:
"""Yield stdout (and merged stderr) lines from a long-running SSH command.
The generator may be aborted by closing it (e.g. `break` in `async for`);
the child SSH process is terminated and waited on in the `finally` block.
"""
args = _base_args(settings) + [f"{user}@{host}", command]
proc = await asyncio.create_subprocess_exec(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
assert proc.stdout is not None
try:
async for raw in proc.stdout:
yield raw.decode(errors="replace").rstrip("\r\n")
finally:
if proc.returncode is None:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=5)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
if handle is not None:
handle.returncode = proc.returncode