mithril-client 0.1.0a1__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mithril/__init__.py +7 -0
- mithril/_mcli.cpython-314-darwin.so +0 -0
- mithril/_mcli.pyi +7 -0
- mithril/_mcli_entry.py +75 -0
- mithril/api/__init__.py +7 -0
- mithril/api/bindings/.gitattributes +2 -0
- mithril/api/bindings/__init__.py +10 -0
- mithril/api/bindings/api/__init__.py +1 -0
- mithril/api/bindings/api/api_keys/__init__.py +1 -0
- mithril/api/bindings/api/api_keys/create_api_key_v2_api_keys_post.py +179 -0
- mithril/api/bindings/api/api_keys/get_api_keys_v2_api_keys_get.py +141 -0
- mithril/api/bindings/api/api_keys/revoke_api_key_v2_api_keys_key_fid_delete.py +173 -0
- mithril/api/bindings/api/image_versions/__init__.py +1 -0
- mithril/api/bindings/api/image_versions/get_image_versions_v2_image_versions_get.py +141 -0
- mithril/api/bindings/api/image_versions/get_mcc_image_versions_v2_mcc_image_versions_get.py +179 -0
- mithril/api/bindings/api/instance_types/__init__.py +1 -0
- mithril/api/bindings/api/instance_types/get_instance_types_v2_instance_types_get.py +137 -0
- mithril/api/bindings/api/instances/__init__.py +1 -0
- mithril/api/bindings/api/instances/get_instance_status_v2_instances_instance_id_status_get.py +165 -0
- mithril/api/bindings/api/instances/get_instances_v2_instances_get.py +409 -0
- mithril/api/bindings/api/kubernetes_clusters/__init__.py +1 -0
- mithril/api/bindings/api/kubernetes_clusters/create_kubernetes_cluster_v2_kubernetes_clusters_post.py +171 -0
- mithril/api/bindings/api/kubernetes_clusters/delete_kubernetes_cluster_v2_kubernetes_clusters_cluster_fid_delete.py +163 -0
- mithril/api/bindings/api/kubernetes_clusters/get_kubernetes_cluster_v2_kubernetes_clusters_cluster_fid_get.py +165 -0
- mithril/api/bindings/api/kubernetes_clusters/get_kubernetes_clusters_v2_kubernetes_clusters_get.py +175 -0
- mithril/api/bindings/api/lifecycle_scripts/__init__.py +1 -0
- mithril/api/bindings/api/lifecycle_scripts/create_lifecycle_script_v2_lifecycle_scripts_post.py +171 -0
- mithril/api/bindings/api/lifecycle_scripts/delete_lifecycle_script_v2_lifecycle_scripts_ls_fid_delete.py +155 -0
- mithril/api/bindings/api/lifecycle_scripts/get_lifecycle_script_content_v2_lifecycle_scripts_ls_fid_content_get.py +155 -0
- mithril/api/bindings/api/lifecycle_scripts/list_lifecycle_scripts_v2_lifecycle_scripts_get.py +247 -0
- mithril/api/bindings/api/lifecycle_scripts/update_lifecycle_script_v2_lifecycle_scripts_ls_fid_patch.py +179 -0
- mithril/api/bindings/api/pricing/__init__.py +1 -0
- mithril/api/bindings/api/pricing/get_current_prices_v2_v2_pricing_current_get.py +217 -0
- mithril/api/bindings/api/pricing/get_historical_prices_v2_v2_pricing_history_get.py +222 -0
- mithril/api/bindings/api/profile/__init__.py +1 -0
- mithril/api/bindings/api/profile/get_me_v2_me_get.py +132 -0
- mithril/api/bindings/api/profile/get_my_teammates_v2_me_teammates_get.py +153 -0
- mithril/api/bindings/api/projects/__init__.py +1 -0
- mithril/api/bindings/api/projects/get_projects_v2_projects_get.py +137 -0
- mithril/api/bindings/api/quotas/__init__.py +1 -0
- mithril/api/bindings/api/quotas/get_quotas_v2_quotas_get.py +175 -0
- mithril/api/bindings/api/reservations/__init__.py +1 -0
- mithril/api/bindings/api/reservations/create_reservation_v2_reservation_post.py +171 -0
- mithril/api/bindings/api/reservations/extend_reservation_v2_reservation_reservation_fid_extend_post.py +187 -0
- mithril/api/bindings/api/reservations/get_availability_v2_reservation_availability_get.py +664 -0
- mithril/api/bindings/api/reservations/get_extension_availability_v2_reservation_reservation_fid_extension_availability_get.py +165 -0
- mithril/api/bindings/api/reservations/get_reservations_v2_reservation_get.py +309 -0
- mithril/api/bindings/api/reservations/update_reservation_v2_reservation_reservation_fid_patch.py +187 -0
- mithril/api/bindings/api/spot/__init__.py +1 -0
- mithril/api/bindings/api/spot/cancel_bid_v2_spot_bids_bid_fid_delete.py +161 -0
- mithril/api/bindings/api/spot/create_bid_v2_spot_bids_post.py +171 -0
- mithril/api/bindings/api/spot/get_auctions_v2_spot_availability_get.py +137 -0
- mithril/api/bindings/api/spot/get_bid_history_v2_spot_bids_bid_fid_history_get.py +193 -0
- mithril/api/bindings/api/spot/get_bid_status_v2_spot_bids_bid_fid_status_get.py +189 -0
- mithril/api/bindings/api/spot/get_bid_v2_spot_bids_bid_fid_get.py +163 -0
- mithril/api/bindings/api/spot/get_bids_v2_spot_bids_get.py +330 -0
- mithril/api/bindings/api/spot/update_bid_v2_spot_bids_bid_fid_patch.py +185 -0
- mithril/api/bindings/api/ssh_keys/__init__.py +1 -0
- mithril/api/bindings/api/ssh_keys/create_ssh_key_v2_ssh_keys_post.py +175 -0
- mithril/api/bindings/api/ssh_keys/delete_ssh_key_v2_ssh_keys_ssh_key_fid_delete.py +167 -0
- mithril/api/bindings/api/ssh_keys/get_ssh_keys_v2_ssh_keys_get.py +175 -0
- mithril/api/bindings/api/ssh_keys/update_ssh_key_v2_ssh_keys_ssh_key_fid_patch.py +187 -0
- mithril/api/bindings/api/volumes/__init__.py +1 -0
- mithril/api/bindings/api/volumes/create_volume_v2_volumes_post.py +211 -0
- mithril/api/bindings/api/volumes/delete_volume_v2_volumes_volume_fid_delete.py +199 -0
- mithril/api/bindings/api/volumes/get_volumes_v2_volumes_get.py +239 -0
- mithril/api/bindings/api/volumes/update_volume_v2_volumes_volume_fid_patch.py +243 -0
- mithril/api/bindings/client.py +284 -0
- mithril/api/bindings/errors.py +18 -0
- mithril/api/bindings/models/__init__.py +169 -0
- mithril/api/bindings/models/api_key_model.py +114 -0
- mithril/api/bindings/models/auction_model.py +146 -0
- mithril/api/bindings/models/availability_slot_model.py +76 -0
- mithril/api/bindings/models/bid_history_event_model.py +157 -0
- mithril/api/bindings/models/bid_history_event_model_event_type.py +19 -0
- mithril/api/bindings/models/bid_history_response.py +84 -0
- mithril/api/bindings/models/bid_model.py +191 -0
- mithril/api/bindings/models/bid_model_status.py +14 -0
- mithril/api/bindings/models/bid_status_response.py +72 -0
- mithril/api/bindings/models/bid_status_response_status.py +15 -0
- mithril/api/bindings/models/check_availability_response.py +60 -0
- mithril/api/bindings/models/create_api_key_request.py +68 -0
- mithril/api/bindings/models/create_api_key_response.py +122 -0
- mithril/api/bindings/models/create_bid_request.py +116 -0
- mithril/api/bindings/models/create_kubernetes_cluster_request.py +136 -0
- mithril/api/bindings/models/create_kubernetes_cluster_request_k8s_version.py +11 -0
- mithril/api/bindings/models/create_lifecycle_script_request.py +115 -0
- mithril/api/bindings/models/create_reservation_request.py +124 -0
- mithril/api/bindings/models/create_ssh_key_request.py +99 -0
- mithril/api/bindings/models/create_volume_request.py +98 -0
- mithril/api/bindings/models/create_volume_request_disk_interface.py +11 -0
- mithril/api/bindings/models/created_ssh_key_model.py +122 -0
- mithril/api/bindings/models/current_prices_response.py +202 -0
- mithril/api/bindings/models/extend_reservation_request.py +60 -0
- mithril/api/bindings/models/extension_availability_response.py +68 -0
- mithril/api/bindings/models/get_availability_v2_reservation_availability_get_mode.py +12 -0
- mithril/api/bindings/models/get_bids_response.py +96 -0
- mithril/api/bindings/models/get_bids_v2_spot_bids_get_sort_by.py +11 -0
- mithril/api/bindings/models/get_bids_v2_spot_bids_get_status.py +14 -0
- mithril/api/bindings/models/get_instances_response.py +96 -0
- mithril/api/bindings/models/get_instances_v2_instances_get_order_type_in_type_0_item.py +11 -0
- mithril/api/bindings/models/get_instances_v2_instances_get_sort_by.py +12 -0
- mithril/api/bindings/models/get_instances_v2_instances_get_status_in_type_0_item.py +24 -0
- mithril/api/bindings/models/get_latest_end_time_response.py +68 -0
- mithril/api/bindings/models/get_reservations_response.py +96 -0
- mithril/api/bindings/models/get_reservations_v2_reservation_get_sort_by.py +11 -0
- mithril/api/bindings/models/get_reservations_v2_reservation_get_status.py +14 -0
- mithril/api/bindings/models/historical_price_point_model.py +94 -0
- mithril/api/bindings/models/historical_prices_response_model.py +76 -0
- mithril/api/bindings/models/http_validation_error.py +78 -0
- mithril/api/bindings/models/image_version_model.py +224 -0
- mithril/api/bindings/models/instance_model.py +211 -0
- mithril/api/bindings/models/instance_model_status.py +24 -0
- mithril/api/bindings/models/instance_status_response.py +141 -0
- mithril/api/bindings/models/instance_status_response_status.py +24 -0
- mithril/api/bindings/models/instance_type_model.py +170 -0
- mithril/api/bindings/models/kubernetes_cluster_model.py +207 -0
- mithril/api/bindings/models/kubernetes_cluster_model_status.py +12 -0
- mithril/api/bindings/models/launch_specification_model.py +152 -0
- mithril/api/bindings/models/lifecycle_script_model.py +134 -0
- mithril/api/bindings/models/lifecycle_script_scope.py +12 -0
- mithril/api/bindings/models/list_lifecycle_scripts_response.py +96 -0
- mithril/api/bindings/models/list_lifecycle_scripts_v2_lifecycle_scripts_get_sort_by.py +11 -0
- mithril/api/bindings/models/me_response.py +126 -0
- mithril/api/bindings/models/new_ssh_key_model.py +100 -0
- mithril/api/bindings/models/persistent_disk_change.py +92 -0
- mithril/api/bindings/models/project_model.py +76 -0
- mithril/api/bindings/models/public_lifecycle_script_scope.py +11 -0
- mithril/api/bindings/models/quota_model.py +132 -0
- mithril/api/bindings/models/reservation_model.py +215 -0
- mithril/api/bindings/models/reservation_model_status.py +14 -0
- mithril/api/bindings/models/size.py +70 -0
- mithril/api/bindings/models/size_unit.py +18 -0
- mithril/api/bindings/models/sort_direction.py +11 -0
- mithril/api/bindings/models/teammate_response.py +158 -0
- mithril/api/bindings/models/update_bid_request.py +143 -0
- mithril/api/bindings/models/update_lifecycle_script_request.py +109 -0
- mithril/api/bindings/models/update_reservation_request.py +103 -0
- mithril/api/bindings/models/update_ssh_key_request.py +60 -0
- mithril/api/bindings/models/update_volume_request.py +65 -0
- mithril/api/bindings/models/validation_error.py +89 -0
- mithril/api/bindings/models/volume_model.py +140 -0
- mithril/api/bindings/models/volume_model_attachments.py +46 -0
- mithril/api/bindings/models/volume_model_interface.py +11 -0
- mithril/api/bindings/types.py +56 -0
- mithril/api/client.py +138 -0
- mithril/cli/__init__.py +7 -0
- mithril/cli/commands/__init__.py +15 -0
- mithril/cli/commands/help.py +88 -0
- mithril/cli/commands/launch.py +353 -0
- mithril/cli/main.py +68 -0
- mithril/cli/utils/__init__.py +1 -0
- mithril/cli/utils/skypilot_passthrough.py +38 -0
- mithril/cli/utils/streaming.py +235 -0
- mithril/cli/utils/volumes.py +110 -0
- mithril/config.py +47 -0
- mithril/py.typed +0 -0
- mithril/sky/__init__.py +141 -0
- mithril/sky/client.py +176 -0
- mithril_client-0.1.0a1.dist-info/METADATA +56 -0
- mithril_client-0.1.0a1.dist-info/RECORD +163 -0
- mithril_client-0.1.0a1.dist-info/WHEEL +4 -0
- mithril_client-0.1.0a1.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""SkyPilot passthrough helpers for the Mithril CLI."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
import sys
|
|
7
|
+
from contextlib import suppress
|
|
8
|
+
|
|
9
|
+
SKY_ALIAS_COMMANDS: tuple[str, ...] = ("exec", "status", "start", "down", "stop")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def run_sky(*args: str) -> int:
|
|
13
|
+
"""Run sky CLI with given arguments.
|
|
14
|
+
|
|
15
|
+
Note: sky.cli is not part of skypilot's public API, but it's the stable
|
|
16
|
+
entry point for the `sky` command (defined in skypilot's setup.py). We
|
|
17
|
+
invoke it directly to avoid requiring `sky` to be on PATH (e.g., when
|
|
18
|
+
installed via `uv tool`).
|
|
19
|
+
"""
|
|
20
|
+
try:
|
|
21
|
+
sky_cli = importlib.import_module("sky.client.cli.command").cli
|
|
22
|
+
sky_cli.main(args=list(args), standalone_mode=False)
|
|
23
|
+
except SystemExit as e:
|
|
24
|
+
return e.code if isinstance(e.code, int) else 1
|
|
25
|
+
except Exception as e: # noqa: BLE001
|
|
26
|
+
# SkyPilot is Click-based and may raise ClickException; avoid importing click
|
|
27
|
+
# directly and rely on duck-typing to preserve decent error output.
|
|
28
|
+
show = getattr(e, "show", None)
|
|
29
|
+
if callable(show):
|
|
30
|
+
with suppress(Exception):
|
|
31
|
+
show()
|
|
32
|
+
exit_code = getattr(e, "exit_code", 1)
|
|
33
|
+
return exit_code if isinstance(exit_code, int) else 1
|
|
34
|
+
|
|
35
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
36
|
+
return 1
|
|
37
|
+
else:
|
|
38
|
+
return 0
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""Custom streaming utilities for Mithril CLI.
|
|
2
|
+
|
|
3
|
+
This module provides a polling-based approach to monitoring SkyPilot requests,
|
|
4
|
+
giving full control over the UI rather than parsing SkyPilot's log output.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
11
|
+
|
|
12
|
+
from rich.live import Live
|
|
13
|
+
from rich.spinner import Spinner
|
|
14
|
+
from rich.text import Text
|
|
15
|
+
|
|
16
|
+
from mithril.cli.utils.skypilot_passthrough import SKY_ALIAS_COMMANDS
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
|
|
23
|
+
from rich.console import Console
|
|
24
|
+
from sky.server.common import RequestId
|
|
25
|
+
|
|
26
|
+
from mithril.sky import SkyClient
|
|
27
|
+
|
|
28
|
+
TERMINAL_STATUSES = {"SUCCEEDED", "FAILED", "CANCELLED"}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def stream_job_logs(
|
|
32
|
+
cluster_name: str,
|
|
33
|
+
job_id: int,
|
|
34
|
+
console: Console,
|
|
35
|
+
*,
|
|
36
|
+
sky: SkyClient,
|
|
37
|
+
follow: bool = True,
|
|
38
|
+
filter_fn: Callable[[str], str | None] | None = None,
|
|
39
|
+
) -> int:
|
|
40
|
+
"""Stream job logs with optional filtering.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
cluster_name: The cluster name
|
|
44
|
+
job_id: The job ID to stream logs for
|
|
45
|
+
console: Rich console for output
|
|
46
|
+
sky: SkyClient instance
|
|
47
|
+
follow: Whether to follow the logs (like tail -f)
|
|
48
|
+
filter_fn: Optional function to filter/transform each line.
|
|
49
|
+
Return the line to print, or None to skip it.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Exit code (0 for success, non-zero for failure)
|
|
53
|
+
"""
|
|
54
|
+
# Use preload_content=False to get an iterator
|
|
55
|
+
log_iter = sky.tail_logs(
|
|
56
|
+
cluster_name=cluster_name,
|
|
57
|
+
job_id=job_id,
|
|
58
|
+
follow=follow,
|
|
59
|
+
preload_content=False,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
for line in log_iter:
|
|
63
|
+
if line is None:
|
|
64
|
+
# End of stream
|
|
65
|
+
break
|
|
66
|
+
|
|
67
|
+
if filter_fn:
|
|
68
|
+
filtered = filter_fn(line)
|
|
69
|
+
if filtered is not None:
|
|
70
|
+
console.print(filtered, end="", markup=False, highlight=False)
|
|
71
|
+
else:
|
|
72
|
+
console.print(line, end="", markup=False, highlight=False)
|
|
73
|
+
|
|
74
|
+
# Get the exit code by checking job status
|
|
75
|
+
job_statuses = sky.get(sky.job_status(cluster_name, job_ids=[job_id]))
|
|
76
|
+
status = job_statuses.get(job_id)
|
|
77
|
+
|
|
78
|
+
if status is None:
|
|
79
|
+
return 1 # Job not found
|
|
80
|
+
|
|
81
|
+
# JobStatus enum - check if succeeded
|
|
82
|
+
return 0 if str(status) == "SUCCEEDED" else 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def poll_launch( # noqa: UP047 - PEP 695 syntax requires 3.12+
|
|
86
|
+
request_id: RequestId[T],
|
|
87
|
+
console: Console,
|
|
88
|
+
*,
|
|
89
|
+
sky: SkyClient,
|
|
90
|
+
poll_interval: float = 1.0,
|
|
91
|
+
quiet: bool = False,
|
|
92
|
+
initial_cluster_name: str | None = None,
|
|
93
|
+
) -> tuple[T, str | None]:
|
|
94
|
+
"""Poll a launch request until completion, showing custom progress.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
request_id: The SkyPilot request ID from sky.launch()
|
|
98
|
+
console: Rich console for output
|
|
99
|
+
sky: SkyClient instance
|
|
100
|
+
poll_interval: How often to poll for status updates (seconds)
|
|
101
|
+
quiet: If True, suppress progress output
|
|
102
|
+
initial_cluster_name: Optional cluster name to seed status messages
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
The result from the launch request (job_id, handle)
|
|
106
|
+
"""
|
|
107
|
+
cluster_name = initial_cluster_name
|
|
108
|
+
initial_renderable, last_display_text = _render_status_message(
|
|
109
|
+
None,
|
|
110
|
+
cluster_name,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
spinner = Spinner("dots", text="Launching...")
|
|
114
|
+
live = Live(spinner, console=console, refresh_per_second=10, transient=True)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
if not quiet:
|
|
118
|
+
spinner.update(text=initial_renderable)
|
|
119
|
+
live.start()
|
|
120
|
+
last_display_text, cluster_name = _run_poll_loop(
|
|
121
|
+
request_id=request_id,
|
|
122
|
+
poll_interval=poll_interval,
|
|
123
|
+
spinner=spinner,
|
|
124
|
+
live=live,
|
|
125
|
+
last_display_text=last_display_text,
|
|
126
|
+
cluster_name=cluster_name,
|
|
127
|
+
quiet=quiet,
|
|
128
|
+
sky=sky,
|
|
129
|
+
)
|
|
130
|
+
except KeyboardInterrupt:
|
|
131
|
+
if not quiet and cluster_name:
|
|
132
|
+
_print_provision_hint(console, cluster_name)
|
|
133
|
+
raise
|
|
134
|
+
finally:
|
|
135
|
+
if not quiet:
|
|
136
|
+
live.stop()
|
|
137
|
+
|
|
138
|
+
# Get the final result (this will raise if there was an error)
|
|
139
|
+
result = sky.get(request_id)
|
|
140
|
+
|
|
141
|
+
return result, cluster_name
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def format_status_message(status_msg: str, cluster_name: str | None) -> Text:
|
|
145
|
+
"""Format a status message for display.
|
|
146
|
+
|
|
147
|
+
Appends the cluster name to status messages so users know which cluster
|
|
148
|
+
is being worked on. For example: "Launching" becomes "Launching (my-cluster)".
|
|
149
|
+
|
|
150
|
+
The cluster name is optional because it isn't available at the start of a
|
|
151
|
+
launch—SkyPilot assigns it during provisioning. Early status messages won't
|
|
152
|
+
have a cluster name yet.
|
|
153
|
+
"""
|
|
154
|
+
formatted = status_msg
|
|
155
|
+
if cluster_name:
|
|
156
|
+
formatted = f"{formatted} ({cluster_name})"
|
|
157
|
+
formatted = (
|
|
158
|
+
f"{formatted} [dim]View logs: {_cli_command('logs')} --provision "
|
|
159
|
+
f"{cluster_name}[/dim]"
|
|
160
|
+
)
|
|
161
|
+
return Text.from_markup(formatted)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _default_launch_message(cluster_name: str | None) -> str:
|
|
165
|
+
if cluster_name:
|
|
166
|
+
return (
|
|
167
|
+
"Launching [dim]View logs: "
|
|
168
|
+
f"{_cli_command('logs')} --provision {cluster_name}[/dim]"
|
|
169
|
+
)
|
|
170
|
+
return "Launching..."
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _render_status_message(
|
|
174
|
+
status_msg: str | None,
|
|
175
|
+
cluster_name: str | None,
|
|
176
|
+
) -> tuple[Text, str]:
|
|
177
|
+
if status_msg:
|
|
178
|
+
renderable = format_status_message(status_msg, cluster_name)
|
|
179
|
+
return renderable, renderable.plain
|
|
180
|
+
message = _default_launch_message(cluster_name)
|
|
181
|
+
renderable = Text.from_markup(message)
|
|
182
|
+
return renderable, renderable.plain
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _run_poll_loop(
|
|
186
|
+
*,
|
|
187
|
+
request_id: str,
|
|
188
|
+
poll_interval: float,
|
|
189
|
+
spinner: Spinner,
|
|
190
|
+
live: Live,
|
|
191
|
+
last_display_text: str,
|
|
192
|
+
cluster_name: str | None,
|
|
193
|
+
quiet: bool,
|
|
194
|
+
sky: SkyClient,
|
|
195
|
+
) -> tuple[str, str | None]:
|
|
196
|
+
while True:
|
|
197
|
+
statuses = sky.api_status(request_ids=[request_id])
|
|
198
|
+
|
|
199
|
+
if not statuses:
|
|
200
|
+
time.sleep(poll_interval)
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
status = statuses[0]
|
|
204
|
+
current_status = status.status
|
|
205
|
+
status_msg = status.status_msg
|
|
206
|
+
cluster_name = status.cluster_name or cluster_name
|
|
207
|
+
|
|
208
|
+
if not quiet:
|
|
209
|
+
renderable, display_text = _render_status_message(
|
|
210
|
+
status_msg,
|
|
211
|
+
cluster_name,
|
|
212
|
+
)
|
|
213
|
+
if display_text != last_display_text:
|
|
214
|
+
last_display_text = display_text
|
|
215
|
+
spinner.update(text=renderable)
|
|
216
|
+
live.update(spinner)
|
|
217
|
+
|
|
218
|
+
if current_status in TERMINAL_STATUSES:
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
time.sleep(poll_interval)
|
|
222
|
+
|
|
223
|
+
return last_display_text, cluster_name
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _print_provision_hint(console: Console, cluster_name: str) -> None:
|
|
227
|
+
console.print(
|
|
228
|
+
f"[dim]View logs: {_cli_command('logs')} --provision {cluster_name}[/dim]"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _cli_command(command: str) -> str:
|
|
233
|
+
if command in SKY_ALIAS_COMMANDS:
|
|
234
|
+
return f"ml {command}"
|
|
235
|
+
return f"ml sky {command}"
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Volume synchronization between Mithril and SkyPilot.
|
|
2
|
+
|
|
3
|
+
This module ensures that volumes referenced in a SkyPilot task are registered
|
|
4
|
+
with SkyPilot before launch. It checks volumes against the Mithril API and
|
|
5
|
+
automatically registers any that exist in Mithril but aren't yet known to SkyPilot.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
from mithril.api.bindings.models.volume_model_interface import VolumeModelInterface
|
|
13
|
+
|
|
14
|
+
# SkyPilot volume type constants from sky.utils.volume.VolumeType enum
|
|
15
|
+
MITHRIL_FILE_SHARE_TYPE = "mithril-file-share"
|
|
16
|
+
MITHRIL_BLOCK_TYPE = "mithril-block"
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from mithril.api.bindings.models import VolumeModel
|
|
20
|
+
from mithril.api.client import MithrilClient
|
|
21
|
+
from mithril.sky import SkyClient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _get_skypilot_volumes(sky: SkyClient) -> set[str]:
|
|
25
|
+
"""Get volume names registered with SkyPilot.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
sky: SkyClient instance.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Set of volume names currently registered in SkyPilot.
|
|
32
|
+
"""
|
|
33
|
+
request_id = sky.volumes.ls()
|
|
34
|
+
volumes = sky.get(request_id)
|
|
35
|
+
return {vol.name for vol in volumes if vol.name}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _register_volume_with_skypilot(vol: VolumeModel, sky: SkyClient) -> None:
|
|
39
|
+
"""Register a Mithril volume with SkyPilot.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
vol: VolumeModel from Mithril API.
|
|
43
|
+
sky: SkyClient instance.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
Exception: If registration fails. sky.get() will raise RuntimeError,
|
|
47
|
+
RequestCancelled, or the original exception from the request.
|
|
48
|
+
"""
|
|
49
|
+
# Map Mithril volume to SkyPilot volume spec
|
|
50
|
+
match vol.interface:
|
|
51
|
+
case VolumeModelInterface.FILE:
|
|
52
|
+
volume_type = MITHRIL_FILE_SHARE_TYPE
|
|
53
|
+
case VolumeModelInterface.BLOCK:
|
|
54
|
+
volume_type = MITHRIL_BLOCK_TYPE
|
|
55
|
+
|
|
56
|
+
# Note: SkyPilot treats GB and Gi identically (both map to 2**30 bytes)
|
|
57
|
+
sky_volume = sky.volumes.Volume(
|
|
58
|
+
name=vol.name,
|
|
59
|
+
type=volume_type,
|
|
60
|
+
infra=f"mithril/{vol.region}",
|
|
61
|
+
size=f"{vol.capacity_gb}GB",
|
|
62
|
+
)
|
|
63
|
+
request_id = sky.volumes.apply(sky_volume)
|
|
64
|
+
sky.get(request_id)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def ensure_volumes_registered(
|
|
68
|
+
volume_names: list[str],
|
|
69
|
+
*,
|
|
70
|
+
mithril: MithrilClient,
|
|
71
|
+
sky: SkyClient,
|
|
72
|
+
) -> list[str]:
|
|
73
|
+
"""Ensure task volumes are registered with SkyPilot.
|
|
74
|
+
|
|
75
|
+
For each volume name:
|
|
76
|
+
1. Check if already registered with SkyPilot
|
|
77
|
+
2. If not, check if it exists in Mithril
|
|
78
|
+
3. If found in Mithril, register it with SkyPilot
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
volume_names: List of volume names to ensure are registered.
|
|
82
|
+
mithril: MithrilClient instance for Mithril API operations.
|
|
83
|
+
sky: SkyClient instance for SkyPilot operations.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of volume names that were newly registered.
|
|
87
|
+
"""
|
|
88
|
+
if not volume_names:
|
|
89
|
+
return []
|
|
90
|
+
|
|
91
|
+
volume_names_set = set(volume_names)
|
|
92
|
+
|
|
93
|
+
# Check what's already in SkyPilot
|
|
94
|
+
sky_volumes = _get_skypilot_volumes(sky)
|
|
95
|
+
missing = volume_names_set - sky_volumes
|
|
96
|
+
|
|
97
|
+
if not missing:
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
mithril_volumes = mithril.list_volumes()
|
|
101
|
+
mithril_by_name = {v.name: v for v in mithril_volumes}
|
|
102
|
+
|
|
103
|
+
registered: list[str] = []
|
|
104
|
+
for name in missing:
|
|
105
|
+
if name in mithril_by_name:
|
|
106
|
+
vol = mithril_by_name[name]
|
|
107
|
+
_register_volume_with_skypilot(vol, sky)
|
|
108
|
+
registered.append(name)
|
|
109
|
+
|
|
110
|
+
return registered
|
mithril/config.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Mithril configuration management.
|
|
2
|
+
|
|
3
|
+
Delegates to the Rust implementation for config loading.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
|
|
10
|
+
from mithril._mcli import MithrilConfig as RustConfig
|
|
11
|
+
from mithril._mcli import load_config as rust_load_config
|
|
12
|
+
|
|
13
|
+
DEFAULT_API_URL = "https://api.mithril.ai"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConfigError(Exception):
|
|
17
|
+
"""Raised when configuration is missing or invalid."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, slots=True)
|
|
21
|
+
class Config:
|
|
22
|
+
"""Resolved configuration for Mithril operations."""
|
|
23
|
+
|
|
24
|
+
api_key: str
|
|
25
|
+
project_id: str
|
|
26
|
+
api_url: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def load_config() -> Config:
|
|
30
|
+
"""Load and validate configuration from environment and config file.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Validated Config object.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ConfigError: If required configuration (api_key, project_id) is missing.
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
rust_cfg: RustConfig = rust_load_config()
|
|
40
|
+
return Config(
|
|
41
|
+
api_key=rust_cfg.api_key,
|
|
42
|
+
project_id=rust_cfg.project_id,
|
|
43
|
+
api_url=rust_cfg.api_url,
|
|
44
|
+
)
|
|
45
|
+
except RuntimeError as e:
|
|
46
|
+
# Convert Rust RuntimeError to Python ConfigError
|
|
47
|
+
raise ConfigError(str(e)) from e
|
mithril/py.typed
ADDED
|
File without changes
|
mithril/sky/__init__.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""SkyPilot wrapper for stable imports and testability.
|
|
2
|
+
|
|
3
|
+
This module provides two interfaces:
|
|
4
|
+
|
|
5
|
+
1. **Module-level API** for SDK developers::
|
|
6
|
+
|
|
7
|
+
from mithril import sky
|
|
8
|
+
|
|
9
|
+
task = sky.Task(run="echo hello")
|
|
10
|
+
resources = sky.Resources(accelerators="A100:4")
|
|
11
|
+
sky.launch(task=task, cluster_name="dev", ...)
|
|
12
|
+
|
|
13
|
+
2. **SkyClient class** for CLI code and testing (supports parameter injection)::
|
|
14
|
+
|
|
15
|
+
from mithril.sky import SkyClient
|
|
16
|
+
|
|
17
|
+
client = SkyClient()
|
|
18
|
+
# or inject a mock in tests
|
|
19
|
+
build_task(params, sky=mock_client)
|
|
20
|
+
|
|
21
|
+
The module lazily imports SkyPilot to keep `ml --help` fast.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import TYPE_CHECKING, TypeVar
|
|
27
|
+
|
|
28
|
+
from mithril.sky.client import SkyClient
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T")
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from collections.abc import Iterable
|
|
34
|
+
|
|
35
|
+
from sky import Task as SkyTask
|
|
36
|
+
from sky import backends
|
|
37
|
+
from sky.server.common import RequestId
|
|
38
|
+
from sky.server.requests import payloads
|
|
39
|
+
from sky.skylet import job_lib
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"SkyClient",
|
|
43
|
+
"api_info",
|
|
44
|
+
"api_status",
|
|
45
|
+
"get",
|
|
46
|
+
"job_status",
|
|
47
|
+
"launch",
|
|
48
|
+
"tail_logs",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
# Default client instance for module-level API
|
|
52
|
+
_client = SkyClient()
|
|
53
|
+
|
|
54
|
+
# Module-level attributes (Task, Resources, clouds, volumes) are accessed via
|
|
55
|
+
# __getattr__ to maintain lazy loading - SkyPilot is only imported when first accessed.
|
|
56
|
+
_LAZY_ATTRS = {"Task", "Resources", "clouds", "volumes"}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def __getattr__(name: str) -> object:
|
|
60
|
+
"""Lazily expose SkyClient attributes at module level."""
|
|
61
|
+
if name in _LAZY_ATTRS:
|
|
62
|
+
return getattr(_client, name)
|
|
63
|
+
msg = f"module {__name__!r} has no attribute {name!r}"
|
|
64
|
+
raise AttributeError(msg)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def __dir__() -> list[str]:
|
|
68
|
+
"""Include lazy attributes in dir() for discoverability."""
|
|
69
|
+
return [*globals().keys(), *_LAZY_ATTRS]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def launch(
|
|
73
|
+
*,
|
|
74
|
+
task: SkyTask,
|
|
75
|
+
cluster_name: str | None = None,
|
|
76
|
+
retry_until_up: bool = False,
|
|
77
|
+
idle_minutes_to_autostop: int | None = None,
|
|
78
|
+
dryrun: bool = False,
|
|
79
|
+
down: bool = False,
|
|
80
|
+
) -> RequestId[tuple[int | None, backends.ResourceHandle | None]]:
|
|
81
|
+
"""Launch a task on SkyPilot.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
task: A sky.Task object
|
|
85
|
+
cluster_name: Name of the cluster (auto-generated if None)
|
|
86
|
+
retry_until_up: Keep retrying until the cluster is up
|
|
87
|
+
idle_minutes_to_autostop: Auto-stop after idle for this many minutes
|
|
88
|
+
dryrun: If True, don't actually launch
|
|
89
|
+
down: Tear down cluster after job completes
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Request ID for the launch operation
|
|
93
|
+
"""
|
|
94
|
+
return _client.launch(
|
|
95
|
+
task=task,
|
|
96
|
+
cluster_name=cluster_name,
|
|
97
|
+
retry_until_up=retry_until_up,
|
|
98
|
+
idle_minutes_to_autostop=idle_minutes_to_autostop,
|
|
99
|
+
dryrun=dryrun,
|
|
100
|
+
down=down,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def api_status(*, request_ids: list[str]) -> list[payloads.RequestPayload]:
|
|
105
|
+
"""Get status of SkyPilot requests."""
|
|
106
|
+
return _client.api_status(request_ids=request_ids)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get(request_id: RequestId[T]) -> T: # noqa: UP047 - PEP 695 syntax requires 3.12+
|
|
110
|
+
"""Get result of a SkyPilot request."""
|
|
111
|
+
return _client.get(request_id)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def tail_logs(
|
|
115
|
+
*,
|
|
116
|
+
cluster_name: str,
|
|
117
|
+
job_id: int,
|
|
118
|
+
follow: bool = True,
|
|
119
|
+
preload_content: bool = False,
|
|
120
|
+
) -> Iterable[str | None]:
|
|
121
|
+
"""Stream logs from a SkyPilot job."""
|
|
122
|
+
return _client.tail_logs(
|
|
123
|
+
cluster_name=cluster_name,
|
|
124
|
+
job_id=job_id,
|
|
125
|
+
follow=follow,
|
|
126
|
+
preload_content=preload_content,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def job_status(
|
|
131
|
+
cluster_name: str,
|
|
132
|
+
*,
|
|
133
|
+
job_ids: list[int] | None = None,
|
|
134
|
+
) -> RequestId[dict[int | None, job_lib.JobStatus | None]]:
|
|
135
|
+
"""Get status of a SkyPilot job (returns a request ID)."""
|
|
136
|
+
return _client.job_status(cluster_name, job_ids=job_ids)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def api_info() -> object:
|
|
140
|
+
"""Get API server info (status, version, etc.)."""
|
|
141
|
+
return _client.api_info()
|