lightning-sdk 0.1.55__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +22 -0
  3. lightning_sdk/api/ai_hub_api.py +21 -2
  4. lightning_sdk/api/deployment_api.py +4 -3
  5. lightning_sdk/api/job_api.py +5 -10
  6. lightning_sdk/api/mmt_api.py +1 -4
  7. lightning_sdk/api/studio_api.py +5 -7
  8. lightning_sdk/api/teamspace_api.py +7 -0
  9. lightning_sdk/api/utils.py +1 -27
  10. lightning_sdk/cli/configure.py +92 -0
  11. lightning_sdk/cli/connect.py +31 -0
  12. lightning_sdk/cli/delete.py +6 -4
  13. lightning_sdk/cli/download.py +1 -1
  14. lightning_sdk/cli/entrypoint.py +8 -1
  15. lightning_sdk/cli/generate.py +13 -36
  16. lightning_sdk/cli/inspect.py +4 -2
  17. lightning_sdk/cli/jobs_menu.py +2 -1
  18. lightning_sdk/cli/list.py +5 -10
  19. lightning_sdk/cli/mmts_menu.py +2 -1
  20. lightning_sdk/cli/run.py +3 -3
  21. lightning_sdk/cli/serve.py +1 -2
  22. lightning_sdk/cli/start.py +2 -2
  23. lightning_sdk/cli/stop.py +5 -3
  24. lightning_sdk/cli/studios_menu.py +24 -1
  25. lightning_sdk/cli/switch.py +2 -2
  26. lightning_sdk/cli/teamspace_menu.py +2 -1
  27. lightning_sdk/cli/upload.py +6 -4
  28. lightning_sdk/lightning_cloud/openapi/__init__.py +4 -0
  29. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
  30. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +113 -0
  31. lightning_sdk/lightning_cloud/openapi/api/lit_registry_service_api.py +7 -3
  32. lightning_sdk/lightning_cloud/openapi/api/projects_service_api.py +1 -5
  33. lightning_sdk/lightning_cloud/openapi/models/__init__.py +4 -0
  34. lightning_sdk/lightning_cloud/openapi/models/id_reportrestarttimings_body.py +123 -0
  35. lightning_sdk/lightning_cloud/openapi/models/project_id_litregistry_body.py +2 -0
  36. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +27 -1
  37. lightning_sdk/lightning_cloud/openapi/models/v1_get_cluster_accelerator_demand_response.py +123 -0
  38. lightning_sdk/lightning_cloud/openapi/models/v1_job.py +27 -1
  39. lightning_sdk/lightning_cloud/openapi/models/v1_lit_registry_artifact.py +27 -1
  40. lightning_sdk/lightning_cloud/openapi/models/v1_lit_registry_project.py +8 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_lit_repository.py +27 -1
  42. lightning_sdk/lightning_cloud/openapi/models/v1_report_restart_timings_response.py +97 -0
  43. lightning_sdk/lightning_cloud/openapi/models/v1_restart_timing.py +175 -0
  44. lightning_sdk/lightning_cloud/openapi/models/v1_validate_deployment_image_request.py +27 -1
  45. lightning_sdk/machine.py +59 -27
  46. lightning_sdk/studio.py +5 -1
  47. lightning_sdk/teamspace.py +25 -0
  48. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/METADATA +2 -1
  49. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/RECORD +53 -47
  50. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/LICENSE +0 -0
  51. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/WHEEL +0 -0
  52. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/entry_points.txt +0 -0
  53. {lightning_sdk-0.1.55.dist-info → lightning_sdk-0.1.56.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -29,5 +29,5 @@ __all__ = [
29
29
  "AIHub",
30
30
  ]
31
31
 
32
- __version__ = "0.1.55"
32
+ __version__ = "0.1.56"
33
33
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -8,6 +8,7 @@ from lightning_sdk.utils.resolve import _resolve_teamspace
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from lightning_sdk import Organization, Teamspace
11
+ from lightning_sdk.machine import Machine
11
12
 
12
13
 
13
14
  class AIHub:
@@ -107,6 +108,7 @@ class AIHub:
107
108
  teamspace: Optional[Union[str, "Teamspace"]] = None,
108
109
  org: Optional[Union[str, "Organization"]] = None,
109
110
  user: Optional[Union[str, "User"]] = None,
111
+ machine: Optional[Union[str, "Machine"]] = None,
110
112
  ) -> Dict[str, Union[str, bool]]:
111
113
  """Deploy an API from the AI Hub.
112
114
 
@@ -128,6 +130,7 @@ class AIHub:
128
130
  teamspace: The team or group for deployment. Defaults to None.
129
131
  org: The organization for deployment. Don't pass user with this. Defaults to None.
130
132
  user: The user for deployment. Don't pass org with this. Defaults to None.
133
+ machine: The machine to run the deployment on. Defaults to the first option set in the AI Hub template.
131
134
 
132
135
  Returns:
133
136
  A dictionary containing the name of the deployed API,
@@ -153,6 +156,7 @@ class AIHub:
153
156
  project_id=teamspace_id,
154
157
  name=name,
155
158
  api_arguments=api_arguments,
159
+ machine=machine,
156
160
  )
157
161
 
158
162
  url = (
@@ -171,4 +175,22 @@ class AIHub:
171
175
  "deployment_url": url,
172
176
  "api_endpoint": deployment.status.urls[0],
173
177
  "interruptible": deployment.spec.spot,
178
+ "teamspace id": teamspace_id,
174
179
  }
180
+
181
+ def delete_deployment(self, deployment: Dict[str, Union[str, bool]]) -> None:
182
+ """Delete a deployment from the AI Hub.
183
+
184
+ Example:
185
+ from lightning_sdk import AIHub
186
+ hub = AIHub()
187
+ deployment = hub.run("temp_xxxx")
188
+ hub.delete_deployment(deployment)
189
+
190
+ Args:
191
+ deployment: The deployment dictionary returned by the run method.
192
+ """
193
+ if "teamspace id" not in deployment or "id" not in deployment:
194
+ raise ValueError("Deployment dictionary must contain 'teamspace id' and 'id' keys.")
195
+
196
+ self._api.delete_api(deployment["id"], deployment["teamspace id"])
@@ -1,8 +1,10 @@
1
1
  import traceback
2
- from typing import Dict, List, Optional, Tuple
2
+ from typing import Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import backoff
5
5
 
6
+ from lightning_sdk.api.deployment_api import apply_change
7
+ from lightning_sdk.api.utils import _machine_to_compute_name
6
8
  from lightning_sdk.lightning_cloud.openapi.models import (
7
9
  CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs,
8
10
  V1Deployment,
@@ -16,6 +18,7 @@ from lightning_sdk.lightning_cloud.openapi.models.v1_deployment_template_gallery
16
18
  V1DeploymentTemplateGalleryResponse,
17
19
  )
18
20
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
21
+ from lightning_sdk.machine import Machine
19
22
 
20
23
 
21
24
  class AIHubApi:
@@ -110,7 +113,13 @@ class AIHubApi:
110
113
  return job
111
114
 
112
115
  def run_api(
113
- self, template_id: str, project_id: str, cloud_account: str, name: Optional[str], api_arguments: Dict[str, str]
116
+ self,
117
+ template_id: str,
118
+ project_id: str,
119
+ cloud_account: str,
120
+ name: Optional[str],
121
+ api_arguments: Dict[str, str],
122
+ machine: Optional[Union[str, Machine]],
114
123
  ) -> V1Deployment:
115
124
  template = self._client.deployment_templates_service_get_deployment_template(template_id)
116
125
  name = name or template.name
@@ -123,6 +132,13 @@ class AIHubApi:
123
132
  template.spec_v2.autoscaling.enabled = True
124
133
 
125
134
  AIHubApi._set_parameters(template.spec_v2.job, template.parameter_spec.parameters, api_arguments)
135
+ if machine and isinstance(machine, Machine):
136
+ apply_change(template.spec_v2.job, "instance_name", _machine_to_compute_name(machine))
137
+ apply_change(template.spec_v2.job, "instance_type", _machine_to_compute_name(machine))
138
+ elif machine and isinstance(machine, str):
139
+ apply_change(template.spec_v2.job, "instance_name", machine)
140
+ apply_change(template.spec_v2.job, "instance_type", machine)
141
+
126
142
  return self._client.jobs_service_create_deployment(
127
143
  project_id=project_id,
128
144
  body=CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(
@@ -134,3 +150,6 @@ class AIHubApi:
134
150
  spec=template.spec_v2.job,
135
151
  ),
136
152
  )
153
+
154
+ def delete_api(self, deployment_id: str, teamspace_id: str) -> None:
155
+ self._client.jobs_service_delete_deployment(project_id=teamspace_id, id=deployment_id)
@@ -1,7 +1,7 @@
1
1
  from time import sleep
2
2
  from typing import Any, List, Literal, Optional, Union
3
3
 
4
- from lightning_sdk.api.utils import _MACHINE_TO_COMPUTE_NAME
4
+ from lightning_sdk.api.utils import _machine_to_compute_name
5
5
  from lightning_sdk.lightning_cloud.openapi import (
6
6
  CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs,
7
7
  V1AutoscalingSpec,
@@ -270,7 +270,8 @@ class DeploymentApi:
270
270
 
271
271
  # Any updates to the Job Spec triggers a new release
272
272
  if machine:
273
- apply_change(deployment.spec, "instance_name", _MACHINE_TO_COMPUTE_NAME[machine])
273
+ apply_change(deployment.spec, "instance_name", _machine_to_compute_name(machine))
274
+ apply_change(deployment.spec, "instance_type", _machine_to_compute_name(machine))
274
275
 
275
276
  requires_release = False
276
277
  requires_release |= apply_change(deployment.spec, "image", environment)
@@ -554,7 +555,7 @@ def to_spec(
554
555
  env=to_env(env),
555
556
  image=environment,
556
557
  spot=spot,
557
- instance_name=_MACHINE_TO_COMPUTE_NAME[machine],
558
+ instance_name=_machine_to_compute_name(machine),
558
559
  readiness_probe=to_health_check(health_check),
559
560
  )
560
561
 
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
3
3
  from urllib.request import urlopen
4
4
 
5
5
  from lightning_sdk.api.utils import (
6
- _COMPUTE_NAME_TO_MACHINE,
7
6
  _create_app,
8
7
  _machine_to_compute_name,
9
8
  remove_datetime_prefix,
@@ -98,13 +97,11 @@ class JobApiV1:
98
97
  def get_machine_from_work(self, work: Externalv1Lightningwork) -> Machine:
99
98
  spec: V1LightningworkSpec = work.spec
100
99
  # prefer user-requested config if specified
101
- compute_config: V1UserRequestedComputeConfig = spec.user_requested_compute_config
102
- compute: str = compute_config.name
103
- if compute:
104
- return _COMPUTE_NAME_TO_MACHINE[compute]
100
+ user_requested_compute_config: V1UserRequestedComputeConfig = spec.user_requested_compute_config
101
+ if user_requested_compute_config.name:
102
+ return Machine(user_requested_compute_config.name, user_requested_compute_config.name)
105
103
  compute_config: V1ComputeConfig = spec.compute_config
106
- compute: str = compute_config.instance_type
107
- return _COMPUTE_NAME_TO_MACHINE[compute]
104
+ return Machine(compute_config.instance_type, compute_config.instance_type)
108
105
 
109
106
  def get_studio_name(self, job: Externalv1LightningappInstance) -> str:
110
107
  cs: V1CloudSpace = self._client.cloud_space_service_get_cloud_space(
@@ -343,9 +340,7 @@ class JobApiV2:
343
340
  instance_name = spec.instance_name
344
341
  instance_type = spec.instance_type
345
342
 
346
- return _COMPUTE_NAME_TO_MACHINE.get(
347
- instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
348
- )
343
+ return Machine(instance_name, instance_type or instance_name)
349
344
 
350
345
  def get_total_cost(self, job: V1Job) -> float:
351
346
  return job.total_cost
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
4
4
 
5
5
  from lightning_sdk.api.job_api import JobApiV1
6
6
  from lightning_sdk.api.utils import (
7
- _COMPUTE_NAME_TO_MACHINE,
8
7
  _create_app,
9
8
  _machine_to_compute_name,
10
9
  resolve_path_mappings,
@@ -205,9 +204,7 @@ class MMTApiV2:
205
204
  instance_name = spec.instance_name
206
205
  instance_type = spec.instance_type
207
206
 
208
- return _COMPUTE_NAME_TO_MACHINE.get(
209
- instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
210
- )
207
+ return Machine(instance_name, instance_type or instance_name)
211
208
 
212
209
  def get_total_cost(self, job: V1MultiMachineJob) -> float:
213
210
  return job.total_cost
@@ -12,8 +12,6 @@ import requests
12
12
  from tqdm import tqdm
13
13
 
14
14
  from lightning_sdk.api.utils import (
15
- _COMPUTE_NAME_TO_MACHINE,
16
- _MACHINE_TO_COMPUTE_NAME,
17
15
  _create_app,
18
16
  _DummyBody,
19
17
  _DummyResponse,
@@ -257,7 +255,7 @@ class StudioApi:
257
255
  response: V1CloudSpaceInstanceConfig = self._client.cloud_space_service_get_cloud_space_instance_config(
258
256
  project_id=teamspace_id, id=studio_id
259
257
  )
260
- return _COMPUTE_NAME_TO_MACHINE[response.compute_config.name]
258
+ return Machine(response.compute_config.name, response.compute_config.name)
261
259
 
262
260
  def get_interruptible(self, studio_id: str, teamspace_id: str) -> bool:
263
261
  """Get whether the Studio is running on a interruptible instance."""
@@ -582,7 +580,7 @@ class StudioApi:
582
580
  plugin_type="job",
583
581
  entrypoint=entrypoint,
584
582
  name=name,
585
- compute=_MACHINE_TO_COMPUTE_NAME[machine],
583
+ compute=_machine_to_compute_name(machine),
586
584
  interruptible=interruptible,
587
585
  )
588
586
 
@@ -600,7 +598,7 @@ class StudioApi:
600
598
  ) -> Externalv1LightningappInstance:
601
599
  """Creates a multi-machine job with given commands."""
602
600
  distributed_args = {
603
- "cloud_compute": _MACHINE_TO_COMPUTE_NAME[machine],
601
+ "cloud_compute": _machine_to_compute_name(machine),
604
602
  "num_instances": num_instances,
605
603
  "strategy": strategy,
606
604
  }
@@ -628,7 +626,7 @@ class StudioApi:
628
626
  ) -> Externalv1LightningappInstance:
629
627
  """Creates a multi-machine job with given commands."""
630
628
  data_prep_args = {
631
- "cloud_compute": _MACHINE_TO_COMPUTE_NAME[machine],
629
+ "cloud_compute": _machine_to_compute_name(machine),
632
630
  "num_instances": num_instances,
633
631
  }
634
632
  return self._create_app(
@@ -665,7 +663,7 @@ class StudioApi:
665
663
  teamspace_id=teamspace_id,
666
664
  cloud_account=cloud_account,
667
665
  plugin_type="inference_plugin",
668
- compute=_MACHINE_TO_COMPUTE_NAME[machine],
666
+ compute=_machine_to_compute_name(machine),
669
667
  entrypoint=entrypoint,
670
668
  name=name,
671
669
  min_replicas=min_replicas,
@@ -14,6 +14,7 @@ from lightning_sdk.lightning_cloud.openapi import (
14
14
  ProjectIdModelsBody,
15
15
  V1Assistant,
16
16
  V1CloudSpace,
17
+ V1ClusterAccelerator,
17
18
  V1Endpoint,
18
19
  V1Job,
19
20
  V1ModelVersionArchive,
@@ -295,3 +296,9 @@ class TeamspaceApi:
295
296
  ).lightningapps
296
297
  jobs = self._client.jobs_service_list_multi_machine_jobs(project_id=teamspace_id).multi_machine_jobs
297
298
  return apps, jobs
299
+
300
+ def list_machines(self, teamspace_id: str, cloud_account: str) -> List[V1ClusterAccelerator]:
301
+ response = self._client.cluster_service_list_project_cluster_accelerators(
302
+ project_id=teamspace_id, id=cloud_account
303
+ )
304
+ return response.accelerator
@@ -322,38 +322,12 @@ class _DummyResponse:
322
322
  self.data = data
323
323
 
324
324
 
325
- # TODO: This should really come from some kind of metadata service
326
- _MACHINE_TO_COMPUTE_NAME: Dict[Machine, str] = {
327
- Machine.CPU_SMALL: "m3.medium",
328
- Machine.CPU: "cpu-4",
329
- Machine.DATA_PREP: "data-large",
330
- Machine.DATA_PREP_MAX: "data-max",
331
- Machine.DATA_PREP_ULTRA: "data-ultra",
332
- Machine.T4: "g4dn.2xlarge",
333
- Machine.T4_X_4: "g4dn.12xlarge",
334
- Machine.L4: "g6.4xlarge",
335
- Machine.L4_X_4: "g6.12xlarge",
336
- Machine.L4_X_8: "g6.48xlarge",
337
- Machine.A10G: "g5.8xlarge",
338
- Machine.A10G_X_4: "g5.12xlarge",
339
- Machine.A10G_X_8: "g5.48xlarge",
340
- Machine.L40S: "g6e.4xlarge",
341
- Machine.L40S_X_4: "g6e.12xlarge",
342
- Machine.L40S_X_8: "g6e.48xlarge",
343
- Machine.A100_X_8: "p4d.24xlarge",
344
- Machine.H100_X_8: "p5.48xlarge",
345
- Machine.H200_X_8: "p5e.48xlarge",
346
- }
347
-
348
-
349
325
  def _machine_to_compute_name(machine: Union[Machine, str]) -> str:
350
326
  if isinstance(machine, Machine):
351
- return _MACHINE_TO_COMPUTE_NAME[machine]
327
+ return machine.instance_type
352
328
  return machine
353
329
 
354
330
 
355
- _COMPUTE_NAME_TO_MACHINE: Dict[str, Machine] = {v: k for k, v in _MACHINE_TO_COMPUTE_NAME.items()}
356
-
357
331
  _DEFAULT_CLOUD_URL = "https://lightning.ai"
358
332
  _DEFAULT_REGISTRY_URL = "litcr.io"
359
333
 
@@ -0,0 +1,92 @@
1
+ import platform
2
+ import uuid
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ from rich.console import Console
7
+
8
+ from lightning_sdk.cli.generate import _Generate
9
+ from lightning_sdk.lightning_cloud.login import Auth
10
+
11
+
12
+ def _download_file(url: str, local_path: Path, overwrite: bool = True, chmod: Optional[int] = None) -> None:
13
+ """Download a file from a URL."""
14
+ import requests
15
+
16
+ if local_path.exists() and not overwrite:
17
+ raise FileExistsError(f"The file {local_path} already exists and overwrite is set to False.")
18
+
19
+ response = requests.get(url, stream=True)
20
+ response.raise_for_status()
21
+
22
+ with open(local_path, "wb") as file:
23
+ for chunk in response.iter_content(chunk_size=8192):
24
+ file.write(chunk)
25
+ if chmod is not None:
26
+ local_path.chmod(0o600)
27
+
28
+
29
+ class _Configure(_Generate):
30
+ """Configure lightning products."""
31
+
32
+ @staticmethod
33
+ def _download_ssh_keys(
34
+ api_key: str,
35
+ key_id: str = "",
36
+ ssh_home: Union[str, Path] = "",
37
+ ssh_key_name: str = "lightning_rsa",
38
+ overwrite: bool = False,
39
+ ) -> None:
40
+ if not ssh_home:
41
+ ssh_home = Path.home() / ".ssh"
42
+ elif isinstance(ssh_home, str):
43
+ ssh_home = Path(ssh_home)
44
+ if not key_id:
45
+ key_id = str(uuid.uuid4())
46
+
47
+ path_key = ssh_home / ssh_key_name
48
+ path_pub = ssh_home / f"{ssh_key_name}.pub"
49
+
50
+ # todo: consider hitting the API to get the key pair directly instead of using wget
51
+ _download_file(
52
+ f"https://lightning.ai/setup/ssh-gen?t={api_key}&id={key_id}&machineName={platform.node()}",
53
+ path_key,
54
+ overwrite=overwrite,
55
+ chmod=0o600,
56
+ )
57
+ _download_file(f"https://lightning.ai/setup/ssh-public?t={api_key}&id={key_id}", path_pub, overwrite=overwrite)
58
+
59
+ def ssh(self, overwrite: bool = False, ssh_key_name: str = "lightning_rsa") -> None:
60
+ """Get SSH config entry for a studio.
61
+
62
+ Args:
63
+ overwrite: Whether to overwrite the SSH key and config if they already exist.
64
+ ssh_key_name: The name of the SSH key to generate
65
+ """
66
+ auth = Auth()
67
+ auth.authenticate()
68
+ console = Console()
69
+ ssh_dir = Path.home() / ".ssh"
70
+ ssh_dir.mkdir(parents=True, exist_ok=True)
71
+
72
+ key_path = ssh_dir / ssh_key_name
73
+ config_path = ssh_dir / "config"
74
+
75
+ # Check if the SSH key already exists
76
+ if key_path.exists() and (key_path.with_suffix(".pub")).exists() and not overwrite:
77
+ console.print(f"SSH key already exists at {key_path}")
78
+ else:
79
+ self._download_ssh_keys(auth.api_key, ssh_home=ssh_dir, ssh_key_name=ssh_key_name, overwrite=overwrite)
80
+ console.print(f"SSH key generated and saved to {key_path}")
81
+
82
+ # Check if the SSH config already contains the required configuration
83
+ config_content = self._generate_ssh_config(str(key_path))
84
+ if config_path.exists():
85
+ with config_path.open("r") as config_file:
86
+ if config_content.strip() in config_file.read():
87
+ console.print("SSH config already contains the required configuration.")
88
+ return
89
+
90
+ with config_path.open("a") as config_file:
91
+ config_file.write(config_content)
92
+ console.print(f"SSH config updated at {config_path}")
@@ -0,0 +1,31 @@
1
+ import subprocess
2
+ import sys
3
+ from typing import Optional
4
+
5
+ from lightning_sdk.cli.configure import _Configure
6
+ from lightning_sdk.lightning_cloud.login import Auth
7
+
8
+
9
+ class _Connect(_Configure):
10
+ """Connect to lightning products."""
11
+
12
+ def studio(self, name: Optional[str] = None, teamspace: Optional[str] = None) -> None:
13
+ """Connect to a studio via SSH.
14
+
15
+ Args:
16
+ name: The name of the studio to connect to.
17
+ teamspace: The teamspace the studio is part of. Should be of format <OWNER>/<TEAMSPACE_NAME>.
18
+ """
19
+ auth = Auth()
20
+ auth.authenticate() # this is maybe not needed
21
+ studio = self._get_studio(name=name, teamspace=teamspace)
22
+ host = "ssh.lightning.ai"
23
+ username = f"s_{studio._studio.id}"
24
+
25
+ self.ssh(overwrite=False)
26
+
27
+ try:
28
+ subprocess.run(["ssh", f"{username}@{host}"])
29
+ except Exception as ex:
30
+ print(f"Failed to establish SSH connection: {ex}")
31
+ sys.exit(1)
@@ -1,5 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
+ from rich.console import Console
4
+
3
5
  from lightning_sdk.cli.exceptions import StudioCliError
4
6
  from lightning_sdk.cli.job_and_mmt_action import _JobAndMMTAction
5
7
  from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
@@ -22,7 +24,7 @@ class _Delete(_JobAndMMTAction, _TeamspacesMenu):
22
24
  resolved_teamspace = self._resolve_teamspace(teamspace=teamspace)
23
25
  try:
24
26
  api.delete_container(container, resolved_teamspace.name, resolved_teamspace.owner.name)
25
- print(f"Container {container} deleted successfully.")
27
+ Console().print(f"Container {container} deleted successfully.")
26
28
  except Exception as e:
27
29
  raise StudioCliError(
28
30
  f"Could not delete container {container} from project {resolved_teamspace.name}: {e}"
@@ -41,7 +43,7 @@ class _Delete(_JobAndMMTAction, _TeamspacesMenu):
41
43
  job = super().job(name=name, teamspace=teamspace)
42
44
 
43
45
  job.delete()
44
- print(f"Successfully deleted {job.name}!")
46
+ Console().print(f"Successfully deleted {job.name}!")
45
47
 
46
48
  def mmt(self, name: Optional[str] = None, teamspace: Optional[str] = None) -> None:
47
49
  """Delete a multi-machine job.
@@ -56,7 +58,7 @@ class _Delete(_JobAndMMTAction, _TeamspacesMenu):
56
58
  mmt = super().mmt(name=name, teamspace=teamspace)
57
59
 
58
60
  mmt.delete()
59
- print(f"Successfully deleted {mmt.name}!")
61
+ Console().print(f"Successfully deleted {mmt.name}!")
60
62
 
61
63
  def studio(self, name: Optional[str] = None, teamspace: Optional[str] = None) -> None:
62
64
  """Delete an existing studio.
@@ -82,4 +84,4 @@ class _Delete(_JobAndMMTAction, _TeamspacesMenu):
82
84
  studio = Studio(name=name, teamspace=teamspace, org=None, user=owner, create_ok=False)
83
85
 
84
86
  studio.delete()
85
- print("Studio successfully deleted")
87
+ Console().print("Studio successfully deleted")
@@ -143,8 +143,8 @@ class _Downloads(_StudiosMenu, _TeamspacesMenu):
143
143
  teamspace: The name of the teamspace to download the container from.
144
144
  tag: The tag of the container to download.
145
145
  """
146
- resolved_teamspace = self._resolve_teamspace(teamspace)
147
146
  console = Console()
147
+ resolved_teamspace = self._resolve_teamspace(teamspace)
148
148
  with console.status("Downloading container..."):
149
149
  api = LitContainerApi()
150
150
  api.download_container(container, resolved_teamspace, tag)
@@ -4,9 +4,13 @@ from typing import Type
4
4
 
5
5
  from fire import Fire
6
6
  from lightning_utilities.core.imports import RequirementCache
7
+ from rich.console import Console
8
+ from rich.panel import Panel
7
9
 
8
10
  from lightning_sdk.api.studio_api import _cloud_url
9
11
  from lightning_sdk.cli.ai_hub import _AIHub
12
+ from lightning_sdk.cli.configure import _Configure
13
+ from lightning_sdk.cli.connect import _Connect
10
14
  from lightning_sdk.cli.delete import _Delete
11
15
  from lightning_sdk.cli.download import _Downloads
12
16
  from lightning_sdk.cli.generate import _Generate
@@ -41,6 +45,8 @@ class StudioCLI:
41
45
  self.start = _Start()
42
46
  self.switch = _Switch()
43
47
  self.generate = _Generate()
48
+ self.connect = _Connect()
49
+ self.configure = _Configure()
44
50
 
45
51
  sys.excepthook = _notify_exception
46
52
 
@@ -62,7 +68,8 @@ class StudioCLI:
62
68
 
63
69
  def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None: # No
64
70
  """CLI won't show tracebacks, just print the exception message."""
65
- print(value)
71
+ console = Console()
72
+ console.print(Panel(value))
66
73
 
67
74
 
68
75
  def main_cli() -> None:
@@ -2,57 +2,34 @@ from typing import Optional
2
2
 
3
3
  from rich.console import Console
4
4
 
5
- from lightning_sdk import Studio
5
+ from lightning_sdk.cli.studios_menu import _StudiosMenu
6
6
 
7
7
 
8
- class _Generate:
8
+ class _Generate(_StudiosMenu):
9
9
  """Generate configs (such as ssh for studio) and print them to commandline."""
10
10
 
11
- console = Console()
12
-
13
- def _generate_ssh_config(self, name: str, studio_id: str) -> str:
14
- """Generate SSH config entry for the studio.
15
-
16
- Args:
17
- name: Studio name
18
- studio_id: Studio space ID
19
-
20
- Returns:
21
- str: SSH config entry
22
- """
23
- return f"""# ssh s_{studio_id}@ssh.lightning.ai
24
-
25
- Host {name}
26
- User s_{studio_id}
27
- Hostname ssh.lightning.ai
28
- IdentityFile ~/.ssh/lightning_rsa
11
+ @staticmethod
12
+ def _generate_ssh_config(key_path: str) -> str:
13
+ return f"""Host ssh.lightning.ai
14
+ IdentityFile {key_path}
29
15
  IdentitiesOnly yes
30
16
  ServerAliveInterval 15
31
17
  ServerAliveCountMax 4
32
18
  StrictHostKeyChecking no
33
- UserKnownHostsFile=/dev/null"""
19
+ UserKnownHostsFile=/dev/null
20
+ """
34
21
 
35
22
  def ssh(self, name: Optional[str] = None, teamspace: Optional[str] = None) -> None:
36
- """Get SSH config entry for a studio. Will start the studio if needed.
23
+ """Get SSH config entry for a studio.
37
24
 
38
25
  Args:
39
- name: The name of the studio to stop.
26
+ name: The name of the studio to obtain SSH config.
40
27
  If not specified, tries to infer from the environment (e.g. when run from within a Studio.)
41
28
  teamspace: The teamspace the studio is part of. Should be of format <OWNER>/<TEAMSPACE_NAME>.
42
29
  If not specified, tries to infer from the environment (e.g. when run from within a Studio.)
43
30
  """
44
- if teamspace:
45
- ts_splits = teamspace.split("/")
46
- if len(ts_splits) != 2:
47
- raise ValueError(f"Teamspace should be of format <OWNER>/<TEAMSPACE_NAME> but got {teamspace}")
48
- owner, teamspace = ts_splits
49
- else:
50
- owner, teamspace = None, None
51
-
52
- try:
53
- studio = Studio(name=name, teamspace=teamspace, org=owner, user=None, create_ok=False)
54
- except (RuntimeError, ValueError):
55
- studio = Studio(name=name, teamspace=teamspace, org=None, user=owner, create_ok=False)
31
+ studio = self._get_studio(name=name, teamspace=teamspace)
56
32
 
57
33
  # Print the SSH config
58
- self.console.print(self._generate_ssh_config(name, studio._studio.id))
34
+ conf = f"# ssh s_{studio._studio.id}@ssh.lightning.ai\n\n" + self._generate_ssh_config("~/.ssh/lightning_rsa")
35
+ Console().print(conf)
@@ -1,5 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
+ from rich.console import Console
4
+
3
5
  from lightning_sdk.cli.job_and_mmt_action import _JobAndMMTAction
4
6
 
5
7
 
@@ -16,7 +18,7 @@ class _Inspect(_JobAndMMTAction):
16
18
  If not specified can be selected interactively.
17
19
 
18
20
  """
19
- print(super().job(name=name, teamspace=teamspace).json())
21
+ (super().job(name=name, teamspace=teamspace).json())
20
22
 
21
23
  def mmt(self, name: Optional[str] = None, teamspace: Optional[str] = None) -> None:
22
24
  """Inspect a multi-machine job for further details as JSON.
@@ -28,4 +30,4 @@ class _Inspect(_JobAndMMTAction):
28
30
  If not specified can be selected interactively.
29
31
 
30
32
  """
31
- print(super().mmt(name=name, teamspace=teamspace).json())
33
+ Console().print(super().mmt(name=name, teamspace=teamspace).json())
@@ -1,5 +1,6 @@
1
1
  from typing import Dict, List, Optional
2
2
 
3
+ from rich.console import Console
3
4
  from simple_term_menu import TerminalMenu
4
5
 
5
6
  from lightning_sdk.cli.exceptions import StudioCliError
@@ -20,7 +21,7 @@ class _JobsMenu:
20
21
  if j.name == job:
21
22
  return j
22
23
 
23
- print("Could not find Job {job}, please select it from the list:")
24
+ Console().print("Could not find Job {job}, please select it from the list:")
24
25
  return self._get_job_from_interactive_menu(possible_jobs)
25
26
 
26
27
  @staticmethod