"""Configuration management for srunx."""
import json
import os
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
from srunx.logging import get_logger
from srunx.models import ContainerResource
logger = get_logger(__name__)
[docs]
class ResourceDefaults(BaseModel):
"""Default resource configuration."""
nodes: int = Field(default=1, ge=1, description="Default number of compute nodes")
gpus_per_node: int = Field(
default=0, ge=0, description="Default number of GPUs per node"
)
ntasks_per_node: int = Field(
default=1, ge=1, description="Default number of tasks per node"
)
cpus_per_task: int = Field(
default=1, ge=1, description="Default number of CPUs per task"
)
memory_per_node: str | None = Field(
default=None, description="Default memory per node"
)
time_limit: str | None = Field(default=None, description="Default time limit")
nodelist: str | None = Field(default=None, description="Default nodelist")
partition: str | None = Field(default=None, description="Default partition")
[docs]
class EnvironmentDefaults(BaseModel):
"""Default environment configuration."""
conda: str | None = Field(default=None, description="Default conda environment")
venv: str | None = Field(
default=None, description="Default virtual environment path"
)
container: ContainerResource | None = Field(
default=None, description="Default container resource"
)
env_vars: dict[str, str] = Field(
default_factory=dict, description="Default environment variables"
)
[docs]
class NotificationConfig(BaseModel):
"""Notification configuration."""
slack_webhook_url: str | None = Field(
default=None, description="Slack webhook URL for notifications"
)
[docs]
class SrunxConfig(BaseModel):
"""Main srunx configuration."""
resources: ResourceDefaults = Field(default_factory=ResourceDefaults)
environment: EnvironmentDefaults = Field(default_factory=EnvironmentDefaults)
notifications: NotificationConfig = Field(default_factory=NotificationConfig)
log_dir: str = Field(default="logs", description="Default log directory")
work_dir: str | None = Field(default=None, description="Default working directory")
[docs]
def get_config_paths() -> list[Path]:
"""Get configuration file paths in order of precedence (lowest to highest)."""
paths = []
# System-wide config (for pip installations)
# On Unix: /etc/srunx/config.json
# On Windows: C:\ProgramData\srunx\config.json
if os.name == "posix":
paths.append(Path("/etc/srunx/config.json"))
else:
paths.append(Path("C:/ProgramData/srunx/config.json"))
# User-wide config
# On Unix: ~/.config/srunx/config.json
# On Windows: ~/AppData/Roaming/srunx/config.json
if os.name == "posix":
user_config_dir = Path.home() / ".config" / "srunx"
else:
user_config_dir = Path.home() / "AppData" / "Roaming" / "srunx"
paths.append(user_config_dir / "config.json")
# Project-wide config (current working directory)
paths.append(Path.cwd() / ".srunx.json")
paths.append(Path.cwd() / "srunx.json")
return paths
[docs]
def load_config_from_file(config_path: Path) -> dict[str, Any]:
"""Load configuration from a JSON file."""
try:
if config_path.exists():
logger.debug(f"Loading config from {config_path}")
with open(config_path, encoding="utf-8") as f:
return json.load(f)
except (OSError, json.JSONDecodeError) as e:
logger.warning(f"Failed to load config from {config_path}: {e}")
return {}
[docs]
def merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Recursively merge configuration dictionaries."""
result = base.copy()
for key, value in override.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = merge_config(result[key], value)
else:
result[key] = value
return result
[docs]
def load_config_from_env() -> dict[str, Any]:
"""Load configuration from environment variables."""
config: dict[str, Any] = {}
# Resource defaults from environment
resources: dict[str, Any] = {}
if nodes := os.getenv("SRUNX_DEFAULT_NODES"):
try:
resources["nodes"] = int(nodes)
except ValueError:
logger.warning(f"Invalid SRUNX_DEFAULT_NODES value: {nodes}")
if gpus := os.getenv("SRUNX_DEFAULT_GPUS_PER_NODE"):
try:
resources["gpus_per_node"] = int(gpus)
except ValueError:
logger.warning(f"Invalid SRUNX_DEFAULT_GPUS_PER_NODE value: {gpus}")
if ntasks := os.getenv("SRUNX_DEFAULT_NTASKS_PER_NODE"):
try:
resources["ntasks_per_node"] = int(ntasks)
except ValueError:
logger.warning(f"Invalid SRUNX_DEFAULT_NTASKS_PER_NODE value: {ntasks}")
if cpus := os.getenv("SRUNX_DEFAULT_CPUS_PER_TASK"):
try:
resources["cpus_per_task"] = int(cpus)
except ValueError:
logger.warning(f"Invalid SRUNX_DEFAULT_CPUS_PER_TASK value: {cpus}")
if memory := os.getenv("SRUNX_DEFAULT_MEMORY_PER_NODE"):
resources["memory_per_node"] = memory
if time_limit := os.getenv("SRUNX_DEFAULT_TIME_LIMIT"):
resources["time_limit"] = time_limit
if nodelist := os.getenv("SRUNX_DEFAULT_NODELIST"):
resources["nodelist"] = nodelist
if partition := os.getenv("SRUNX_DEFAULT_PARTITION"):
resources["partition"] = partition
if resources:
config["resources"] = resources
# Environment defaults from environment
environment: dict[str, Any] = {}
if conda := os.getenv("SRUNX_DEFAULT_CONDA"):
environment["conda"] = conda
if venv := os.getenv("SRUNX_DEFAULT_VENV"):
environment["venv"] = venv
if container := os.getenv("SRUNX_DEFAULT_CONTAINER"):
environment["container"] = {"image": container}
if container_runtime := os.getenv("SRUNX_DEFAULT_CONTAINER_RUNTIME"):
# Only override runtime on an existing container config —
# runtime alone (without image) is not a valid container.
if "container" in environment:
environment["container"]["runtime"] = container_runtime
if environment:
config["environment"] = environment
# General defaults from environment
if log_dir := os.getenv("SRUNX_DEFAULT_LOG_DIR"):
config["log_dir"] = log_dir
if work_dir := os.getenv("SRUNX_DEFAULT_WORK_DIR"):
config["work_dir"] = work_dir
return config
[docs]
def load_config() -> SrunxConfig:
"""Load configuration from all sources in order of precedence."""
# Start with empty config
config_data: dict[str, Any] = {}
# Load from config files (lowest to highest precedence)
for config_path in get_config_paths():
file_config = load_config_from_file(config_path)
if file_config:
config_data = merge_config(config_data, file_config)
# Override with environment variables (highest precedence)
env_config = load_config_from_env()
if env_config:
config_data = merge_config(config_data, env_config)
# Create and validate config
try:
return SrunxConfig.model_validate(config_data)
except Exception as e:
logger.warning(f"Failed to validate config: {e}. Using defaults.")
return SrunxConfig()
[docs]
def save_user_config(config: SrunxConfig) -> None:
"""Save configuration to user config file.
Merges SrunxConfig fields into the existing file so that
SSH profile data (managed by ConfigManager) is preserved.
"""
config_paths = get_config_paths()
# Use the user-wide config path (second in the list)
user_config_path = config_paths[1]
# Create directory if it doesn't exist
user_config_path.parent.mkdir(parents=True, exist_ok=True)
# Load existing data to preserve non-SrunxConfig keys (e.g. SSH profiles)
existing: dict[str, Any] = {}
if user_config_path.exists():
try:
with open(user_config_path, encoding="utf-8") as f:
content = f.read().strip()
if content:
existing = json.loads(content)
except (OSError, json.JSONDecodeError):
pass
# Merge: SrunxConfig fields overwrite, other keys preserved
existing.update(config.model_dump(exclude_unset=True))
# Save config
try:
with open(user_config_path, "w", encoding="utf-8") as f:
json.dump(existing, f, indent=2)
logger.info(f"Configuration saved to {user_config_path}")
except OSError as e:
logger.error(f"Failed to save config to {user_config_path}: {e}")
raise
[docs]
def create_example_config() -> str:
"""Create an example configuration file content."""
example_config = {
"resources": {
"nodes": 1,
"gpus_per_node": 1,
"ntasks_per_node": 1,
"cpus_per_task": 8,
"memory_per_node": "32GB",
"time_limit": "2:00:00",
"partition": "gpu",
},
"environment": {
"conda": "ml_env",
"container": {
"image": "nvcr.io/nvidia/pytorch:24.01-py3",
"runtime": "pyxis",
},
"env_vars": {"CUDA_VISIBLE_DEVICES": "0", "OMP_NUM_THREADS": "8"},
},
"log_dir": "slurm_logs",
"work_dir": "/scratch/username",
}
return json.dumps(example_config, indent=2)
# Global config instance
_config: SrunxConfig | None = None
[docs]
def get_config(reload: bool = False) -> SrunxConfig:
"""Get the global configuration instance."""
global _config
if _config is None or reload:
_config = load_config()
return _config