Source code for srunx.ssh.core.client

import logging
import os
import re
import shlex
import subprocess
import time
import uuid
from dataclasses import dataclass
from pathlib import Path

import paramiko  # type: ignore

from srunx.sync import RsyncClient

from .proxy_client import ProxySSHClient, create_proxy_aware_connection


[docs] @dataclass class SlurmJob: job_id: str name: str status: str = "UNKNOWN" output_file: str | None = None error_file: str | None = None script_path: str | None = None # Path to script on server is_local_script: bool = False # Whether script was uploaded from local _cleanup: bool = False # Whether to cleanup temporary files
[docs] class SSHSlurmClient:
[docs] def __init__( self, hostname: str, username: str, password: str | None = None, key_filename: str | None = None, port: int = 22, proxy_jump: str | None = None, ssh_config_path: str | None = None, env_vars: dict | None = None, verbose: bool = False, ): self.hostname = hostname self.username = username self.password = password self.key_filename = key_filename self.port = port self.proxy_jump = proxy_jump self.ssh_config_path = ssh_config_path self.ssh_client: paramiko.SSHClient | None = None self.sftp_client: paramiko.SFTPClient | None = None self.proxy_client: ProxySSHClient | None = None self.logger = logging.getLogger(__name__) self.temp_dir = os.getenv("SRUNX_TEMP_DIR", "/tmp/srunx") self._slurm_path: str | None = None # Cache for SLURM command paths self._slurm_env: dict[str, str] | None = ( None # Cache for SLURM environment variables ) self.custom_env_vars = env_vars or {} # Custom environment variables to pass self.verbose = verbose # RsyncClient for project sync (key-based auth only) self._rsync_client: RsyncClient | None = None if self.key_filename: try: self._rsync_client = RsyncClient( hostname=self.hostname, username=self.username, port=self.port, key_filename=self.key_filename, proxy_jump=self.proxy_jump, ssh_config_path=self.ssh_config_path, ) except RuntimeError: self.logger.warning("rsync not available; sync_project() disabled")
[docs] def connect(self) -> bool: try: if self.proxy_jump: # Use ProxyJump connection if not self.key_filename: raise ValueError("ProxyJump requires key-based authentication") self.ssh_client, self.proxy_client = create_proxy_aware_connection( hostname=self.hostname, username=self.username, key_filename=self.key_filename, port=self.port, proxy_jump=self.proxy_jump, ssh_config_path=self.ssh_config_path, logger=self.logger, ) else: # Direct connection self.ssh_client = paramiko.SSHClient() self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if self.key_filename: self.ssh_client.connect( hostname=self.hostname, username=self.username, key_filename=self.key_filename, port=self.port, ) else: self.ssh_client.connect( hostname=self.hostname, username=self.username, password=self.password, port=self.port, ) # Create SFTP client for file transfers self.sftp_client = self.ssh_client.open_sftp() # Create temp directory on server self.execute_command(f"mkdir -p {self.temp_dir}") # Initialize SLURM paths and environment self._initialize_slurm_paths() self._initialize_slurm_environment() # Final verification of SLURM setup self._verify_slurm_setup() connection_info = f"{self.hostname}" if self.proxy_jump: connection_info += f" (via {self.proxy_jump})" if self.verbose: self.logger.info(f"Successfully connected to {connection_info}") return True except Exception as e: self.logger.error(f"Failed to connect to {self.hostname}: {e}") return False
[docs] def disconnect(self): if self.sftp_client: self.sftp_client.close() self.sftp_client = None if self.ssh_client: self.ssh_client.close() self.ssh_client = None if self.proxy_client: self.proxy_client.close_proxy() self.proxy_client = None info = self.hostname + (f" (via {self.proxy_jump})" if self.proxy_jump else "") self.logger.info(f"Disconnected from {info}")
[docs] def test_connection(self) -> dict[str, str | bool]: """Test SSH connection and SLURM availability. Returns: Dictionary with test results including: - ssh_connected: Whether SSH connection succeeded - slurm_available: Whether SLURM commands are available - hostname: Remote hostname - user: Remote username - slurm_version: SLURM version if available - error: Error message if connection failed """ result: dict[str, str | bool] = { "ssh_connected": False, "slurm_available": False, "hostname": "", "user": "", "slurm_version": "", } try: # Test SSH connection if not self.connect(): result["error"] = "Failed to establish SSH connection" return result result["ssh_connected"] = True # Get hostname stdout_data, stderr_data, exit_code = self.execute_command("hostname") result["hostname"] = stdout_data.strip() # Get username stdout_data, stderr_data, exit_code = self.execute_command("whoami") result["user"] = stdout_data.strip() return result except Exception as e: result["error"] = str(e) return result finally: self.disconnect()
[docs] def execute_command(self, command: str) -> tuple[str, str, int]: if not self.ssh_client: raise ConnectionError("SSH client is not connected") stdin, stdout, stderr = self.ssh_client.exec_command(command) stdout_data = stdout.read().decode("utf-8", errors="replace") stderr_data = stderr.read().decode("utf-8", errors="replace") exit_code = stdout.channel.recv_exit_status() return stdout_data, stderr_data, exit_code
[docs] def upload_file(self, local_path: str, remote_path: str | None = None) -> str: """Upload a local file to the server and return the remote path""" if not self.sftp_client: raise ConnectionError("SFTP client is not connected") local_path_obj = Path(local_path) if not local_path_obj.exists(): raise FileNotFoundError(f"Local file not found: {local_path}") if remote_path is None: # Generate unique remote path in temp directory unique_id = str(uuid.uuid4())[:8] remote_filename = ( f"{local_path_obj.stem}_{unique_id}{local_path_obj.suffix}" ) remote_path = f"{self.temp_dir}/{remote_filename}" try: self.sftp_client.put(str(local_path_obj), remote_path) # Make the uploaded script executable if it's a script if local_path_obj.suffix in [".sh", ".py", ".pl", ".r"]: self.execute_command(f"chmod +x {remote_path}") if self.verbose: self.logger.info(f"Uploaded {local_path} to {remote_path}") return remote_path except Exception as e: self.logger.error(f"Failed to upload file: {e}") raise
[docs] def cleanup_file(self, remote_path: str) -> None: """Remove a file from the server""" try: self.execute_command(f"rm -f {remote_path}") if self.verbose: self.logger.info(f"Cleaned up remote file: {remote_path}") except Exception as e: self.logger.warning(f"Failed to cleanup file {remote_path}: {e}")
[docs] def file_exists(self, remote_path: str) -> bool: """Check if a file exists on the server""" stdout, stderr, exit_code = self.execute_command( f"test -f {remote_path} && echo 'exists' || echo 'not_found'" ) exists = stdout.strip() == "exists" self.logger.debug(f"File existence check for {remote_path}: {exists}") return exists
[docs] def validate_remote_script(self, remote_path: str) -> tuple[bool, str]: """Validate a remote script file and return (is_valid, error_message)""" # Check if file exists if not self.file_exists(remote_path): return False, f"Remote script file not found: {remote_path}" # Check if file is readable stdout, stderr, exit_code = self.execute_command( f"test -r {remote_path} && echo 'readable' || echo 'not_readable'" ) if stdout.strip() != "readable": return False, f"Remote script file is not readable: {remote_path}" # Check if file is executable (warn if not) stdout, stderr, exit_code = self.execute_command( f"test -x {remote_path} && echo 'executable' || echo 'not_executable'" ) if stdout.strip() != "executable": self.logger.warning( f"Remote script file is not executable: {remote_path}. SLURM may fail to run it." ) # Check file size (warn if empty or too large) stdout, stderr, exit_code = self.execute_command( f"wc -c < {remote_path} 2>/dev/null || echo '0'" ) try: file_size = int(stdout.strip()) if file_size == 0: self.logger.warning(f"Remote script file is empty: {remote_path}") elif file_size > 1024 * 1024: # 1MB self.logger.warning( f"Remote script file is very large ({file_size} bytes): {remote_path}" ) self.logger.debug(f"Remote script file size: {file_size} bytes") except ValueError: self.logger.warning(f"Could not determine file size for: {remote_path}") # Basic syntax check for shell scripts if remote_path.endswith(".sh"): stdout, stderr, exit_code = self.execute_command( f"bash -n {remote_path} 2>&1 || echo 'SYNTAX_ERROR'" ) if "SYNTAX_ERROR" in stdout or exit_code != 0: return ( False, f"Shell script syntax error in {remote_path}: {stdout.strip()}", ) self.logger.debug(f"Shell script syntax check passed for {remote_path}") return True, "Script validation successful"
def _initialize_slurm_paths(self) -> None: """Initialize SLURM command paths by checking common locations and environment""" try: # Try with login shell environment stdout, stderr, exit_code = self.execute_command( "bash -l -c 'echo $PATH' 2>/dev/null || echo ''" ) if stdout.strip(): login_path = stdout.strip() self.logger.debug(f"Login shell PATH: {login_path}") # Test if SLURM commands are available with login shell stdout, stderr, exit_code = self.execute_command( "bash -l -c 'which sbatch' 2>/dev/null || echo 'NOT_FOUND'" ) self.logger.debug( f"SLURM which command result: stdout='{stdout}', stderr='{stderr}', exit_code={exit_code}" ) if exit_code == 0 and "NOT_FOUND" not in stdout: sbatch_path = stdout.strip() self._slurm_path = sbatch_path.rsplit("/", 1)[0] # Get directory if self.verbose: self.logger.info(f"Found SLURM at: {self._slurm_path}") self.logger.debug(f"Full sbatch path: {sbatch_path}") return # Fallback: Check common SLURM installation paths common_paths = [ "/cm/shared/apps/slurm/current/bin", "/usr/bin", "/usr/local/bin", "/opt/slurm/bin", "/cluster/slurm/bin", ] for path in common_paths: stdout, stderr, exit_code = self.execute_command( f"test -f {path}/sbatch && echo 'FOUND' || echo 'NOT_FOUND'" ) self.logger.debug(f"Checking {path}: {stdout.strip()}") if "FOUND" in stdout: self._slurm_path = path if self.verbose: self.logger.info(f"Found SLURM at: {self._slurm_path}") # Verify permissions stdout, stderr, exit_code = self.execute_command( f"test -x {path}/sbatch && echo 'EXECUTABLE' || echo 'NOT_EXECUTABLE'" ) self.logger.debug(f"SLURM executable check: {stdout.strip()}") return self.logger.warning("SLURM commands not found in standard locations") except Exception as e: self.logger.warning(f"Failed to initialize SLURM paths: {e}") def _initialize_slurm_environment(self) -> None: """Capture the actual working SLURM environment""" try: # Test if sbatch works with full environment setup env_setup = self._get_slurm_env_setup() test_cmd = f"{env_setup} && which sbatch 2>/dev/null || echo 'NOT_FOUND'" stdout, stderr, exit_code = self.execute_command(f"bash -l -c '{test_cmd}'") self.logger.debug( f"SLURM verification test: stdout='{stdout}', stderr='{stderr}', exit_code={exit_code}" ) if exit_code == 0 and "NOT_FOUND" not in stdout and stdout.strip(): sbatch_location = stdout.strip() if self.verbose: self.logger.info(f"SLURM sbatch verified at: {sbatch_location}") # Try to get SLURM environment variables after setup env_cmd = ( f"{env_setup} && env | grep -E '^(SLURM|CLUSTER|PATH)' | head -20" ) stdout, stderr, exit_code = self.execute_command( f"bash -l -c '{env_cmd}'" ) if exit_code == 0 and stdout.strip(): # Parse environment variables env_vars = {} for line in stdout.strip().split("\n"): if "=" in line: key, value = line.split("=", 1) env_vars[key] = value self._slurm_env = env_vars if self.verbose: self.logger.info( f"Captured SLURM environment with {len(env_vars)} variables" ) self.logger.debug( f"Environment variables: {list(env_vars.keys())}" ) else: self.logger.warning( f"SLURM sbatch test failed. stdout: {stdout}, stderr: {stderr}" ) except Exception as e: self.logger.warning(f"Failed to initialize SLURM environment: {e}") def _get_slurm_command(self, command: str) -> str: """Get the full path for a SLURM command, or use login shell if path not found""" if self._slurm_path: return f"{self._slurm_path}/{command}" else: # Fallback to bash login shell - command will be handled in _execute_slurm_command return command def _execute_slurm_command(self, command: str) -> tuple[str, str, int]: """Execute a SLURM command with proper environment""" # Always use full login shell environment with proper SLURM command path env_setup = self._get_slurm_env_setup() # Build the command with full path if available if self._slurm_path: # Replace SLURM commands with full paths slurm_commands = [ "sbatch", "squeue", "sacct", "scancel", "scontrol", "sinfo", ] modified_command = command for cmd in slurm_commands: if modified_command.startswith(cmd + " ") or modified_command == cmd: modified_command = modified_command.replace( cmd, f"{self._slurm_path}/{cmd}", 1 ) break final_command = f"{env_setup} && {modified_command}" else: # Use login shell to find commands in PATH final_command = f"{env_setup} && {command}" # Execute with bash login shell full_cmd = f"bash -l -c '{final_command}'" self.logger.debug(f"Executing SLURM command: {full_cmd}") stdout, stderr, exit_code = self.execute_command(full_cmd) self.logger.debug( f"SLURM command result: exit_code={exit_code}, stdout_len={len(stdout)}, stderr_len={len(stderr)}" ) if stderr and exit_code != 0: self.logger.debug( f"SLURM command stderr: {stderr[:500]}..." ) # First 500 chars return stdout, stderr, exit_code def _get_slurm_env_setup(self) -> str: """Get environment setup commands for SLURM execution""" # Use the exact same environment setup as an interactive login shell env_commands = [ "cd ~", # Fully simulate login shell environment "source /etc/profile 2>/dev/null || true", "source ~/.bash_profile 2>/dev/null || true", "source ~/.bashrc 2>/dev/null || true", "source ~/.profile 2>/dev/null || true", # Try multiple methods to load SLURM environment "module load slurm 2>/dev/null || true", "module load slurm/current 2>/dev/null || true", # Check if sbatch is in PATH after module loads 'which sbatch >/dev/null 2>&1 || export PATH="$PATH:/cm/shared/apps/slurm/current/bin" 2>/dev/null || true', # Final fallback - add known SLURM paths 'export PATH="$PATH:/usr/local/bin:/opt/slurm/bin:/cluster/slurm/bin" 2>/dev/null || true', ] # Add custom environment variables for key, value in self.custom_env_vars.items(): # Use single quotes to avoid shell interpretation escaped_value = value.replace("'", "'\\''") env_commands.append(f"export {key}='{escaped_value}'") self.logger.debug(f"Adding custom environment variable: {key}={value}") return " && ".join(env_commands) def _verify_slurm_setup(self) -> None: """Verify that SLURM commands are working properly""" try: test_cmd = ( "sbatch --version 2>/dev/null | head -1 || echo 'SLURM_NOT_AVAILABLE'" ) stdout, stderr, exit_code = self._execute_slurm_command(test_cmd) if exit_code == 0 and "SLURM_NOT_AVAILABLE" not in stdout: if self.verbose: self.logger.info(f"SLURM verification successful: {stdout.strip()}") else: self.logger.warning(f"SLURM verification failed: {stdout} / {stderr}") except Exception as e: self.logger.warning(f"SLURM verification error: {e}")
[docs] def submit_sbatch_job( self, script_content: str, job_name: str | None = None, dependency: str | None = None, ) -> SlurmJob | None: """Submit an sbatch job with script content""" try: # Create temporary script file on server unique_id = str(uuid.uuid4())[:8] remote_script_path = f"{self.temp_dir}/job_{unique_id}.sh" self._write_remote_file(remote_script_path, script_content) # Make script executable self.execute_command(f"chmod +x {remote_script_path}") # Validate the uploaded script valid, validation_msg = self.validate_remote_script(remote_script_path) if not valid: self.logger.error(f"Script validation failed: {validation_msg}") return None # Submit the job cmd = f"{self._get_slurm_command('sbatch')}" if job_name: safe_name = re.sub(r"[^a-zA-Z0-9_.-]", "_", job_name) cmd += f" --job-name={safe_name}" if dependency: if not re.fullmatch(r"[a-z]+:\d+(,[a-z]+:\d+)*", dependency): raise ValueError(f"Invalid dependency format: {dependency!r}") cmd += f" --dependency={dependency}" cmd += f" {remote_script_path}" stdout, stderr, exit_code = self._execute_slurm_command(cmd) if exit_code == 0: match = re.search(r"Submitted batch job (\d+)", stdout) if match: job_id = match.group(1) return SlurmJob(job_id=job_id, name=job_name or f"job_{job_id}") self.logger.error( f"Failed to submit job. stdout: {stdout}, stderr: {stderr}, exit_code: {exit_code}" ) return None except Exception as e: self.logger.error(f"Job submission failed: {e}") return None
[docs] def submit_sbatch_file( self, script_path: str, job_name: str | None = None, cleanup: bool = True ) -> SlurmJob | None: """Submit an sbatch job from a local or remote file""" try: path_obj = Path(script_path) if path_obj.exists(): # Local file: upload to temp directory remote_path = self.upload_file(script_path) # Generate job name from filename if not provided if not job_name: job_name = path_obj.stem # Build sbatch command with job name and log file options safe_name = re.sub(r"[^a-zA-Z0-9_.-]", "_", job_name) cmd = f"{self._get_slurm_command('sbatch')}" cmd += f" -J {safe_name}" cmd += " -o $SLURM_LOG_DIR/%x_%j.log" cmd += f" {remote_path}" stdout, stderr, exit_code = self._execute_slurm_command(cmd) if exit_code == 0: match = re.search(r"Submitted batch job (\d+)", stdout) if match: job_id = match.group(1) job = SlurmJob(job_id=job_id, name=safe_name) job.script_path = remote_path job.is_local_script = True job._cleanup = cleanup return job self.logger.error( f"Failed to submit job. stdout: {stdout}, stderr: {stderr}, exit_code: {exit_code}" ) return None else: # Remote file: submit directly on server # Validate remote file valid, validation_msg = self.validate_remote_script(script_path) if not valid: self.logger.error( f"Remote script validation failed: {validation_msg}" ) return None # Generate job name from filename if not provided if not job_name: job_name = Path(script_path).stem # Build sbatch command with job name and log file options safe_name = re.sub(r"[^a-zA-Z0-9_.-]", "_", job_name) cmd = f"{self._get_slurm_command('sbatch')}" cmd += f" -J {safe_name}" cmd += " -o $SLURM_LOG_DIR/%x_%j.log" cmd += f" {script_path}" stdout, stderr, exit_code = self._execute_slurm_command(cmd) if exit_code == 0: match = re.search(r"Submitted batch job (\d+)", stdout) if match: job_id = match.group(1) job = SlurmJob(job_id=job_id, name=safe_name) job.script_path = script_path return job self.logger.error( f"Failed to submit job. stdout: {stdout}, stderr: {stderr}, exit_code: {exit_code}" ) return None except Exception as e: self.logger.error(f"Job submission failed: {e}") return None
[docs] def cleanup_job_files(self, job: SlurmJob) -> None: """Cleanup temporary files for a job if it was a local script""" if job.is_local_script and job.script_path and job._cleanup: self.cleanup_file(job.script_path)
[docs] def sync_project( self, local_path: str | None = None, remote_path: str | None = None, *, delete: bool = True, dry_run: bool = False, exclude_patterns: list[str] | None = None, ) -> str: """Sync the local project directory to the remote workspace via rsync. Args: local_path: Local project root to sync. If None, uses git toplevel or cwd. remote_path: Remote destination. If None, uses the default ``~/.config/srunx/workspace/{repo_name}/``. delete: Remove remote files not present locally (default True). dry_run: Preview what would be transferred without syncing. exclude_patterns: Additional exclude patterns for this sync. Returns: The remote project path (for use with ``sbatch --chdir``). Raises: RuntimeError: If rsync is not available or key-based auth is not configured. """ if self._rsync_client is None: raise RuntimeError( "sync_project() requires key-based SSH auth and rsync installed locally" ) if local_path is None: local_path = self._detect_project_root() if remote_path is None: remote_path = RsyncClient.get_default_remote_path(local_path) result = self._rsync_client.push( local_path, remote_path, delete=delete, dry_run=dry_run, exclude_patterns=exclude_patterns, ) if not result.success: raise RuntimeError( f"rsync failed (exit {result.returncode}): {result.stderr}" ) if self.verbose: self.logger.info(f"Project synced to {self.hostname}:{remote_path}") return remote_path
@staticmethod def _detect_project_root() -> str: """Detect the project root directory via git or fallback to cwd.""" try: result = subprocess.run( # noqa: S603, S607 ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True, ) if result.returncode == 0: return result.stdout.strip() except FileNotFoundError: pass return str(Path.cwd())
[docs] def get_job_status(self, job_id: str) -> str: """Get job status using SLURM commands""" try: # Validate job_id to prevent command injection # Allow numeric job IDs with optional step/array components (e.g., 12345, 12345_1, 12345.batch) if not re.match(r"^[0-9]+([._][A-Za-z0-9_-]+)?$", job_id): self.logger.error(f"Invalid job_id format: {job_id!r}") return "ERROR" # Try sacct for completed jobs, squeue for running/pending # Do NOT use _get_slurm_command here; _execute_slurm_command will prepend the path safely sacct_cmd = f"sacct -j {job_id} --format=JobID,State --noheader | grep -E '^[0-9]+' | head -1" stdout, stderr, exit_code = self._execute_slurm_command(sacct_cmd) if exit_code == 0 and stdout.strip(): status = stdout.strip().split()[1].split("+")[0] return status # Fallback to squeue squeue_cmd = f"squeue -j {job_id} -h -o %T | head -1" stdout, stderr, exit_code = self._execute_slurm_command(squeue_cmd) if exit_code == 0 and stdout.strip(): return stdout.strip().split("\n")[0].strip() return "NOT_FOUND" except Exception as e: self.logger.error(f"Failed to get job status for job {job_id}: {e}") return "ERROR"
@staticmethod def _quote_shell_path(path: str) -> str: """Quote a path for remote shell, handling ~ expansion. shlex.quote prevents ~ expansion, so paths starting with ~/ are converted to use $HOME with double quotes instead. """ if path.startswith("~/"): # Double quotes allow $HOME expansion while preventing word splitting suffix = path[2:] return '"$HOME/' + suffix + '"' return shlex.quote(path) @staticmethod def _sanitize_job_id(job_id: str) -> str: """Sanitize a SLURM job ID, supporting array and step IDs. Valid formats: 12345, 12345_4, 12345_[1-10], 12345.0 """ import re as _re job_id_str = str(job_id) if not _re.fullmatch(r"[0-9][0-9_.\[\]\-]*", job_id_str): raise ValueError(f"Invalid SLURM job ID: {job_id_str!r}") return job_id_str
[docs] def get_job_output( self, job_id: str, job_name: str | None = None, stdout_offset: int = 0, stderr_offset: int = 0, ) -> tuple[str, str, int, int]: """Get job output from SLURM log files. First tries ``scontrol show job`` to discover the actual StdOut/StdErr paths configured for the job. Falls back to pattern-based search if scontrol doesn't return usable paths. When *stdout_offset* / *stderr_offset* are non-zero, only the bytes **after** that position are returned (tail-like incremental reads). Returns: ``(stdout, stderr, new_stdout_offset, new_stderr_offset)`` """ try: import re as _re safe_job_id = self._sanitize_job_id(job_id) safe_job_name = None if job_name and _re.fullmatch(r"[\w\-\.]+", job_name): safe_job_name = job_name output_content = "" error_content = "" new_stdout_offset = stdout_offset new_stderr_offset = stderr_offset # ── 1. Try scontrol to get actual log paths ────────────── stdout_path, stderr_path = self._get_log_paths_from_scontrol(safe_job_id) if stdout_path: output_content, new_stdout_offset = self._read_file_from_offset( stdout_path, stdout_offset ) if stderr_path and stderr_path != stdout_path: error_content, new_stderr_offset = self._read_file_from_offset( stderr_path, stderr_offset ) if output_content or error_content: return ( output_content, error_content, new_stdout_offset, new_stderr_offset, ) # ── 2. Fallback: pattern-based file search (full read) ─── out, err = self._get_job_output_by_pattern(safe_job_id, safe_job_name) return out, err, len(out.encode()), len(err.encode()) except Exception as e: self.logger.error(f"Failed to get job output for {job_id}: {e}") return "", "", stdout_offset, stderr_offset
def _read_file_from_offset(self, path: str, offset: int) -> tuple[str, int]: """Read a remote file from a byte offset. Returns (content, new_offset).""" quoted = shlex.quote(path) if offset > 0: # tail -c +N reads from byte N (1-indexed) out, _, rc = self.execute_command( f"tail -c +{offset + 1} {quoted} 2>/dev/null" ) else: out, _, rc = self.execute_command(f"cat {quoted} 2>/dev/null") if rc != 0: return "", offset # Get current file size for next offset size_out, _, src = self.execute_command(f"wc -c < {quoted} 2>/dev/null") if src == 0 and size_out.strip().isdigit(): new_offset = int(size_out.strip()) else: new_offset = offset + len(out.encode()) return out, new_offset def _get_log_paths_from_scontrol( self, job_id: str ) -> tuple[str | None, str | None]: """Query ``scontrol show job`` for StdOut / StdErr paths.""" quoted_id = shlex.quote(job_id) stdout, _, rc = self._execute_slurm_command( f"scontrol show job {quoted_id} 2>/dev/null" ) if rc != 0 or not stdout.strip(): return None, None stdout_path: str | None = None stderr_path: str | None = None for line in stdout.splitlines(): for token in line.split(): if token.startswith("StdOut="): stdout_path = token.split("=", 1)[1] elif token.startswith("StdErr="): stderr_path = token.split("=", 1)[1] return stdout_path, stderr_path def _get_job_output_by_pattern( self, safe_job_id: str, safe_job_name: str | None ) -> tuple[str, str]: """Search common directories for SLURM log files by naming patterns.""" potential_log_patterns = [ f"{safe_job_name}_{safe_job_id}.log" if safe_job_name else None, f"*_{safe_job_id}.log", f"slurm-{safe_job_id}.out", f"slurm-{safe_job_id}.err", f"job_{safe_job_id}.log", f"{safe_job_id}.log", f"*_{safe_job_id}.out", ] patterns = [p for p in potential_log_patterns if p is not None] log_dirs = [ os.environ.get("SLURM_LOG_DIR", "~/logs/slurm"), "./", "/tmp", "/var/log/slurm", ] output_content = "" error_content = "" found_files: list[str] = [] for log_dir in log_dirs: quoted_dir = self._quote_shell_path(log_dir) for pattern in patterns: quoted_pattern = shlex.quote(pattern) find_cmd = f"find {quoted_dir} -name {quoted_pattern} -type f 2>/dev/null | head -5" stdout, stderr, exit_code = self.execute_command(find_cmd) if exit_code == 0 and stdout.strip(): log_files = stdout.strip().split("\n") for log_file in log_files: if log_file.strip(): found_files.append(log_file.strip()) self.logger.debug( f"Found potential log file: {log_file.strip()}" ) if found_files: primary_log = found_files[0] quoted_log = shlex.quote(primary_log) stdout_output, _, _ = self.execute_command( f"cat {quoted_log} 2>/dev/null || echo 'Could not read log file'" ) output_content = stdout_output if len(found_files) > 1: for log_file in found_files[1:]: if "err" in log_file.lower() or "error" in log_file.lower(): quoted_err_log = shlex.quote(log_file) stderr_output, _, _ = self.execute_command( f"cat {quoted_err_log} 2>/dev/null || echo ''" ) error_content += stderr_output if self.verbose: self.logger.info( f"Found {len(found_files)} log file(s) for job {safe_job_id}" ) self.logger.debug(f"Primary log file: {primary_log}") else: self.logger.warning( f"No log files found for job {safe_job_id} using common patterns" ) default_patterns: list[str] = [] if safe_job_name: default_patterns.append(f"{safe_job_name}_{safe_job_id}.log") default_patterns.append(f"{safe_job_id}.log") for pattern in default_patterns: quoted_pattern = shlex.quote(pattern) stdout_output, _, _ = self.execute_command( f"cat {quoted_pattern} 2>/dev/null || echo ''" ) output_content += stdout_output return output_content, error_content
[docs] def get_job_output_detailed( self, job_id: str, job_name: str | None = None ) -> dict[str, str | list[str] | None]: """Get detailed job output information including found log files""" try: # Sanitize inputs to prevent shell injection safe_job_id = self._sanitize_job_id(job_id) safe_job_name = None if job_name: # Only allow alphanumeric, underscore, hyphen, dot in job names import re as _re if _re.fullmatch(r"[\w\-\.]+", job_name): safe_job_name = job_name else: self.logger.warning( f"Rejecting unsafe job_name for log search: {job_name!r}" ) # Try multiple common SLURM log file patterns potential_log_patterns = [ # Pattern from SBATCH directives: %x_%j.log (job_name_job_id.log) f"{safe_job_name}_{safe_job_id}.log" if safe_job_name else None, # Common SLURM_LOG_DIR patterns f"*_{safe_job_id}.log", # Default SLURM patterns f"{safe_job_name}_{safe_job_id}.log" if safe_job_name else None, f"{safe_job_id}.log", # Alternative patterns f"job_{safe_job_id}.log", f"*_{safe_job_id}.out", ] patterns = [p for p in potential_log_patterns if p is not None] log_dirs = [ os.environ.get("SLURM_LOG_DIR", "~/logs/slurm"), "./", "/tmp", "/var/log/slurm", ] found_files: list[str] = [] primary_log: str | None = None output_content = "" error_content = "" for log_dir in log_dirs: quoted_dir = self._quote_shell_path(log_dir) for pattern in patterns: quoted_pattern = shlex.quote(pattern) find_cmd = f"find {quoted_dir} -name {quoted_pattern} -type f 2>/dev/null | head -5" stdout, stderr, exit_code = self.execute_command(find_cmd) if exit_code == 0 and stdout.strip(): log_files = stdout.strip().split("\n") for log_file in log_files: if log_file.strip(): found_files.append(log_file.strip()) found_files = list(set(found_files)) if found_files: primary_log = found_files[0] quoted_log = shlex.quote(primary_log) stdout_output, _, _ = self.execute_command( f"cat {quoted_log} 2>/dev/null || echo 'Could not read log file'" ) output_content = stdout_output if len(found_files) > 1: for log_file in found_files[1:]: if "err" in log_file.lower() or "error" in log_file.lower(): quoted_err_log = shlex.quote(log_file) stderr_output, _, _ = self.execute_command( f"cat {quoted_err_log} 2>/dev/null || echo ''" ) error_content += stderr_output return { "found_files": found_files, "primary_log": primary_log, "output": output_content, "error": error_content, "slurm_log_dir": os.environ.get("SLURM_LOG_DIR"), "searched_dirs": log_dirs, } except Exception as e: self.logger.error(f"Failed to get detailed output for {job_id}: {e}") return { "found_files": [], "primary_log": None, "output": "", "error": "", "slurm_log_dir": os.environ.get("SLURM_LOG_DIR"), "searched_dirs": [], }
[docs] def monitor_job( self, job: SlurmJob, poll_interval: int = 10, timeout: int | None = None ) -> SlurmJob: """Monitor a job until completion""" start_time = time.time() while True: current_time = time.time() elapsed_time = current_time - start_time job.status = self.get_job_status(job.job_id) if job.status in [ "COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "NOT_FOUND", ]: break if timeout and elapsed_time > timeout: job.status = "TIMEOUT" break time.sleep(poll_interval) return job
[docs] def tail_log( self, job_id: str, job_name: str | None = None, follow: bool = False, last_n: int | None = None, poll_interval: float = 1.0, ) -> dict[str, str | bool | None]: """Display job logs with optional real-time streaming via SSH. Args: job_id: SLURM job ID job_name: Job name for better log file detection follow: If True, continuously stream new log lines (like tail -f) last_n: Show only the last N lines poll_interval: Polling interval in seconds for follow mode Returns: Dictionary with log information: - success: Whether log retrieval was successful - log_content: Log content (empty in follow mode) - tail_command: Command to execute for follow mode (None in static mode) - status_message: Status or error message - log_file: Path to the primary log file """ # Find the log file log_info = self.get_job_output_detailed(job_id, job_name) primary_log = log_info.get("primary_log") found_files = log_info.get("found_files", []) if not found_files: searched_dirs = log_info.get("searched_dirs", []) searched_dirs_list = ( searched_dirs if isinstance(searched_dirs, list) else [] ) msg = f"No log files found for job {job_id}\n" msg += f"Searched in: {', '.join(searched_dirs_list)}\n" slurm_log_dir = log_info.get("slurm_log_dir") if slurm_log_dir: msg += f"SLURM_LOG_DIR: {slurm_log_dir}\n" return { "success": False, "log_content": "", "tail_command": None, "status_message": msg, "log_file": None, } if not primary_log: return { "success": False, "log_content": "", "tail_command": None, "status_message": "Could not find primary log file", "log_file": None, } # Convert primary_log to string for type safety primary_log_str = str(primary_log) if primary_log else "" # Quote the path to prevent shell metacharacter injection quoted_log_path = shlex.quote(primary_log_str) if follow: # Real-time streaming mode (like tail -f) # Build tail command tail_cmd = "tail -f" if last_n: tail_cmd = f"tail -n {last_n} -f" tail_cmd += f" {quoted_log_path}" return { "success": True, "log_content": "", "tail_command": tail_cmd, "status_message": f"Streaming logs from {primary_log_str} (Ctrl+C to stop)...", "log_file": primary_log_str, } else: # Static display mode output_raw = log_info.get("output", "") output = str(output_raw) if output_raw else "" if last_n and isinstance(output, str): lines = output.split("\n") output = "\n".join(lines[-last_n:]) if output: return { "success": True, "log_content": output, "tail_command": None, "status_message": f"Log file: {primary_log_str}", "log_file": primary_log_str, } else: return { "success": True, "log_content": "", "tail_command": None, "status_message": "Log file is empty", "log_file": primary_log_str, }
def _write_remote_file(self, remote_path: str, content: str) -> None: if not self.sftp_client: raise ConnectionError("SFTP client is not connected") with self.sftp_client.open(remote_path, "w") as f: f.write(content) def _handle_slurm_error( self, command: str, error_output: str, exit_code: int ) -> None: """Handle SLURM command errors with helpful suggestions""" self.logger.error( f"{command} failed with exit code {exit_code}: {error_output}" ) # Common error patterns and suggestions error_lower = error_output.lower() if "command not found" in error_lower or "sbatch: not found" in error_lower: self.logger.error("SLURM commands not found in PATH.") self.logger.error("Ensure SLURM is installed and configured on the server.") elif "permission denied" in error_lower: self.logger.error( "Permission denied. Check file permissions and user access." ) elif "invalid partition" in error_lower: self.logger.error( "Invalid partition specified. Check available partitions with 'sinfo'." ) elif "invalid qos" in error_lower: self.logger.error( "Invalid QoS specified. Check available QoS with 'sacctmgr show qos'." ) def _execute_with_environment(self, command: str) -> tuple[str, str, int]: """Execute command with full login environment""" return self.execute_command(f'bash -l -c "{command}"') def __enter__(self): if self.connect(): return self else: raise ConnectionError("Failed to establish SSH connection") def __exit__(self, exc_type, exc_val, exc_tb): self.disconnect()