databricks-air 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. cli/__init__.py +5 -0
  2. cli/ai_training_client.py +51 -0
  3. cli/base_client.py +101 -0
  4. cli/changelog.md +57 -0
  5. cli/cli_display.py +723 -0
  6. cli/cli_entrypoint.py +3259 -0
  7. cli/cli_output.py +234 -0
  8. cli/cli_progress.py +152 -0
  9. cli/compute.py +78 -0
  10. cli/compute_options_client.py +37 -0
  11. cli/docker_utils.py +82 -0
  12. cli/error_detection.py +195 -0
  13. cli/get_env_secrets.py +135 -0
  14. cli/image_registration/__init__.py +41 -0
  15. cli/image_registration/docker_config_creds.py +143 -0
  16. cli/image_registration/image_client.py +224 -0
  17. cli/image_registration/image_credentials.py +317 -0
  18. cli/image_registration/image_policy.py +234 -0
  19. cli/jobs_api_client.py +1454 -0
  20. cli/json_output.py +187 -0
  21. cli/log_streaming.py +1376 -0
  22. cli/mlflow_metrics.py +173 -0
  23. cli/mlflow_rest_client.py +161 -0
  24. cli/mlflow_system_metrics.py +127 -0
  25. cli/node_sanity_check.sh +338 -0
  26. cli/run_config.py +756 -0
  27. cli/run_harness.py +255 -0
  28. cli/sdk/__init__.py +83 -0
  29. cli/sdk/compute.py +88 -0
  30. cli/sdk/config.py +602 -0
  31. cli/sdk/enums.py +69 -0
  32. cli/sdk/events.py +89 -0
  33. cli/sdk/exceptions.py +60 -0
  34. cli/sdk/models.py +108 -0
  35. cli/sdk/py.typed +0 -0
  36. cli/serverless_policy_client.py +104 -0
  37. cli/telemetry.py +232 -0
  38. cli/utils/__init__.py +134 -0
  39. cli/utils/auth.py +276 -0
  40. cli/utils/git_state.py +644 -0
  41. cli/utils/mapi/__init__.py +27 -0
  42. cli/utils/mapi/api.py +117 -0
  43. cli/utils/mapi/launch_script.py +577 -0
  44. cli/utils/retry.py +115 -0
  45. cli/utils/snapshot.py +255 -0
  46. cli/utils/uploads.py +237 -0
  47. cli/utils/workspace.py +140 -0
  48. cli/version.py +92 -0
  49. cli/yaml_config.py +19 -0
  50. cli/yaml_help.py +971 -0
  51. cli/yaml_overrides.py +239 -0
  52. databricks_air-0.1.0.dist-info/METADATA +30 -0
  53. databricks_air-0.1.0.dist-info/RECORD +55 -0
  54. databricks_air-0.1.0.dist-info/WHEEL +4 -0
  55. databricks_air-0.1.0.dist-info/entry_points.txt +2 -0
cli/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """Databricks AIR CLI package."""
2
+
3
+ # NOTE: version and git sha are determined at release time; the versions listed here are placeholders
4
+ __version__ = "0.1.0"
5
+ __git_sha__ = "4a95901bfafec7edf7c9f9f6919b19b1b2749338"
@@ -0,0 +1,51 @@
1
+ """Client for AiTrainingService — the per-user training-workflow surface on AICM.
2
+
3
+ Authenticated as the calling user via the Databricks SDK's [[WorkspaceClient]], matching how
4
+ [[cli.image_registration.image_client.ImageClient]] talks to AICM today.
5
+ """
6
+
7
+ import logging
8
+ from typing import Any, Dict, Optional
9
+
10
+ from cli.utils import get_workspace_client
11
+ from databricks.sdk import WorkspaceClient
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ class AiTrainingClient:
17
+ """REST client for /api/2.0/ai-training/workflows."""
18
+
19
+ API_PATH = "/api/2.0/ai-training/workflows"
20
+
21
+ def __init__(self, workspace_host: Optional[str] = None):
22
+ self.workspace_host = workspace_host
23
+ self._workspace_client: Optional[WorkspaceClient] = None
24
+
25
+ def _get_workspace_client(self) -> WorkspaceClient:
26
+ if self._workspace_client is None:
27
+ self._workspace_client = get_workspace_client(workspace_host=self.workspace_host)
28
+ return self._workspace_client
29
+
30
+ def _api_request(
31
+ self,
32
+ method: str,
33
+ endpoint: str,
34
+ data: Optional[Dict[str, Any]] = None,
35
+ params: Optional[Dict[str, Any]] = None,
36
+ ) -> Dict[str, Any]:
37
+ w = self._get_workspace_client()
38
+ path = f"{self.API_PATH}{endpoint}"
39
+ try:
40
+ response = w.api_client.do(method, path, body=data, query=params)
41
+ return response if response else {}
42
+ except Exception as e:
43
+ raise RuntimeError(f"AiTrainingService request failed: {e}") from e
44
+
45
+ def cancel_workflow(self, job_run_id: str) -> None:
46
+ """Cancel a training workflow by its Jobs job_run_id.
47
+
48
+ Maps to CancelTrainingWorkflow with a job_run_id ref. No-op server-side if the run is
49
+ already in a terminal state.
50
+ """
51
+ self._api_request(method="POST", endpoint=f"/by-run-id/{job_run_id}/cancel", data={})
cli/base_client.py ADDED
@@ -0,0 +1,101 @@
1
+ """Base client interface and types for workload management.
2
+ This module defines the abstract interface that all workload clients must implement,
3
+ ensuring consistent behavior across different backend APIs (Jobs API, CMv3, etc.).
4
+ It also defines the client type enumeration.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from enum import Enum
9
+ from typing import Dict, Any, Optional
10
+ from cli.compute import GPUType
11
+
12
+
13
+ class WorkloadClientType(Enum):
14
+ """Enum for different workload client types."""
15
+
16
+ JOBS_API = "Jobs API"
17
+ AI_COMPUTE_MANAGER = "AI Compute Manager (CMv3)"
18
+
19
+ def __str__(self) -> str:
20
+ """Return the string representation of the client type."""
21
+ return self.value
22
+
23
+
24
+ class WorkloadClient(ABC):
25
+ """Abstract base class for workload management clients.
26
+ This interface ensures that all workload clients (Jobs API, CMv3, etc.)
27
+ provide the same core functionality with consistent method signatures.
28
+ """
29
+
30
+ @abstractmethod
31
+ def create_workload(
32
+ self,
33
+ gpus: int,
34
+ gpus_per_node: int,
35
+ python_file: str,
36
+ func_dir: str,
37
+ experiment_name: str,
38
+ timeout_seconds: int,
39
+ param_to_args: dict,
40
+ env_sync_timeout: int,
41
+ gpu_type: GPUType,
42
+ max_retries: int = 3,
43
+ ) -> str:
44
+ """Create a workload and return workload/run ID.
45
+ Args:
46
+ gpus: Number of GPUs required
47
+ gpus_per_node: GPUs per node (for distributed workloads)
48
+ python_file: Path to the Python script to execute
49
+ func_dir: Directory containing the function
50
+ experiment_name: Name of the function to execute
51
+ timeout_seconds: Timeout for the workload
52
+ param_to_args: Parameters to pass to the function
53
+ env_sync_timeout: Environment synchronization timeout
54
+ gpu_type: Type of GPU required
55
+ max_retries: Maximum number of retries for failed tasks (default: 3)
56
+ Returns:
57
+ Workload/run ID as string
58
+ Raises:
59
+ Exception: If workload creation fails
60
+ """
61
+ pass
62
+
63
+ @abstractmethod
64
+ def get_workload_status(self, workload_id: str) -> Dict[str, Any]:
65
+ """Get workload status and metadata.
66
+ Args:
67
+ workload_id: Unique workload identifier
68
+ Returns:
69
+ Dictionary containing workload status and metadata
70
+ """
71
+ pass
72
+
73
+ @abstractmethod
74
+ def terminate_workload(self, workload_id: str, reason: Optional[str] = None) -> bool:
75
+ """Terminate a running workload.
76
+ Args:
77
+ workload_id: Unique workload identifier
78
+ reason: Optional reason for termination
79
+ Returns:
80
+ True if termination was successful, False otherwise
81
+ """
82
+ pass
83
+
84
+ @abstractmethod
85
+ def list_workloads(self, **kwargs) -> Dict[str, Any]:
86
+ """List workloads.
87
+ Args:
88
+ **kwargs: Client-specific parameters for listing (pagination, filters, etc.)
89
+ Returns:
90
+ Dictionary containing workloads and pagination info
91
+ """
92
+ pass
93
+
94
+ @property
95
+ @abstractmethod
96
+ def client_type(self) -> WorkloadClientType:
97
+ """Get the client type for logging and debugging.
98
+ Returns:
99
+ WorkloadClientType enum identifying the client type
100
+ """
101
+ pass
cli/changelog.md ADDED
@@ -0,0 +1,57 @@
1
+ # Start of CLI Changelog
2
+
3
+ Version 0.1.0
4
+ [CNXT-2129] air run --dry-run now implies -v unless verbosity is already set.
5
+ [CNXT-2130] air cancel: print a friendly "Run X not found" message instead of leaking the Jobs API URL.
6
+ [CNXT-2155] air now emits a "watch" telemetry event capturing submission-to-first-streamed-log latency on run --watch and the logs command.
7
+ [CNXT-2164] Fix `air run` crash on native Windows by skipping the POSIX-only SIGALRM upload timeout
8
+ Fix --json error envelope classification: derive retryable from a closed ErrorKind set (only TRANSIENT retries), classify uniformly by SDK type / HTTP status / filesystem error across all commands, and emit error{NOT_FOUND} for subscribe on a missing run
9
+ air --json now reports usage/early-exit failures (unknown subcommand, missing or invalid arguments, mutually-exclusive misuse, no command, and register/logs validation) as a structured USER error envelope on stdout with a non-zero exit instead of plain-text argparse usage or human help, and no longer double-prints the profile-not-found warning around the envelope
10
+ Launch artifacts now upload under .air/cli_launch/<experiment>/<run_name>_<uuid> with the command script named command.sh, and --help/--version help text is capitalized.
11
+ Deprecation warnings for renamed YAML keys now route through the `air` logger (shown on stderr, including in `--json` mode); the config schema and GPU types are now defined in the internal `air` SDK.
12
+ Show [Beta] badges in help output: `air --help` (AIR CLI is in Beta) and Beta GPU types (e.g. GPU_1xH100) in `air -h config.compute`
13
+ Stamp the installed CLI wheel version on BYOT submissions (client_version field) so the training service can correlate runs to a CLI release.
14
+ `air cancel --all` now lists target runs and workspace before prompting for y/yes confirmation; skip with -y/--yes.
15
+ Route `air cancel` through AiTrainingService instead of the Jobs API (no user-visible change).
16
+ Forward `credentials-for-read` response headers to the pre-signed URL fetch in `mlflow_rest_client.download_artifact`. Fixes log/artifact downloads on Azure-backed workspaces where SAS URIs require headers like `x-ms-version` to authenticate.
17
+ Route air diagnostics ([INFO]/[WARNING] messages) to stderr so `air logs | tee /dev/tty | pbcopy` captures only training log lines
18
+ Fix cryptic "'str' object has no attribute 'decode'" error on submit when no Databricks profile is configured; surface the real authentication error instead
19
+ Fix `sgcli --json run --watch` returning immediately with `status=PENDING` instead of blocking until terminal. The JSON path now emits a `SUBMITTED` event with run_id, streams `STATUS`/`LOG`/`ALERT` JSONL events through the watch, then emits a final success envelope whose `status` reflects the actual terminal state (`SUCCESS`/`FAILED`/`TIMEDOUT`/`CANCELED`).
20
+ Fix post-submit MLflow run-name update clobbering the user-supplied `mlflow_run_name`. The CLI now passes the user's run_name (when set) to the post-submit `/api/2.0/mlflow/runs/update` call, falling back to `job_run_name` (== `experiment_name`) only when the user did not set one. Previously the CLI always passed `job_run_name`, which overwrote the run name the ai-training server had just set from `gen_ai_compute_task.mlflow_run_name`.
21
+ Fix the "Waiting for run to start" spinner smearing into the first lines of log output during `air logs` and `air run --watch`. The Rich Live render thread was still drawing the spinner while raw bytes were being written to stdout from inside `print_new_logs`; `live.stop()` only ran in the caller after the write returned. The fix threads an `on_before_print` hook through `print_new_logs` that the streaming caller uses to tear down the spinner strictly before any byte hits stdout, and only when there is actually new content to emit (a no-op poll keeps the spinner running).
22
+ Fix spinner artifacts in redirected output and JSON mode by detecting TTY status before creating Rich console objects
23
+ [air-cli] Align `air get run` display with `air list runs`: rename "Task Run ID" -> "Run ID", adopt list-runs color palette, add User and Accelerators rows.
24
+ `air get run`: started_at and duration_seconds now describe the reported (latest) attempt instead of the whole run, so a retried run no longer reports time spent on earlier failed attempts.
25
+ validate GPU_* accelerator availability via AICM ListComputeOptions
26
+ [air-cli] Help/error polish: drop duplicate `--download_to` alias in `air logs -h`; add description lines to subcommand help; `air` (no args) and `air config[.field]` now print help; remove UNCOMMITTED-CHANGES table from `air -h config.code_source`; long YAML help (`air -h config[.field]`) now opens in a pager unless `--json` / piped; pull_wheel.sh gains `--download-only` (downloads to /tmp and prints the `uv tool install` command); `Failed to get run output: 404` during list-runs no longer logs at ERROR; YAML validation deduplicates the "Available fields are: …" list when multiple unknown fields share a parent.
27
+ Added support for inline dependencies in the workload YAML: environment.dependencies now accepts a list of packages alongside an environment.version field, as an alternative to pointing at a separate requirements.yaml file.
28
+ Remove `no_interpolation` field and `--no-interpolation` flag; OmegaConf variable interpolation is now disabled unconditionally. Literal `${VAR}` strings in YAML are preserved as-is. Bash `$VAR` shell expansion at runtime on the worker node is unaffected.
29
+ [air-cli] Plug --json envelope coverage gaps so machine consumers always get a structured envelope: `--version --json` returns `{version}` (instead of the ASCII banner); `--json changelog` returns `{version, changelog}`; `--json get pools` returns `{pools[], workspace_url, workspace_id}`; `--json run --dry-run` returns `{status: DRY_RUN_OK, dry_run: true}`; `--json register image` returns success envelopes (cached + new) and a classified error envelope (USER/INTERNAL_ERROR) on failure.
30
+ PuPr filter consolidation: replace `get runs` flags `--user`, `--experiment`, and the deprecated `--all` with a single repeatable `--filter KEY=VALUE` flag (supported keys: user, experiment). `--all-users` and the other listing flags are unchanged.
31
+ Pass user run_name through gen_ai_compute_task.mlflow_run_name on the TS path.
32
+ sgcli logs: follow multi-chunk MLflow log streams; default to last 10000 lines for completed runs
33
+ Group code_source.snapshot.{git_branch,git_commit,use_remote_head,remote_alias} under a nested `git:` object; consolidate remote behavior into bool-or-string `git.remote`.
34
+ Allow --override to add new YAML fields and auto-create missing nested blocks; typos still surface from Pydantic schema validation.
35
+ --override with an unknown field now reports a clean error naming the invalid key instead of a raw pydantic validation dump
36
+ Without git_branch/git_commit, package working tree as plain tar; upload git state (base/tip/dirty + diff) as WSFS sidecars for the backend to apply MLflow tags.
37
+ PuPr command renames: monitor->subscribe, get runs->list runs, get status->get run, get logs->logs; cancel now takes one or more run IDs or --all; removed --no-interpolation, get-run --watch, logs --local-rank; hid get pools, register image, logs --review; deprecated aliases (get runs, get status, get logs) still parse with a warning.
38
+ PuPr cleanup: `environment` block is now optional in YAML, and `--host` is removed from `sgcli list runs` / `sgcli get pools` (workspace comes from `-p/--profile`).
39
+ Expand `air register image -h` to walk through the three credential methods (docker login -- recommended, `--interactive-authenticate`, `--scope`/`--key`) with concrete examples.
40
+ [air-cli] Reject workload configs whose YAML exceeds 1 MB before submission, with a clear error pointing at oversized parameters/command fields.
41
+ Remove the `bash_script` YAML field. Use the top-level `command:` field instead. Configs that still specify `bash_script` now fail validation with a hint pointing at the replacement.
42
+ Removed legacy YAML field and command aliases. Unrecognized fields and commands are now rejected as errors. Older client image versions are accepted again.
43
+ Remove deprecated YAML fields: 'workspace' (use -p/--profile) and 'code_source.snapshot.allow_uncommitted' (commit changes or pin with git_commit / git_branch + use_remote_head). Raise 'command' cap from 500 to 1000 lines.
44
+ Remove --local-rank from `sgcli logs` / `sgcli subscribe` and the `local_rank` parameter from `sdk.get_logs()` — per-rank logs are not produced by the platform; every rank funnels into the consolidated per-node stream
45
+ Remove the deprecated nested `environment.env_variables` and `environment.env_variables_secrets` YAML fields. Use the top-level `env_variables:` and `secrets:` fields instead. Configs that still set the nested forms now fail validation with a hint pointing at the replacement.
46
+ [BREAKING] Rename CLI from `sgcli` to `air`; wheel from `databricks-serverless-gpu-cli` to `databricks-air`. The `SGCLI_DISABLE_TELEMETRY` env var is now `AIR_DISABLE_TELEMETRY`. No deprecation alias — update scripts that invoke `sgcli`.
47
+ Fix path expansion to support ~ (tilde) home directory shorthand in YAML config file paths
48
+ Revert task_run_id as the user-facing identifier; the CLI again uses job_run_id as the canonical handle (#1929533 undone).
49
+ [air-cli] Validation hardening: reject experiment_name > 100 chars client-side (Jobs API task_key limit); reject empty/whitespace docker_image_url in `air register image`; fix `air get logs --review` (without `--lines`) rendering "Last None lines per node" by coercing the None default to 200.
50
+ Faster snapshot submission: detect uncommitted changes once and scope the check to include_paths, avoiding a redundant full-repo `git status` (slow on large monorepos).
51
+ [air-cli] `air run` now shows self-clearing spinners while uploading the YAML configuration files and while packaging/uploading the code snapshot.
52
+ Surface the server's error and run-termination messages (e.g. version upgrade notices) so job submission failures and runs that fail before producing logs report the actual reason instead of a raw HTTP error or bare status.
53
+ Telemetry: emit job_run_id (Jobs API run_id) on sgcli run events to enable joins against AI scheduler activity logs.
54
+ Assign a usage policy by name via usage_policy_name; removed usage_policy_id/budget_policy_id fields and the DATABRICKS_USAGE_POLICY_ID env var
55
+ Install requirements.yaml deps with `uv pip install` layered on the base environment (no forced `-U`, so the image's torch and other packages are preserved). `--trusted-host` is ignored with a warning (uv configures trust per index URL).
56
+ `air --version` now prints an ASCII art banner alongside the version string.
57
+ Rename user-facing YAML fields to align with cross-product Databricks naming. Old names are temporarily accepted with a deprecation warning and will be removed in an upcoming release. Renames: env_variables_secrets -> secrets, run_name -> mlflow_run_name, experiment_directory -> mlflow_experiment_directory, budget_policy_id -> usage_policy_id, code_source.snapshot.repo_path -> code_source.snapshot.root_path, compute.gpus -> compute.num_accelerators, compute.gpu_type -> compute.accelerator_type, compute.gpu_node_pool_id -> compute.node_pool_id, compute.gpu_pool_name -> compute.pool_name.