mithril-client 0.1.0a1__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. mithril/__init__.py +7 -0
  2. mithril/_mcli.cpython-314-darwin.so +0 -0
  3. mithril/_mcli.pyi +7 -0
  4. mithril/_mcli_entry.py +75 -0
  5. mithril/api/__init__.py +7 -0
  6. mithril/api/bindings/.gitattributes +2 -0
  7. mithril/api/bindings/__init__.py +10 -0
  8. mithril/api/bindings/api/__init__.py +1 -0
  9. mithril/api/bindings/api/api_keys/__init__.py +1 -0
  10. mithril/api/bindings/api/api_keys/create_api_key_v2_api_keys_post.py +179 -0
  11. mithril/api/bindings/api/api_keys/get_api_keys_v2_api_keys_get.py +141 -0
  12. mithril/api/bindings/api/api_keys/revoke_api_key_v2_api_keys_key_fid_delete.py +173 -0
  13. mithril/api/bindings/api/image_versions/__init__.py +1 -0
  14. mithril/api/bindings/api/image_versions/get_image_versions_v2_image_versions_get.py +141 -0
  15. mithril/api/bindings/api/image_versions/get_mcc_image_versions_v2_mcc_image_versions_get.py +179 -0
  16. mithril/api/bindings/api/instance_types/__init__.py +1 -0
  17. mithril/api/bindings/api/instance_types/get_instance_types_v2_instance_types_get.py +137 -0
  18. mithril/api/bindings/api/instances/__init__.py +1 -0
  19. mithril/api/bindings/api/instances/get_instance_status_v2_instances_instance_id_status_get.py +165 -0
  20. mithril/api/bindings/api/instances/get_instances_v2_instances_get.py +409 -0
  21. mithril/api/bindings/api/kubernetes_clusters/__init__.py +1 -0
  22. mithril/api/bindings/api/kubernetes_clusters/create_kubernetes_cluster_v2_kubernetes_clusters_post.py +171 -0
  23. mithril/api/bindings/api/kubernetes_clusters/delete_kubernetes_cluster_v2_kubernetes_clusters_cluster_fid_delete.py +163 -0
  24. mithril/api/bindings/api/kubernetes_clusters/get_kubernetes_cluster_v2_kubernetes_clusters_cluster_fid_get.py +165 -0
  25. mithril/api/bindings/api/kubernetes_clusters/get_kubernetes_clusters_v2_kubernetes_clusters_get.py +175 -0
  26. mithril/api/bindings/api/lifecycle_scripts/__init__.py +1 -0
  27. mithril/api/bindings/api/lifecycle_scripts/create_lifecycle_script_v2_lifecycle_scripts_post.py +171 -0
  28. mithril/api/bindings/api/lifecycle_scripts/delete_lifecycle_script_v2_lifecycle_scripts_ls_fid_delete.py +155 -0
  29. mithril/api/bindings/api/lifecycle_scripts/get_lifecycle_script_content_v2_lifecycle_scripts_ls_fid_content_get.py +155 -0
  30. mithril/api/bindings/api/lifecycle_scripts/list_lifecycle_scripts_v2_lifecycle_scripts_get.py +247 -0
  31. mithril/api/bindings/api/lifecycle_scripts/update_lifecycle_script_v2_lifecycle_scripts_ls_fid_patch.py +179 -0
  32. mithril/api/bindings/api/pricing/__init__.py +1 -0
  33. mithril/api/bindings/api/pricing/get_current_prices_v2_v2_pricing_current_get.py +217 -0
  34. mithril/api/bindings/api/pricing/get_historical_prices_v2_v2_pricing_history_get.py +222 -0
  35. mithril/api/bindings/api/profile/__init__.py +1 -0
  36. mithril/api/bindings/api/profile/get_me_v2_me_get.py +132 -0
  37. mithril/api/bindings/api/profile/get_my_teammates_v2_me_teammates_get.py +153 -0
  38. mithril/api/bindings/api/projects/__init__.py +1 -0
  39. mithril/api/bindings/api/projects/get_projects_v2_projects_get.py +137 -0
  40. mithril/api/bindings/api/quotas/__init__.py +1 -0
  41. mithril/api/bindings/api/quotas/get_quotas_v2_quotas_get.py +175 -0
  42. mithril/api/bindings/api/reservations/__init__.py +1 -0
  43. mithril/api/bindings/api/reservations/create_reservation_v2_reservation_post.py +171 -0
  44. mithril/api/bindings/api/reservations/extend_reservation_v2_reservation_reservation_fid_extend_post.py +187 -0
  45. mithril/api/bindings/api/reservations/get_availability_v2_reservation_availability_get.py +664 -0
  46. mithril/api/bindings/api/reservations/get_extension_availability_v2_reservation_reservation_fid_extension_availability_get.py +165 -0
  47. mithril/api/bindings/api/reservations/get_reservations_v2_reservation_get.py +309 -0
  48. mithril/api/bindings/api/reservations/update_reservation_v2_reservation_reservation_fid_patch.py +187 -0
  49. mithril/api/bindings/api/spot/__init__.py +1 -0
  50. mithril/api/bindings/api/spot/cancel_bid_v2_spot_bids_bid_fid_delete.py +161 -0
  51. mithril/api/bindings/api/spot/create_bid_v2_spot_bids_post.py +171 -0
  52. mithril/api/bindings/api/spot/get_auctions_v2_spot_availability_get.py +137 -0
  53. mithril/api/bindings/api/spot/get_bid_history_v2_spot_bids_bid_fid_history_get.py +193 -0
  54. mithril/api/bindings/api/spot/get_bid_status_v2_spot_bids_bid_fid_status_get.py +189 -0
  55. mithril/api/bindings/api/spot/get_bid_v2_spot_bids_bid_fid_get.py +163 -0
  56. mithril/api/bindings/api/spot/get_bids_v2_spot_bids_get.py +330 -0
  57. mithril/api/bindings/api/spot/update_bid_v2_spot_bids_bid_fid_patch.py +185 -0
  58. mithril/api/bindings/api/ssh_keys/__init__.py +1 -0
  59. mithril/api/bindings/api/ssh_keys/create_ssh_key_v2_ssh_keys_post.py +175 -0
  60. mithril/api/bindings/api/ssh_keys/delete_ssh_key_v2_ssh_keys_ssh_key_fid_delete.py +167 -0
  61. mithril/api/bindings/api/ssh_keys/get_ssh_keys_v2_ssh_keys_get.py +175 -0
  62. mithril/api/bindings/api/ssh_keys/update_ssh_key_v2_ssh_keys_ssh_key_fid_patch.py +187 -0
  63. mithril/api/bindings/api/volumes/__init__.py +1 -0
  64. mithril/api/bindings/api/volumes/create_volume_v2_volumes_post.py +211 -0
  65. mithril/api/bindings/api/volumes/delete_volume_v2_volumes_volume_fid_delete.py +199 -0
  66. mithril/api/bindings/api/volumes/get_volumes_v2_volumes_get.py +239 -0
  67. mithril/api/bindings/api/volumes/update_volume_v2_volumes_volume_fid_patch.py +243 -0
  68. mithril/api/bindings/client.py +284 -0
  69. mithril/api/bindings/errors.py +18 -0
  70. mithril/api/bindings/models/__init__.py +169 -0
  71. mithril/api/bindings/models/api_key_model.py +114 -0
  72. mithril/api/bindings/models/auction_model.py +146 -0
  73. mithril/api/bindings/models/availability_slot_model.py +76 -0
  74. mithril/api/bindings/models/bid_history_event_model.py +157 -0
  75. mithril/api/bindings/models/bid_history_event_model_event_type.py +19 -0
  76. mithril/api/bindings/models/bid_history_response.py +84 -0
  77. mithril/api/bindings/models/bid_model.py +191 -0
  78. mithril/api/bindings/models/bid_model_status.py +14 -0
  79. mithril/api/bindings/models/bid_status_response.py +72 -0
  80. mithril/api/bindings/models/bid_status_response_status.py +15 -0
  81. mithril/api/bindings/models/check_availability_response.py +60 -0
  82. mithril/api/bindings/models/create_api_key_request.py +68 -0
  83. mithril/api/bindings/models/create_api_key_response.py +122 -0
  84. mithril/api/bindings/models/create_bid_request.py +116 -0
  85. mithril/api/bindings/models/create_kubernetes_cluster_request.py +136 -0
  86. mithril/api/bindings/models/create_kubernetes_cluster_request_k8s_version.py +11 -0
  87. mithril/api/bindings/models/create_lifecycle_script_request.py +115 -0
  88. mithril/api/bindings/models/create_reservation_request.py +124 -0
  89. mithril/api/bindings/models/create_ssh_key_request.py +99 -0
  90. mithril/api/bindings/models/create_volume_request.py +98 -0
  91. mithril/api/bindings/models/create_volume_request_disk_interface.py +11 -0
  92. mithril/api/bindings/models/created_ssh_key_model.py +122 -0
  93. mithril/api/bindings/models/current_prices_response.py +202 -0
  94. mithril/api/bindings/models/extend_reservation_request.py +60 -0
  95. mithril/api/bindings/models/extension_availability_response.py +68 -0
  96. mithril/api/bindings/models/get_availability_v2_reservation_availability_get_mode.py +12 -0
  97. mithril/api/bindings/models/get_bids_response.py +96 -0
  98. mithril/api/bindings/models/get_bids_v2_spot_bids_get_sort_by.py +11 -0
  99. mithril/api/bindings/models/get_bids_v2_spot_bids_get_status.py +14 -0
  100. mithril/api/bindings/models/get_instances_response.py +96 -0
  101. mithril/api/bindings/models/get_instances_v2_instances_get_order_type_in_type_0_item.py +11 -0
  102. mithril/api/bindings/models/get_instances_v2_instances_get_sort_by.py +12 -0
  103. mithril/api/bindings/models/get_instances_v2_instances_get_status_in_type_0_item.py +24 -0
  104. mithril/api/bindings/models/get_latest_end_time_response.py +68 -0
  105. mithril/api/bindings/models/get_reservations_response.py +96 -0
  106. mithril/api/bindings/models/get_reservations_v2_reservation_get_sort_by.py +11 -0
  107. mithril/api/bindings/models/get_reservations_v2_reservation_get_status.py +14 -0
  108. mithril/api/bindings/models/historical_price_point_model.py +94 -0
  109. mithril/api/bindings/models/historical_prices_response_model.py +76 -0
  110. mithril/api/bindings/models/http_validation_error.py +78 -0
  111. mithril/api/bindings/models/image_version_model.py +224 -0
  112. mithril/api/bindings/models/instance_model.py +211 -0
  113. mithril/api/bindings/models/instance_model_status.py +24 -0
  114. mithril/api/bindings/models/instance_status_response.py +141 -0
  115. mithril/api/bindings/models/instance_status_response_status.py +24 -0
  116. mithril/api/bindings/models/instance_type_model.py +170 -0
  117. mithril/api/bindings/models/kubernetes_cluster_model.py +207 -0
  118. mithril/api/bindings/models/kubernetes_cluster_model_status.py +12 -0
  119. mithril/api/bindings/models/launch_specification_model.py +152 -0
  120. mithril/api/bindings/models/lifecycle_script_model.py +134 -0
  121. mithril/api/bindings/models/lifecycle_script_scope.py +12 -0
  122. mithril/api/bindings/models/list_lifecycle_scripts_response.py +96 -0
  123. mithril/api/bindings/models/list_lifecycle_scripts_v2_lifecycle_scripts_get_sort_by.py +11 -0
  124. mithril/api/bindings/models/me_response.py +126 -0
  125. mithril/api/bindings/models/new_ssh_key_model.py +100 -0
  126. mithril/api/bindings/models/persistent_disk_change.py +92 -0
  127. mithril/api/bindings/models/project_model.py +76 -0
  128. mithril/api/bindings/models/public_lifecycle_script_scope.py +11 -0
  129. mithril/api/bindings/models/quota_model.py +132 -0
  130. mithril/api/bindings/models/reservation_model.py +215 -0
  131. mithril/api/bindings/models/reservation_model_status.py +14 -0
  132. mithril/api/bindings/models/size.py +70 -0
  133. mithril/api/bindings/models/size_unit.py +18 -0
  134. mithril/api/bindings/models/sort_direction.py +11 -0
  135. mithril/api/bindings/models/teammate_response.py +158 -0
  136. mithril/api/bindings/models/update_bid_request.py +143 -0
  137. mithril/api/bindings/models/update_lifecycle_script_request.py +109 -0
  138. mithril/api/bindings/models/update_reservation_request.py +103 -0
  139. mithril/api/bindings/models/update_ssh_key_request.py +60 -0
  140. mithril/api/bindings/models/update_volume_request.py +65 -0
  141. mithril/api/bindings/models/validation_error.py +89 -0
  142. mithril/api/bindings/models/volume_model.py +140 -0
  143. mithril/api/bindings/models/volume_model_attachments.py +46 -0
  144. mithril/api/bindings/models/volume_model_interface.py +11 -0
  145. mithril/api/bindings/types.py +56 -0
  146. mithril/api/client.py +138 -0
  147. mithril/cli/__init__.py +7 -0
  148. mithril/cli/commands/__init__.py +15 -0
  149. mithril/cli/commands/help.py +88 -0
  150. mithril/cli/commands/launch.py +353 -0
  151. mithril/cli/main.py +68 -0
  152. mithril/cli/utils/__init__.py +1 -0
  153. mithril/cli/utils/skypilot_passthrough.py +38 -0
  154. mithril/cli/utils/streaming.py +235 -0
  155. mithril/cli/utils/volumes.py +110 -0
  156. mithril/config.py +47 -0
  157. mithril/py.typed +0 -0
  158. mithril/sky/__init__.py +141 -0
  159. mithril/sky/client.py +176 -0
  160. mithril_client-0.1.0a1.dist-info/METADATA +56 -0
  161. mithril_client-0.1.0a1.dist-info/RECORD +163 -0
  162. mithril_client-0.1.0a1.dist-info/WHEEL +4 -0
  163. mithril_client-0.1.0a1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,38 @@
1
+ """SkyPilot passthrough helpers for the Mithril CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ import sys
7
+ from contextlib import suppress
8
+
9
+ SKY_ALIAS_COMMANDS: tuple[str, ...] = ("exec", "status", "start", "down", "stop")
10
+
11
+
12
+ def run_sky(*args: str) -> int:
13
+ """Run sky CLI with given arguments.
14
+
15
+ Note: sky.cli is not part of skypilot's public API, but it's the stable
16
+ entry point for the `sky` command (defined in skypilot's setup.py). We
17
+ invoke it directly to avoid requiring `sky` to be on PATH (e.g., when
18
+ installed via `uv tool`).
19
+ """
20
+ try:
21
+ sky_cli = importlib.import_module("sky.client.cli.command").cli
22
+ sky_cli.main(args=list(args), standalone_mode=False)
23
+ except SystemExit as e:
24
+ return e.code if isinstance(e.code, int) else 1
25
+ except Exception as e: # noqa: BLE001
26
+ # SkyPilot is Click-based and may raise ClickException; avoid importing click
27
+ # directly and rely on duck-typing to preserve decent error output.
28
+ show = getattr(e, "show", None)
29
+ if callable(show):
30
+ with suppress(Exception):
31
+ show()
32
+ exit_code = getattr(e, "exit_code", 1)
33
+ return exit_code if isinstance(exit_code, int) else 1
34
+
35
+ print(f"Error: {e}", file=sys.stderr)
36
+ return 1
37
+ else:
38
+ return 0
@@ -0,0 +1,235 @@
1
+ """Custom streaming utilities for Mithril CLI.
2
+
3
+ This module provides a polling-based approach to monitoring SkyPilot requests,
4
+ giving full control over the UI rather than parsing SkyPilot's log output.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ from typing import TYPE_CHECKING, TypeVar
11
+
12
+ from rich.live import Live
13
+ from rich.spinner import Spinner
14
+ from rich.text import Text
15
+
16
+ from mithril.cli.utils.skypilot_passthrough import SKY_ALIAS_COMMANDS
17
+
18
+ T = TypeVar("T")
19
+
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Callable
22
+
23
+ from rich.console import Console
24
+ from sky.server.common import RequestId
25
+
26
+ from mithril.sky import SkyClient
27
+
28
+ TERMINAL_STATUSES = {"SUCCEEDED", "FAILED", "CANCELLED"}
29
+
30
+
31
+ def stream_job_logs(
32
+ cluster_name: str,
33
+ job_id: int,
34
+ console: Console,
35
+ *,
36
+ sky: SkyClient,
37
+ follow: bool = True,
38
+ filter_fn: Callable[[str], str | None] | None = None,
39
+ ) -> int:
40
+ """Stream job logs with optional filtering.
41
+
42
+ Args:
43
+ cluster_name: The cluster name
44
+ job_id: The job ID to stream logs for
45
+ console: Rich console for output
46
+ sky: SkyClient instance
47
+ follow: Whether to follow the logs (like tail -f)
48
+ filter_fn: Optional function to filter/transform each line.
49
+ Return the line to print, or None to skip it.
50
+
51
+ Returns:
52
+ Exit code (0 for success, non-zero for failure)
53
+ """
54
+ # Use preload_content=False to get an iterator
55
+ log_iter = sky.tail_logs(
56
+ cluster_name=cluster_name,
57
+ job_id=job_id,
58
+ follow=follow,
59
+ preload_content=False,
60
+ )
61
+
62
+ for line in log_iter:
63
+ if line is None:
64
+ # End of stream
65
+ break
66
+
67
+ if filter_fn:
68
+ filtered = filter_fn(line)
69
+ if filtered is not None:
70
+ console.print(filtered, end="", markup=False, highlight=False)
71
+ else:
72
+ console.print(line, end="", markup=False, highlight=False)
73
+
74
+ # Get the exit code by checking job status
75
+ job_statuses = sky.get(sky.job_status(cluster_name, job_ids=[job_id]))
76
+ status = job_statuses.get(job_id)
77
+
78
+ if status is None:
79
+ return 1 # Job not found
80
+
81
+ # JobStatus enum - check if succeeded
82
+ return 0 if str(status) == "SUCCEEDED" else 1
83
+
84
+
85
+ def poll_launch( # noqa: UP047 - PEP 695 syntax requires 3.12+
86
+ request_id: RequestId[T],
87
+ console: Console,
88
+ *,
89
+ sky: SkyClient,
90
+ poll_interval: float = 1.0,
91
+ quiet: bool = False,
92
+ initial_cluster_name: str | None = None,
93
+ ) -> tuple[T, str | None]:
94
+ """Poll a launch request until completion, showing custom progress.
95
+
96
+ Args:
97
+ request_id: The SkyPilot request ID from sky.launch()
98
+ console: Rich console for output
99
+ sky: SkyClient instance
100
+ poll_interval: How often to poll for status updates (seconds)
101
+ quiet: If True, suppress progress output
102
+ initial_cluster_name: Optional cluster name to seed status messages
103
+
104
+ Returns:
105
+ The result from the launch request (job_id, handle)
106
+ """
107
+ cluster_name = initial_cluster_name
108
+ initial_renderable, last_display_text = _render_status_message(
109
+ None,
110
+ cluster_name,
111
+ )
112
+
113
+ spinner = Spinner("dots", text="Launching...")
114
+ live = Live(spinner, console=console, refresh_per_second=10, transient=True)
115
+
116
+ try:
117
+ if not quiet:
118
+ spinner.update(text=initial_renderable)
119
+ live.start()
120
+ last_display_text, cluster_name = _run_poll_loop(
121
+ request_id=request_id,
122
+ poll_interval=poll_interval,
123
+ spinner=spinner,
124
+ live=live,
125
+ last_display_text=last_display_text,
126
+ cluster_name=cluster_name,
127
+ quiet=quiet,
128
+ sky=sky,
129
+ )
130
+ except KeyboardInterrupt:
131
+ if not quiet and cluster_name:
132
+ _print_provision_hint(console, cluster_name)
133
+ raise
134
+ finally:
135
+ if not quiet:
136
+ live.stop()
137
+
138
+ # Get the final result (this will raise if there was an error)
139
+ result = sky.get(request_id)
140
+
141
+ return result, cluster_name
142
+
143
+
144
+ def format_status_message(status_msg: str, cluster_name: str | None) -> Text:
145
+ """Format a status message for display.
146
+
147
+ Appends the cluster name to status messages so users know which cluster
148
+ is being worked on. For example: "Launching" becomes "Launching (my-cluster)".
149
+
150
+ The cluster name is optional because it isn't available at the start of a
151
+ launch—SkyPilot assigns it during provisioning. Early status messages won't
152
+ have a cluster name yet.
153
+ """
154
+ formatted = status_msg
155
+ if cluster_name:
156
+ formatted = f"{formatted} ({cluster_name})"
157
+ formatted = (
158
+ f"{formatted} [dim]View logs: {_cli_command('logs')} --provision "
159
+ f"{cluster_name}[/dim]"
160
+ )
161
+ return Text.from_markup(formatted)
162
+
163
+
164
+ def _default_launch_message(cluster_name: str | None) -> str:
165
+ if cluster_name:
166
+ return (
167
+ "Launching [dim]View logs: "
168
+ f"{_cli_command('logs')} --provision {cluster_name}[/dim]"
169
+ )
170
+ return "Launching..."
171
+
172
+
173
+ def _render_status_message(
174
+ status_msg: str | None,
175
+ cluster_name: str | None,
176
+ ) -> tuple[Text, str]:
177
+ if status_msg:
178
+ renderable = format_status_message(status_msg, cluster_name)
179
+ return renderable, renderable.plain
180
+ message = _default_launch_message(cluster_name)
181
+ renderable = Text.from_markup(message)
182
+ return renderable, renderable.plain
183
+
184
+
185
+ def _run_poll_loop(
186
+ *,
187
+ request_id: str,
188
+ poll_interval: float,
189
+ spinner: Spinner,
190
+ live: Live,
191
+ last_display_text: str,
192
+ cluster_name: str | None,
193
+ quiet: bool,
194
+ sky: SkyClient,
195
+ ) -> tuple[str, str | None]:
196
+ while True:
197
+ statuses = sky.api_status(request_ids=[request_id])
198
+
199
+ if not statuses:
200
+ time.sleep(poll_interval)
201
+ continue
202
+
203
+ status = statuses[0]
204
+ current_status = status.status
205
+ status_msg = status.status_msg
206
+ cluster_name = status.cluster_name or cluster_name
207
+
208
+ if not quiet:
209
+ renderable, display_text = _render_status_message(
210
+ status_msg,
211
+ cluster_name,
212
+ )
213
+ if display_text != last_display_text:
214
+ last_display_text = display_text
215
+ spinner.update(text=renderable)
216
+ live.update(spinner)
217
+
218
+ if current_status in TERMINAL_STATUSES:
219
+ break
220
+
221
+ time.sleep(poll_interval)
222
+
223
+ return last_display_text, cluster_name
224
+
225
+
226
+ def _print_provision_hint(console: Console, cluster_name: str) -> None:
227
+ console.print(
228
+ f"[dim]View logs: {_cli_command('logs')} --provision {cluster_name}[/dim]"
229
+ )
230
+
231
+
232
+ def _cli_command(command: str) -> str:
233
+ if command in SKY_ALIAS_COMMANDS:
234
+ return f"ml {command}"
235
+ return f"ml sky {command}"
@@ -0,0 +1,110 @@
1
+ """Volume synchronization between Mithril and SkyPilot.
2
+
3
+ This module ensures that volumes referenced in a SkyPilot task are registered
4
+ with SkyPilot before launch. It checks volumes against the Mithril API and
5
+ automatically registers any that exist in Mithril but aren't yet known to SkyPilot.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ from mithril.api.bindings.models.volume_model_interface import VolumeModelInterface
13
+
14
+ # SkyPilot volume type constants from sky.utils.volume.VolumeType enum
15
+ MITHRIL_FILE_SHARE_TYPE = "mithril-file-share"
16
+ MITHRIL_BLOCK_TYPE = "mithril-block"
17
+
18
+ if TYPE_CHECKING:
19
+ from mithril.api.bindings.models import VolumeModel
20
+ from mithril.api.client import MithrilClient
21
+ from mithril.sky import SkyClient
22
+
23
+
24
+ def _get_skypilot_volumes(sky: SkyClient) -> set[str]:
25
+ """Get volume names registered with SkyPilot.
26
+
27
+ Args:
28
+ sky: SkyClient instance.
29
+
30
+ Returns:
31
+ Set of volume names currently registered in SkyPilot.
32
+ """
33
+ request_id = sky.volumes.ls()
34
+ volumes = sky.get(request_id)
35
+ return {vol.name for vol in volumes if vol.name}
36
+
37
+
38
+ def _register_volume_with_skypilot(vol: VolumeModel, sky: SkyClient) -> None:
39
+ """Register a Mithril volume with SkyPilot.
40
+
41
+ Args:
42
+ vol: VolumeModel from Mithril API.
43
+ sky: SkyClient instance.
44
+
45
+ Raises:
46
+ Exception: If registration fails. sky.get() will raise RuntimeError,
47
+ RequestCancelled, or the original exception from the request.
48
+ """
49
+ # Map Mithril volume to SkyPilot volume spec
50
+ match vol.interface:
51
+ case VolumeModelInterface.FILE:
52
+ volume_type = MITHRIL_FILE_SHARE_TYPE
53
+ case VolumeModelInterface.BLOCK:
54
+ volume_type = MITHRIL_BLOCK_TYPE
55
+
56
+ # Note: SkyPilot treats GB and Gi identically (both map to 2**30 bytes)
57
+ sky_volume = sky.volumes.Volume(
58
+ name=vol.name,
59
+ type=volume_type,
60
+ infra=f"mithril/{vol.region}",
61
+ size=f"{vol.capacity_gb}GB",
62
+ )
63
+ request_id = sky.volumes.apply(sky_volume)
64
+ sky.get(request_id)
65
+
66
+
67
+ def ensure_volumes_registered(
68
+ volume_names: list[str],
69
+ *,
70
+ mithril: MithrilClient,
71
+ sky: SkyClient,
72
+ ) -> list[str]:
73
+ """Ensure task volumes are registered with SkyPilot.
74
+
75
+ For each volume name:
76
+ 1. Check if already registered with SkyPilot
77
+ 2. If not, check if it exists in Mithril
78
+ 3. If found in Mithril, register it with SkyPilot
79
+
80
+ Args:
81
+ volume_names: List of volume names to ensure are registered.
82
+ mithril: MithrilClient instance for Mithril API operations.
83
+ sky: SkyClient instance for SkyPilot operations.
84
+
85
+ Returns:
86
+ List of volume names that were newly registered.
87
+ """
88
+ if not volume_names:
89
+ return []
90
+
91
+ volume_names_set = set(volume_names)
92
+
93
+ # Check what's already in SkyPilot
94
+ sky_volumes = _get_skypilot_volumes(sky)
95
+ missing = volume_names_set - sky_volumes
96
+
97
+ if not missing:
98
+ return []
99
+
100
+ mithril_volumes = mithril.list_volumes()
101
+ mithril_by_name = {v.name: v for v in mithril_volumes}
102
+
103
+ registered: list[str] = []
104
+ for name in missing:
105
+ if name in mithril_by_name:
106
+ vol = mithril_by_name[name]
107
+ _register_volume_with_skypilot(vol, sky)
108
+ registered.append(name)
109
+
110
+ return registered
mithril/config.py ADDED
@@ -0,0 +1,47 @@
1
+ """Mithril configuration management.
2
+
3
+ Delegates to the Rust implementation for config loading.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+
10
+ from mithril._mcli import MithrilConfig as RustConfig
11
+ from mithril._mcli import load_config as rust_load_config
12
+
13
+ DEFAULT_API_URL = "https://api.mithril.ai"
14
+
15
+
16
+ class ConfigError(Exception):
17
+ """Raised when configuration is missing or invalid."""
18
+
19
+
20
+ @dataclass(frozen=True, slots=True)
21
+ class Config:
22
+ """Resolved configuration for Mithril operations."""
23
+
24
+ api_key: str
25
+ project_id: str
26
+ api_url: str
27
+
28
+
29
+ def load_config() -> Config:
30
+ """Load and validate configuration from environment and config file.
31
+
32
+ Returns:
33
+ Validated Config object.
34
+
35
+ Raises:
36
+ ConfigError: If required configuration (api_key, project_id) is missing.
37
+ """
38
+ try:
39
+ rust_cfg: RustConfig = rust_load_config()
40
+ return Config(
41
+ api_key=rust_cfg.api_key,
42
+ project_id=rust_cfg.project_id,
43
+ api_url=rust_cfg.api_url,
44
+ )
45
+ except RuntimeError as e:
46
+ # Convert Rust RuntimeError to Python ConfigError
47
+ raise ConfigError(str(e)) from e
mithril/py.typed ADDED
File without changes
@@ -0,0 +1,141 @@
1
+ """SkyPilot wrapper for stable imports and testability.
2
+
3
+ This module provides two interfaces:
4
+
5
+ 1. **Module-level API** for SDK developers::
6
+
7
+ from mithril import sky
8
+
9
+ task = sky.Task(run="echo hello")
10
+ resources = sky.Resources(accelerators="A100:4")
11
+ sky.launch(task=task, cluster_name="dev", ...)
12
+
13
+ 2. **SkyClient class** for CLI code and testing (supports parameter injection)::
14
+
15
+ from mithril.sky import SkyClient
16
+
17
+ client = SkyClient()
18
+ # or inject a mock in tests
19
+ build_task(params, sky=mock_client)
20
+
21
+ The module lazily imports SkyPilot to keep `ml --help` fast.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from typing import TYPE_CHECKING, TypeVar
27
+
28
+ from mithril.sky.client import SkyClient
29
+
30
+ T = TypeVar("T")
31
+
32
+ if TYPE_CHECKING:
33
+ from collections.abc import Iterable
34
+
35
+ from sky import Task as SkyTask
36
+ from sky import backends
37
+ from sky.server.common import RequestId
38
+ from sky.server.requests import payloads
39
+ from sky.skylet import job_lib
40
+
41
+ __all__ = [
42
+ "SkyClient",
43
+ "api_info",
44
+ "api_status",
45
+ "get",
46
+ "job_status",
47
+ "launch",
48
+ "tail_logs",
49
+ ]
50
+
51
+ # Default client instance for module-level API
52
+ _client = SkyClient()
53
+
54
+ # Module-level attributes (Task, Resources, clouds, volumes) are accessed via
55
+ # __getattr__ to maintain lazy loading - SkyPilot is only imported when first accessed.
56
+ _LAZY_ATTRS = {"Task", "Resources", "clouds", "volumes"}
57
+
58
+
59
+ def __getattr__(name: str) -> object:
60
+ """Lazily expose SkyClient attributes at module level."""
61
+ if name in _LAZY_ATTRS:
62
+ return getattr(_client, name)
63
+ msg = f"module {__name__!r} has no attribute {name!r}"
64
+ raise AttributeError(msg)
65
+
66
+
67
+ def __dir__() -> list[str]:
68
+ """Include lazy attributes in dir() for discoverability."""
69
+ return [*globals().keys(), *_LAZY_ATTRS]
70
+
71
+
72
+ def launch(
73
+ *,
74
+ task: SkyTask,
75
+ cluster_name: str | None = None,
76
+ retry_until_up: bool = False,
77
+ idle_minutes_to_autostop: int | None = None,
78
+ dryrun: bool = False,
79
+ down: bool = False,
80
+ ) -> RequestId[tuple[int | None, backends.ResourceHandle | None]]:
81
+ """Launch a task on SkyPilot.
82
+
83
+ Args:
84
+ task: A sky.Task object
85
+ cluster_name: Name of the cluster (auto-generated if None)
86
+ retry_until_up: Keep retrying until the cluster is up
87
+ idle_minutes_to_autostop: Auto-stop after idle for this many minutes
88
+ dryrun: If True, don't actually launch
89
+ down: Tear down cluster after job completes
90
+
91
+ Returns:
92
+ Request ID for the launch operation
93
+ """
94
+ return _client.launch(
95
+ task=task,
96
+ cluster_name=cluster_name,
97
+ retry_until_up=retry_until_up,
98
+ idle_minutes_to_autostop=idle_minutes_to_autostop,
99
+ dryrun=dryrun,
100
+ down=down,
101
+ )
102
+
103
+
104
+ def api_status(*, request_ids: list[str]) -> list[payloads.RequestPayload]:
105
+ """Get status of SkyPilot requests."""
106
+ return _client.api_status(request_ids=request_ids)
107
+
108
+
109
+ def get(request_id: RequestId[T]) -> T: # noqa: UP047 - PEP 695 syntax requires 3.12+
110
+ """Get result of a SkyPilot request."""
111
+ return _client.get(request_id)
112
+
113
+
114
+ def tail_logs(
115
+ *,
116
+ cluster_name: str,
117
+ job_id: int,
118
+ follow: bool = True,
119
+ preload_content: bool = False,
120
+ ) -> Iterable[str | None]:
121
+ """Stream logs from a SkyPilot job."""
122
+ return _client.tail_logs(
123
+ cluster_name=cluster_name,
124
+ job_id=job_id,
125
+ follow=follow,
126
+ preload_content=preload_content,
127
+ )
128
+
129
+
130
+ def job_status(
131
+ cluster_name: str,
132
+ *,
133
+ job_ids: list[int] | None = None,
134
+ ) -> RequestId[dict[int | None, job_lib.JobStatus | None]]:
135
+ """Get status of a SkyPilot job (returns a request ID)."""
136
+ return _client.job_status(cluster_name, job_ids=job_ids)
137
+
138
+
139
+ def api_info() -> object:
140
+ """Get API server info (status, version, etc.)."""
141
+ return _client.api_info()