lightning-sdk 0.1.35__py3-none-any.whl → 0.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/ai_hub.py +27 -13
  3. lightning_sdk/api/ai_hub_api.py +49 -19
  4. lightning_sdk/api/utils.py +1 -0
  5. lightning_sdk/cli/entrypoint.py +2 -2
  6. lightning_sdk/cli/models.py +45 -15
  7. lightning_sdk/cli/run.py +72 -0
  8. lightning_sdk/job/base.py +48 -20
  9. lightning_sdk/job/job.py +21 -11
  10. lightning_sdk/job/v1.py +9 -9
  11. lightning_sdk/job/v2.py +6 -6
  12. lightning_sdk/lightning_cloud/cli/__main__.py +15 -13
  13. lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +5 -1
  14. lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +27 -1
  15. lightning_sdk/lightning_cloud/openapi/models/cloud_space_id_versionpublications_body1.py +27 -1
  16. lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +27 -1
  17. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  18. lightning_sdk/lightning_cloud/openapi/models/externalv1_cloud_space_instance_status.py +27 -1
  19. lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +27 -1
  20. lightning_sdk/lightning_cloud/openapi/models/v1_checkbox.py +29 -3
  21. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  22. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_release.py +27 -1
  23. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +1 -0
  24. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_status.py +27 -1
  25. lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +27 -1
  26. lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
  27. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +1 -27
  28. lightning_sdk/lightning_cloud/openapi/models/v1_input.py +29 -3
  29. lightning_sdk/lightning_cloud/openapi/models/v1_job.py +27 -1
  30. lightning_sdk/lightning_cloud/openapi/models/v1_magic_link_login_request.py +27 -1
  31. lightning_sdk/lightning_cloud/openapi/models/v1_metric_value.py +1 -27
  32. lightning_sdk/lightning_cloud/openapi/models/v1_metrics.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/v1_select.py +29 -3
  34. lightning_sdk/lightning_cloud/openapi/models/v1_update_user_request.py +27 -1
  35. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +105 -1
  36. lightning_sdk/lightning_cloud/rest_client.py +1 -4
  37. lightning_sdk/lightning_cloud/source_code/tar.py +1 -3
  38. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/METADATA +1 -2
  39. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/RECORD +43 -42
  40. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/LICENSE +0 -0
  41. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/WHEEL +0 -0
  42. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/entry_points.txt +0 -0
  43. {lightning_sdk-0.1.35.dist-info → lightning_sdk-0.1.37.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -27,5 +27,5 @@ __all__ = [
27
27
  "AIHub",
28
28
  ]
29
29
 
30
- __version__ = "0.1.35"
30
+ __version__ = "0.1.37"
31
31
  _check_version_and_prompt_upgrade(__version__)
lightning_sdk/ai_hub.py CHANGED
@@ -15,8 +15,17 @@ class AIHub:
15
15
  """An interface to interact with the AI Hub.
16
16
 
17
17
  Example:
18
- ai_hub = AIHub()
19
- api_list = ai_hub.list_apis()
18
+ from lightning_sdk import AIHub
19
+ hub = AIHub()
20
+
21
+ # List public API templates
22
+ api_list = hub.list_apis()
23
+
24
+ # Get detailed information about an API template
25
+ api_info = hub.api_info("temp_xxxx")
26
+
27
+ # Deploy an API template
28
+ deployment = hub.deploy("temp_xxxx")
20
29
  """
21
30
 
22
31
  def __init__(self) -> None:
@@ -28,7 +37,7 @@ class AIHub:
28
37
 
29
38
  Example:
30
39
  ai_hub = AIHub()
31
- api_info = ai_hub.api_info("api_12345")
40
+ api_info = ai_hub.api_info("temp_xxxx")
32
41
 
33
42
  Args:
34
43
  api_id: The ID of the API for which information is requested.
@@ -70,8 +79,11 @@ class AIHub:
70
79
  },
71
80
  }
72
81
 
73
- def list_apis(self, search: Optional[str] = None) -> List[Dict[str, str]]:
74
- """Get a list of AI Hub API templates.
82
+ def list_apis(
83
+ self,
84
+ search: Optional[str] = None,
85
+ ) -> List[Dict[str, str]]:
86
+ """Get a list of public AI Hub API templates.
75
87
 
76
88
  Example:
77
89
  ai_hub = AIHub()
@@ -93,9 +105,6 @@ class AIHub:
93
105
  "name": template.name,
94
106
  "description": template.description,
95
107
  "creator_username": template.creator_username,
96
- "created_on": template.creation_timestamp.strftime("%Y-%m-%d %H:%M:%S")
97
- if template.creation_timestamp
98
- else None,
99
108
  }
100
109
  results.append(result)
101
110
  return results
@@ -127,14 +136,18 @@ class AIHub:
127
136
  name: Optional[str] = None,
128
137
  teamspace: Optional[Union[str, "Teamspace"]] = None,
129
138
  org: Optional[Union[str, "Organization"]] = None,
130
- **kwargs: Dict[str, Any],
139
+ api_arguments: Optional[Dict[str, Any]] = None,
131
140
  ) -> Dict[str, Union[str, bool]]:
132
141
  """Deploy an API from the AI Hub.
133
142
 
134
143
  Example:
135
144
  from lightning_sdk import AIHub
136
- ai_hub = AIHub()
137
- deployment = ai_hub.deploy("temp_01jc37n6qpqkdptjpyep0z06hy", batch_size=10)
145
+ hub = AIHub()
146
+ deployment = hub.deploy("temp_xxxx")
147
+
148
+ # Using API arguments
149
+ api_arugments = {"model": "unitary/toxic-bert", "batch_size" 10, "token": "lit_xxxx"}
150
+ deployment = hub.deploy("temp_xxxx", api_arugments=api_arugments)
138
151
 
139
152
  Args:
140
153
  api_id: The ID of the API you want to deploy.
@@ -142,7 +155,7 @@ class AIHub:
142
155
  name: Name for the deployed API. Defaults to None.
143
156
  teamspace: The team or group for deployment. Defaults to None.
144
157
  org: The organization for deployment. Defaults to None.
145
- **kwargs: Additional keyword arguments for deployment.
158
+ api_arguments: Additional API argument, such as model name, or batch size.
146
159
 
147
160
  Returns:
148
161
  A dictionary containing the name of the deployed API,
@@ -155,8 +168,9 @@ class AIHub:
155
168
  teamspace = self._authenticate(teamspace, org)
156
169
  teamspace_id = teamspace.id
157
170
 
171
+ api_arguments = api_arguments or {}
158
172
  deployment = self._api.deploy_api(
159
- template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, **kwargs
173
+ template_id=api_id, cluster_id=cluster, project_id=teamspace_id, name=name, api_arguments=api_arguments
160
174
  )
161
175
  url = quote(f"{LIGHTNING_CLOUD_URL}/{teamspace._org.name}/{teamspace.name}/jobs/{deployment.name}", safe=":/()")
162
176
  print("Deployment available at:", url)
@@ -1,6 +1,5 @@
1
- import re
2
1
  import traceback
3
- from typing import List, Optional
2
+ from typing import Dict, List, Optional
4
3
 
5
4
  import backoff
6
5
 
@@ -8,7 +7,10 @@ from lightning_sdk.lightning_cloud.openapi.models import (
8
7
  CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs,
9
8
  V1Deployment,
10
9
  V1DeploymentTemplate,
11
- V1ParameterizationSpec,
10
+ V1DeploymentTemplateParameter,
11
+ V1DeploymentTemplateParameterPlacement,
12
+ V1DeploymentTemplateParameterType,
13
+ V1JobSpec,
12
14
  )
13
15
  from lightning_sdk.lightning_cloud.openapi.models.v1_deployment_template_gallery_response import (
14
16
  V1DeploymentTemplateGalleryResponse,
@@ -37,31 +39,59 @@ class AIHubApi:
37
39
  ).templates
38
40
 
39
41
  @staticmethod
40
- def _parse_and_update_args(cmd: str, **kwargs: dict) -> list:
41
- """Parse the command and update the arguments with the provided kwargs.
42
+ def _update_parameters(
43
+ job: V1JobSpec, placements: List[V1DeploymentTemplateParameterPlacement], pattern: str, value: str
44
+ ) -> None:
45
+ for placement in placements:
46
+ if placement == V1DeploymentTemplateParameterPlacement.COMMAND:
47
+ job.command = job.command.replace(pattern, str(value))
48
+ if placement == V1DeploymentTemplateParameterPlacement.ENTRYPOINT:
49
+ job.entrypoint = job.entrypoint.replace(pattern, str(value))
42
50
 
43
- >>> _parse_and_update_args("--arg1 1 --arg2=2", arg1=3)
44
- ['--arg1 3']
45
- """
46
- keys = [key.lstrip("-") for key in re.findall(r"--\w+", cmd)]
47
- arguments = {}
48
- for key in keys:
49
- if key in kwargs:
50
- arguments[key] = kwargs[key]
51
- return [f"--{k} {v}" for k, v in arguments.items()]
51
+ if placement == V1DeploymentTemplateParameterPlacement.ENV:
52
+ for e in job.env:
53
+ if e.value == pattern:
54
+ e.value = str(value)
52
55
 
53
56
  @staticmethod
54
- def _resolve_api_arguments(parameter_spec: "V1ParameterizationSpec", **kwargs: dict) -> str:
55
- return " ".join(AIHubApi._parse_and_update_args(parameter_spec.command, **kwargs))
57
+ def _set_parameters(
58
+ job: V1JobSpec, parameters: List[V1DeploymentTemplateParameter], api_arguments: Dict[str, str]
59
+ ) -> V1JobSpec:
60
+ for p in parameters:
61
+ if p.name not in api_arguments:
62
+ if p.type == V1DeploymentTemplateParameterType.INPUT and p.input and p.input.default_value:
63
+ api_arguments[p.name] = p.input.default_value
64
+
65
+ if p.type == V1DeploymentTemplateParameterType.SELECT and p.select and len(p.select.options) > 0:
66
+ api_arguments[p.name] = p.select.options[0]
67
+
68
+ if p.type == V1DeploymentTemplateParameterType.CHECKBOX and p.checkbox:
69
+ api_arguments[p.name] = (
70
+ (p.checkbox.true_value or "True")
71
+ if p.checkbox.is_checked
72
+ else (p.checkbox.false_value or "False")
73
+ )
74
+
75
+ for p in parameters:
76
+ name = p.name
77
+ pattern = f"${{{name}}}"
78
+ if name in api_arguments:
79
+ AIHubApi._update_parameters(job, p.placements, pattern, api_arguments[name])
80
+ elif not p.required:
81
+ AIHubApi._update_parameters(job, p.placements, pattern, "")
82
+ else:
83
+ raise ValueError(f"API reqires argument '{p.name}' but is not provided with api_arguments.")
84
+
85
+ return job
56
86
 
57
87
  def deploy_api(
58
- self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], **kwargs: dict
88
+ self, template_id: str, project_id: str, cluster_id: str, name: Optional[str], api_arguments: Dict[str, str]
59
89
  ) -> V1Deployment:
60
90
  template = self._client.deployment_templates_service_get_deployment_template(template_id)
61
91
  name = name or template.name
62
92
  template.spec_v2.endpoint.id = None
63
- command = self._resolve_api_arguments(template.parameter_spec, **kwargs)
64
- template.spec_v2.job.command = command
93
+
94
+ AIHubApi._set_parameters(template.spec_v2.job, template.parameter_spec.parameters, api_arguments)
65
95
  return self._client.jobs_service_create_deployment(
66
96
  project_id=project_id,
67
97
  body=CreateDeploymentRequestDefinesASpecForTheJobThatAllowsForAutoscalingJobs(
@@ -509,6 +509,7 @@ def _download_model_files(
509
509
  pbar = tqdm(
510
510
  desc=f"Downloading {version}",
511
511
  unit="B",
512
+ total=float(response.size_bytes),
512
513
  unit_scale=True,
513
514
  unit_divisor=1000,
514
515
  )
@@ -5,6 +5,7 @@ from lightning_sdk.api.studio_api import _cloud_url
5
5
  from lightning_sdk.cli.ai_hub import _AIHub
6
6
  from lightning_sdk.cli.download import _Downloads
7
7
  from lightning_sdk.cli.legacy import _LegacyLightningCLI
8
+ from lightning_sdk.cli.run import _Run
8
9
  from lightning_sdk.cli.upload import _Uploads
9
10
  from lightning_sdk.lightning_cloud.login import Auth
10
11
 
@@ -19,8 +20,7 @@ class StudioCLI:
19
20
  self.upload = _Uploads()
20
21
  self.aihub = _AIHub()
21
22
 
22
- if _LIGHTNING_AVAILABLE:
23
- self.run = _LegacyLightningCLI()
23
+ self.run = _Run(legacy_run=_LegacyLightningCLI() if _LIGHTNING_AVAILABLE else None)
24
24
 
25
25
  def login(self) -> None:
26
26
  """Login to Lightning AI Studios."""
@@ -1,8 +1,11 @@
1
- from typing import Tuple
1
+ import os
2
+ from typing import Any, Dict, List, Tuple
2
3
 
3
4
  from lightning_sdk.api import OrgApi, UserApi
4
5
  from lightning_sdk.cli.exceptions import StudioCliError
6
+ from lightning_sdk.lightning_cloud.openapi.models import V1Membership, V1OwnerType
5
7
  from lightning_sdk.teamspace import Teamspace
8
+ from lightning_sdk.user import User
6
9
  from lightning_sdk.utils.resolve import _get_authed_user
7
10
 
8
11
 
@@ -17,22 +20,49 @@ def _parse_model_name(name: str) -> Tuple[str, str, str]:
17
20
  return org_name, teamspace_name, model_name
18
21
 
19
22
 
23
+ def _get_teamspace_and_path(
24
+ ts: V1Membership, org_api: OrgApi, user_api: UserApi, authed_user: User
25
+ ) -> Tuple[str, Dict[str, Any]]:
26
+ if ts.owner_type == V1OwnerType.ORGANIZATION:
27
+ org = org_api._get_org_by_id(ts.owner_id)
28
+ return f"{org.name}/{ts.name}", {"name": ts.name, "org": org.name}
29
+
30
+ if ts.owner_type == V1OwnerType.USER and ts.owner_id != authed_user.id:
31
+ user = user_api._get_user_by_id(ts.owner_id) # todo: check also the name
32
+ return f"{user.username}/{ts.name}", {"name": ts.name, "user": User(name=user.username)}
33
+
34
+ if ts.owner_type == V1OwnerType.USER:
35
+ return f"{authed_user.name}/{ts.name}", {"name": ts.name, "user": authed_user}
36
+
37
+ raise StudioCliError(f"Unknown organization type {ts.owner_type}")
38
+
39
+
40
+ def _list_teamspaces() -> List[str]:
41
+ org_api = OrgApi()
42
+ user_api = UserApi()
43
+ authed_user = _get_authed_user()
44
+
45
+ return [
46
+ _get_teamspace_and_path(ts, org_api, user_api, authed_user)[0]
47
+ for ts in user_api._get_all_teamspace_memberships("")
48
+ ]
49
+
50
+
20
51
  def _get_teamspace(name: str, organization: str) -> Teamspace:
21
52
  """Get a Teamspace object from the SDK."""
22
53
  org_api = OrgApi()
23
- user = _get_authed_user()
24
- teamspaces = {}
25
- for ts in UserApi()._get_all_teamspace_memberships(""):
26
- if ts.owner_type == "organization":
27
- org = org_api._get_org_by_id(ts.owner_id)
28
- teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
29
- elif ts.owner_type == "user": # todo: check also the name
30
- teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
31
- else:
32
- raise StudioCliError(f"Unknown organization type {ts.owner_type}")
54
+ user_api = UserApi()
55
+ authed_user = _get_authed_user()
33
56
 
34
57
  requested_teamspace = f"{organization}/{name}".lower()
35
- if requested_teamspace not in teamspaces:
36
- options = "\n\t".join(teamspaces.keys())
37
- raise StudioCliError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: \n\t{options}")
38
- return Teamspace(**teamspaces[requested_teamspace])
58
+
59
+ for ts in user_api._get_all_teamspace_memberships(""):
60
+ if ts.name != name:
61
+ continue
62
+
63
+ teamspace_path, teamspace = _get_teamspace_and_path(ts, org_api, user_api, authed_user)
64
+ if requested_teamspace == teamspace_path:
65
+ return Teamspace(**teamspace)
66
+
67
+ options = f"{os.linesep}\t".join(_list_teamspaces())
68
+ raise StudioCliError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: {os.linesep}\t{options}")
@@ -0,0 +1,72 @@
1
+ from typing import TYPE_CHECKING, Dict, Optional
2
+
3
+ from lightning_sdk.job import Job
4
+ from lightning_sdk.machine import Machine
5
+
6
+ if TYPE_CHECKING:
7
+ from lightning_sdk.cli.legacy import _LegacyLightningCLI
8
+
9
+ _MACHINE_VALUES = tuple([machine.value for machine in Machine])
10
+
11
+
12
+ class _Run:
13
+ """Run async workloads on the Lightning AI platform."""
14
+
15
+ def __init__(self, legacy_run: Optional["_LegacyLightningCLI"] = None) -> None:
16
+ if legacy_run is not None:
17
+ self.app = legacy_run.app
18
+ self.model = legacy_run.model
19
+
20
+ # Need to set the docstring here for f-strings to work.
21
+ # Sadly this is the only way to really show options as f-strings are not allowed as docstrings directly
22
+ # and fire does not show values for literals, just that it is a literal.
23
+ docstr = f"""Run async workloads using a docker image or a compute environment from your studio.
24
+
25
+ Args:
26
+ name: The name of the job. Needs to be unique within the teamspace.
27
+ machine: The machine type to run the job on. One of {", ".join(_MACHINE_VALUES)}.
28
+ command: The command to run inside your job. Required if using a studio. Optional if using an image.
29
+ If not provided for images, will run the container entrypoint and default command.
30
+ studio: The studio env to run the job with. Mutually exclusive with image.
31
+ image: The docker image to run the job with. Mutually exclusive with studio.
32
+ teamspace: The teamspace the job should be associated with. Defaults to the current teamspace.
33
+ org: The organization owning the teamspace (if any). Defaults to the current organization.
34
+ user: The user owning the teamspace (if any). Defaults to the current user.
35
+ cluster: The cluster to run the job on. Defaults to the studio cluster if running with studio compute env.
36
+ If not provided will fall back to the teamspaces default cluster.
37
+ env: Environment variables to set inside the job.
38
+ interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
39
+ """
40
+ self.job.__func__.__doc__ = docstr
41
+
42
+ # TODO: sadly, fire displays both Optional[type] and Union[type, None] as Optional[Optional]
43
+ # see https://github.com/google/python-fire/pull/513
44
+ # might need to move to different cli library
45
+ def job(
46
+ self,
47
+ name: str,
48
+ machine: str,
49
+ command: Optional[str] = None,
50
+ studio: Optional[str] = None,
51
+ image: Optional[str] = None,
52
+ teamspace: Optional[str] = None,
53
+ org: Optional[str] = None,
54
+ user: Optional[str] = None,
55
+ cluster: Optional[str] = None,
56
+ env: Optional[Dict[str, str]] = None,
57
+ interruptible: bool = False,
58
+ ) -> None:
59
+ machine_enum = Machine(machine.upper())
60
+ Job.run(
61
+ name=name,
62
+ machine=machine_enum,
63
+ command=command,
64
+ studio=studio,
65
+ image=image,
66
+ teamspace=teamspace,
67
+ org=org,
68
+ user=user,
69
+ cluster=cluster,
70
+ env=env,
71
+ interruptible=interruptible,
72
+ )
lightning_sdk/job/base.py CHANGED
@@ -16,15 +16,20 @@ class _BaseJob(ABC):
16
16
  def __init__(
17
17
  self,
18
18
  name: str,
19
- teamspace: Union[str, "Teamspace"] = None,
20
- org: Union[str, "Organization"] = None,
21
- user: Union[str, "User"] = None,
22
- cluster: Optional[str] = None,
19
+ teamspace: Union[str, "Teamspace", None] = None,
20
+ org: Union[str, "Organization", None] = None,
21
+ user: Union[str, "User", None] = None,
23
22
  *,
24
23
  _fetch_job: bool = True,
25
24
  ) -> None:
26
- self._teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
27
- self._cluster = cluster
25
+ _teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
26
+ if _teamspace is None:
27
+ raise ValueError(
28
+ "Cannot resolve the teamspace from provided arguments."
29
+ f" Got teamspace={teamspace}, org={org}, user={user}."
30
+ )
31
+ else:
32
+ self._teamspace = _teamspace
28
33
  self._name = name
29
34
  self._job = None
30
35
 
@@ -37,18 +42,25 @@ class _BaseJob(ABC):
37
42
  name: str,
38
43
  machine: "Machine",
39
44
  command: Optional[str] = None,
40
- studio: Optional["Studio"] = None,
45
+ studio: Union["Studio", str, None] = None,
41
46
  image: Optional[str] = None,
42
- teamspace: Union[str, "Teamspace"] = None,
43
- org: Union[str, "Organization"] = None,
44
- user: Union[str, "User"] = None,
47
+ teamspace: Union[str, "Teamspace", None] = None,
48
+ org: Union[str, "Organization", None] = None,
49
+ user: Union[str, "User", None] = None,
45
50
  cluster: Optional[str] = None,
46
51
  env: Optional[Dict[str, str]] = None,
47
52
  interruptible: bool = False,
48
53
  ) -> "_BaseJob":
54
+ from lightning_sdk.studio import Studio
55
+
49
56
  if not name:
50
57
  raise ValueError("A job needs to have a name!")
51
- if studio is not None:
58
+
59
+ if image is None:
60
+ if not isinstance(studio, Studio):
61
+ studio = Studio(name=studio, teamspace=teamspace, org=org, user=user, cluster=cluster, create_ok=False)
62
+
63
+ # studio is a Studio instance at this point
52
64
  if teamspace is None:
53
65
  teamspace = studio.teamspace
54
66
  else:
@@ -60,11 +72,30 @@ class _BaseJob(ABC):
60
72
  "Can only run jobs with Studio envs in the teamspace of that Studio."
61
73
  )
62
74
 
63
- # TODO: resolve studio and support string studios
64
- # TODO: assertions for studio to be on cluster
65
- # TODO: if cluster is not provided use studio cluster if provided, otherwise use default cluster from teamspace
66
- inst = cls(name=name, teamspace=teamspace, org=org, user=user, cluster=cluster, _fetch_job=False)
67
- inst._submit(machine=machine, command=command, studio=studio, image=image, env=env, interruptible=interruptible)
75
+ if cluster is None:
76
+ cluster = studio.cluster
77
+
78
+ if cluster != studio.cluster:
79
+ raise ValueError(
80
+ "Studio cluster does not match provided cluster. "
81
+ "Can only run jobs with Studio envs in the same cluster."
82
+ )
83
+ else:
84
+ if studio is not None:
85
+ raise RuntimeError(
86
+ "image and studio are mutually exclusive as both define the environment to run the job in"
87
+ )
88
+
89
+ inst = cls(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=False)
90
+ inst._submit(
91
+ machine=machine,
92
+ cluster=cluster,
93
+ command=command,
94
+ studio=studio,
95
+ image=image,
96
+ env=env,
97
+ interruptible=interruptible,
98
+ )
68
99
  return inst
69
100
 
70
101
  @abstractmethod
@@ -76,6 +107,7 @@ class _BaseJob(ABC):
76
107
  image: Optional[str] = None,
77
108
  env: Optional[Dict[str, str]] = None,
78
109
  interruptible: bool = False,
110
+ cluster: Optional[str] = None,
79
111
  ) -> None:
80
112
  """Submits a job and updates the internal _job attribute as well as the _name attribute."""
81
113
 
@@ -123,7 +155,3 @@ class _BaseJob(ABC):
123
155
  @property
124
156
  def teamspace(self) -> "Teamspace":
125
157
  return self._teamspace
126
-
127
- @property
128
- def cluster(self) -> Optional[str]:
129
- return self._cluster
lightning_sdk/job/job.py CHANGED
@@ -28,17 +28,20 @@ class Job(_BaseJob):
28
28
  def __init__(
29
29
  self,
30
30
  name: str,
31
- teamspace: Union[str, "Teamspace"] = None,
32
- org: Union[str, "Organization"] = None,
33
- user: Union[str, "User"] = None,
34
- cluster: Optional[str] = None,
31
+ teamspace: Union[str, "Teamspace", None] = None,
32
+ org: Union[str, "Organization", None] = None,
33
+ user: Union[str, "User", None] = None,
35
34
  *,
36
35
  _fetch_job: bool = True,
37
36
  ) -> None:
38
37
  internal_job_cls = _JobV2 if _has_jobs_v2() else _JobV1
39
38
 
40
39
  self._internal_job = internal_job_cls(
41
- name=name, teamspace=teamspace, org=org, user=user, cluster=cluster, _fetch_job=_fetch_job
40
+ name=name,
41
+ teamspace=teamspace,
42
+ org=org,
43
+ user=user,
44
+ _fetch_job=_fetch_job,
42
45
  )
43
46
 
44
47
  @classmethod
@@ -47,11 +50,11 @@ class Job(_BaseJob):
47
50
  name: str,
48
51
  machine: "Machine",
49
52
  command: Optional[str] = None,
50
- studio: Optional["Studio"] = None,
51
- image: Optional[str] = None,
52
- teamspace: Union[str, "Teamspace"] = None,
53
- org: Union[str, "Organization"] = None,
54
- user: Union[str, "User"] = None,
53
+ studio: Union["Studio", str, None] = None,
54
+ image: Union[str, None] = None,
55
+ teamspace: Union[str, "Teamspace", None] = None,
56
+ org: Union[str, "Organization", None] = None,
57
+ user: Union[str, "User", None] = None,
55
58
  cluster: Optional[str] = None,
56
59
  env: Optional[Dict[str, str]] = None,
57
60
  interruptible: bool = False,
@@ -81,9 +84,16 @@ class Job(_BaseJob):
81
84
  image: Optional[str] = None,
82
85
  env: Optional[Dict[str, str]] = None,
83
86
  interruptible: bool = False,
87
+ cluster: Optional[str] = None,
84
88
  ) -> None:
85
89
  return self._internal_job._submit(
86
- machine=machine, command=command, studio=studio, image=image, env=env, interruptible=interruptible
90
+ machine=machine,
91
+ cluster=cluster,
92
+ command=command,
93
+ studio=studio,
94
+ image=image,
95
+ env=env,
96
+ interruptible=interruptible,
87
97
  )
88
98
 
89
99
  def stop(self) -> None:
lightning_sdk/job/v1.py CHANGED
@@ -20,15 +20,14 @@ class _JobV1(_BaseJob):
20
20
  def __init__(
21
21
  self,
22
22
  name: str,
23
- teamspace: Union[str, "Teamspace"] = None,
24
- org: Union[str, "Organization"] = None,
25
- user: Union[str, "User"] = None,
26
- cluster: Optional[str] = None,
23
+ teamspace: Union[str, "Teamspace", None] = None,
24
+ org: Union[str, "Organization", None] = None,
25
+ user: Union[str, "User", None] = None,
27
26
  *,
28
27
  _fetch_job: bool = True,
29
28
  ) -> None:
30
29
  self._job_api = JobApiV1()
31
- super().__init__(name=name, teamspace=teamspace, org=org, user=user, cluster=cluster, _fetch_job=_fetch_job)
30
+ super().__init__(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=_fetch_job)
32
31
 
33
32
  @classmethod
34
33
  def run(
@@ -37,9 +36,9 @@ class _JobV1(_BaseJob):
37
36
  machine: "Machine",
38
37
  command: str,
39
38
  studio: "Studio",
40
- teamspace: Union[str, "Teamspace"] = None,
41
- org: Union[str, "Organization"] = None,
42
- user: Union[str, "User"] = None,
39
+ teamspace: Union[str, "Teamspace", None] = None,
40
+ org: Union[str, "Organization", None] = None,
41
+ user: Union[str, "User", None] = None,
43
42
  cluster: Optional[str] = None,
44
43
  interruptible: bool = False,
45
44
  ) -> "_BaseJob":
@@ -65,6 +64,7 @@ class _JobV1(_BaseJob):
65
64
  image: Optional[str] = None,
66
65
  env: Optional[Dict[str, str]] = None,
67
66
  interruptible: bool = False,
67
+ cluster: Optional[str] = None,
68
68
  ) -> None:
69
69
  if studio is None:
70
70
  raise ValueError("Studio is required for submitting jobs")
@@ -85,7 +85,7 @@ class _JobV1(_BaseJob):
85
85
  command=command,
86
86
  studio_id=studio._studio.id,
87
87
  teamspace_id=self._teamspace.id,
88
- cluster_id=self._cluster,
88
+ cluster_id=cluster,
89
89
  machine=machine,
90
90
  interruptible=interruptible,
91
91
  )
lightning_sdk/job/v2.py CHANGED
@@ -16,15 +16,14 @@ class _JobV2(_BaseJob):
16
16
  def __init__(
17
17
  self,
18
18
  name: str,
19
- teamspace: Union[str, "Teamspace"] = None,
20
- org: Union[str, "Organization"] = None,
21
- user: Union[str, "User"] = None,
22
- cluster: Optional[str] = None,
19
+ teamspace: Union[str, "Teamspace", None] = None,
20
+ org: Union[str, "Organization", None] = None,
21
+ user: Union[str, "User", None] = None,
23
22
  *,
24
23
  _fetch_job: bool = True,
25
24
  ) -> None:
26
25
  self._job_api = JobApiV2()
27
- super().__init__(name=name, teamspace=teamspace, org=org, user=user, cluster=cluster, _fetch_job=_fetch_job)
26
+ super().__init__(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=_fetch_job)
28
27
 
29
28
  def _submit(
30
29
  self,
@@ -34,6 +33,7 @@ class _JobV2(_BaseJob):
34
33
  image: Optional[str] = None,
35
34
  env: Optional[Dict[str, str]] = None,
36
35
  interruptible: bool = False,
36
+ cluster: Optional[str] = None,
37
37
  ) -> None:
38
38
  # Command is required if Studio is provided to know what to run
39
39
  # Image is mutually exclusive with Studio
@@ -55,7 +55,7 @@ class _JobV2(_BaseJob):
55
55
  submitted = self._job_api.submit_job(
56
56
  name=self.name,
57
57
  command=command,
58
- cluster_id=self._cluster,
58
+ cluster_id=cluster,
59
59
  teamspace_id=self._teamspace.id,
60
60
  studio_id=studio_id,
61
61
  image=image,
@@ -1,16 +1,9 @@
1
- import click
2
1
 
3
2
 
4
- @click.group()
5
- def main():
6
- pass
7
-
8
-
9
- @main.command()
10
- def login():
3
+ def login() -> None:
11
4
  """Authorize the CLI to access Grid AI resources for a particular user.
12
- Use login command to force authenticate,
13
- a web browser will open to complete the authentication.
5
+
6
+ Use login command to force authenticate, a web browser will open to complete the authentication.
14
7
  """
15
8
  from lightning_sdk.lightning_cloud.login import Auth # local to avoid circular import
16
9
 
@@ -18,10 +11,19 @@ def login():
18
11
  auth.clear()
19
12
  auth._run_server()
20
13
 
21
-
22
- @main.command()
23
- def logout():
14
+ def logout() -> None:
24
15
  """Logout from LightningCloud"""
25
16
  from lightning_sdk.lightning_cloud.login import Auth # local to avoid circular import
26
17
 
27
18
  Auth.clear()
19
+
20
+
21
+ def main() -> None:
22
+ """CLI entrypoint."""
23
+ from fire import Fire
24
+
25
+ Fire({"login": login, "logout": logout})
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()