lightning-sdk 0.1.48__py3-none-any.whl → 0.1.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. lightning_sdk/__init__.py +3 -1
  2. lightning_sdk/api/job_api.py +13 -6
  3. lightning_sdk/api/lit_container_api.py +37 -1
  4. lightning_sdk/api/mmt_api.py +12 -6
  5. lightning_sdk/api/utils.py +7 -0
  6. lightning_sdk/cli/download.py +20 -1
  7. lightning_sdk/cli/entrypoint.py +11 -0
  8. lightning_sdk/cli/list.py +60 -2
  9. lightning_sdk/cli/run.py +19 -4
  10. lightning_sdk/cli/upload.py +32 -1
  11. lightning_sdk/job/base.py +23 -4
  12. lightning_sdk/job/job.py +4 -3
  13. lightning_sdk/job/v1.py +4 -4
  14. lightning_sdk/job/v2.py +7 -10
  15. lightning_sdk/job/work.py +2 -2
  16. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +1 -29
  17. lightning_sdk/lightning_cloud/openapi/models/v1_lambda_labs_direct_v1.py +31 -3
  18. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +27 -1
  19. lightning_sdk/lightning_cloud/openapi/models/v1_vultr_direct_v1.py +27 -1
  20. lightning_sdk/lit_container.py +40 -0
  21. lightning_sdk/mmt/base.py +22 -5
  22. lightning_sdk/mmt/mmt.py +5 -3
  23. lightning_sdk/mmt/v1.py +5 -3
  24. lightning_sdk/mmt/v2.py +11 -10
  25. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/METADATA +1 -1
  26. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/RECORD +30 -30
  27. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/LICENSE +0 -0
  28. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/WHEEL +0 -0
  29. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/entry_points.txt +0 -0
  30. {lightning_sdk-0.1.48.dist-info → lightning_sdk-0.1.50.dist-info}/top_level.txt +0 -0
lightning_sdk/__init__.py CHANGED
@@ -4,6 +4,7 @@ from lightning_sdk.constants import __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__ # noq
4
4
  from lightning_sdk.helpers import _check_version_and_prompt_upgrade
5
5
  from lightning_sdk.job import Job
6
6
  from lightning_sdk.machine import Machine
7
+ from lightning_sdk.mmt import MMT
7
8
  from lightning_sdk.organization import Organization
8
9
  from lightning_sdk.plugin import JobsPlugin, MultiMachineTrainingPlugin, Plugin, SlurmJobsPlugin
9
10
  from lightning_sdk.status import Status
@@ -15,6 +16,7 @@ __all__ = [
15
16
  "Job",
16
17
  "JobsPlugin",
17
18
  "Machine",
19
+ "MMT",
18
20
  "MultiMachineTrainingPlugin",
19
21
  "Organization",
20
22
  "Plugin",
@@ -27,5 +29,5 @@ __all__ = [
27
29
  "AIHub",
28
30
  ]
29
31
 
30
- __version__ = "0.1.48"
32
+ __version__ = "0.1.50"
31
33
  _check_version_and_prompt_upgrade(__version__)
@@ -1,11 +1,11 @@
1
1
  import time
2
- from typing import TYPE_CHECKING, Dict, List, Optional
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
3
3
  from urllib.request import urlopen
4
4
 
5
5
  from lightning_sdk.api.utils import (
6
6
  _COMPUTE_NAME_TO_MACHINE,
7
- _MACHINE_TO_COMPUTE_NAME,
8
7
  _create_app,
8
+ _machine_to_compute_name,
9
9
  remove_datetime_prefix,
10
10
  )
11
11
  from lightning_sdk.api.utils import (
@@ -120,7 +120,7 @@ class JobApiV1:
120
120
  studio_id: str,
121
121
  teamspace_id: str,
122
122
  cloud_account: str,
123
- machine: Machine,
123
+ machine: Union[Machine, str],
124
124
  interruptible: bool,
125
125
  ) -> Externalv1LightningappInstance:
126
126
  """Creates an arbitrary app."""
@@ -130,7 +130,7 @@ class JobApiV1:
130
130
  teamspace_id=teamspace_id,
131
131
  cloud_account=cloud_account,
132
132
  plugin_type="job",
133
- compute=_MACHINE_TO_COMPUTE_NAME[machine],
133
+ compute=_machine_to_compute_name(machine),
134
134
  name=name,
135
135
  entrypoint=command,
136
136
  interruptible=interruptible,
@@ -180,6 +180,10 @@ class JobApiV1:
180
180
 
181
181
  raise RuntimeError("Could not extract command from app")
182
182
 
183
+ def get_total_cost(self, job: Externalv1LightningappInstance) -> float:
184
+ status: V1LightningappInstanceStatus = job.status
185
+ return status.total_cost
186
+
183
187
 
184
188
  class JobApiV2:
185
189
  # these are stages the job can be in.
@@ -205,7 +209,7 @@ class JobApiV2:
205
209
  teamspace_id: str,
206
210
  studio_id: Optional[str],
207
211
  image: Optional[str],
208
- machine: Machine,
212
+ machine: Union[Machine, str],
209
213
  interruptible: bool,
210
214
  env: Optional[Dict[str, str]],
211
215
  image_credentials: Optional[str],
@@ -219,7 +223,7 @@ class JobApiV2:
219
223
  for k, v in env.items():
220
224
  env_vars.append(V1EnvVar(name=k, value=v))
221
225
 
222
- instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
226
+ instance_name = _machine_to_compute_name(machine)
223
227
 
224
228
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
225
229
 
@@ -337,3 +341,6 @@ class JobApiV2:
337
341
  return _COMPUTE_NAME_TO_MACHINE.get(
338
342
  instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
339
343
  )
344
+
345
+ def get_total_cost(self, job: V1Job) -> float:
346
+ return job.total_cost
@@ -1,13 +1,23 @@
1
- from typing import List
1
+ from typing import Generator, List
2
2
 
3
+ from lightning_sdk.api.utils import _get_registry_url
3
4
  from lightning_sdk.lightning_cloud.openapi.models import V1DeleteLitRepositoryResponse
4
5
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
6
+ from lightning_sdk.teamspace import Teamspace
5
7
 
6
8
 
7
9
  class LitContainerApi:
8
10
  def __init__(self) -> None:
9
11
  self._client = LightningClient(max_tries=3)
10
12
 
13
+ import docker
14
+
15
+ try:
16
+ self._docker_client = docker.from_env()
17
+ self._docker_client.ping()
18
+ except docker.errors.DockerException as e:
19
+ raise RuntimeError(f"Failed to connect to Docker daemon: {e!s}. Is Docker running?") from None
20
+
11
21
  def list_containers(self, project_id: str) -> List:
12
22
  project = self._client.lit_registry_service_get_lit_project_registry(project_id)
13
23
  return project.repositories
@@ -17,3 +27,29 @@ class LitContainerApi:
17
27
  return self._client.lit_registry_service_delete_lit_repository(project_id, container)
18
28
  except Exception as ex:
19
29
  raise ValueError(f"Could not delete container {container} from project {project_id}") from ex
30
+
31
+ def upload_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
32
+ import docker
33
+
34
+ try:
35
+ self._docker_client.images.get(container)
36
+ except docker.errors.ImageNotFound:
37
+ raise ValueError(f"Container {container} does not exist") from None
38
+
39
+ registry_url = _get_registry_url()
40
+ repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
41
+ tagged = self._docker_client.api.tag(container, repository, tag)
42
+ if not tagged:
43
+ raise ValueError(f"Could not tag container {container} with {repository}:{tag}")
44
+ return self._docker_client.api.push(repository, stream=True, decode=True)
45
+
46
+ def download_container(self, container: str, teamspace: Teamspace, tag: str) -> Generator[str, None, None]:
47
+ import docker
48
+
49
+ registry_url = _get_registry_url()
50
+ repository = f"{registry_url}/lit-container/{teamspace.owner.name}/{teamspace.name}/{container}"
51
+ try:
52
+ self._docker_client.images.pull(repository, tag=tag)
53
+ except docker.errors.APIError as e:
54
+ raise ValueError(f"Could not pull container {container} from {repository}:{tag}") from e
55
+ return self._docker_client.api.tag(repository, container, tag)
@@ -1,12 +1,12 @@
1
1
  import json
2
2
  import time
3
- from typing import TYPE_CHECKING, Dict, List, Optional
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
4
4
 
5
5
  from lightning_sdk.api.job_api import JobApiV1
6
6
  from lightning_sdk.api.utils import (
7
7
  _COMPUTE_NAME_TO_MACHINE,
8
- _MACHINE_TO_COMPUTE_NAME,
9
8
  _create_app,
9
+ _machine_to_compute_name,
10
10
  )
11
11
  from lightning_sdk.api.utils import (
12
12
  _get_cloud_url as _cloud_url,
@@ -43,13 +43,13 @@ class MMTApiV1(JobApiV1):
43
43
  cloud_account: Optional[str],
44
44
  teamspace_id: str,
45
45
  studio_id: str,
46
- machine: Machine,
46
+ machine: Union[Machine, str],
47
47
  interruptible: bool,
48
48
  strategy: str,
49
49
  ) -> Externalv1LightningappInstance:
50
50
  """Creates a multi-machine job with given commands."""
51
51
  distributed_args = {
52
- "cloud_compute": _MACHINE_TO_COMPUTE_NAME[machine],
52
+ "cloud_compute": _machine_to_compute_name(machine),
53
53
  "num_instances": num_machines,
54
54
  "strategy": strategy,
55
55
  }
@@ -80,7 +80,7 @@ class MMTApiV2:
80
80
  teamspace_id: str,
81
81
  studio_id: Optional[str],
82
82
  image: Optional[str],
83
- machine: Machine,
83
+ machine: Union[Machine, str],
84
84
  interruptible: bool,
85
85
  env: Optional[Dict[str, str]],
86
86
  image_credentials: Optional[str],
@@ -94,7 +94,7 @@ class MMTApiV2:
94
94
  for k, v in env.items():
95
95
  env_vars.append(V1EnvVar(name=k, value=v))
96
96
 
97
- instance_name = _MACHINE_TO_COMPUTE_NAME[machine]
97
+ instance_name = _machine_to_compute_name(machine)
98
98
 
99
99
  run_id = __GLOBAL_LIGHTNING_UNIQUE_IDS_STORE__[studio_id] if studio_id is not None else ""
100
100
 
@@ -203,3 +203,9 @@ class MMTApiV2:
203
203
  return _COMPUTE_NAME_TO_MACHINE.get(
204
204
  instance_type, _COMPUTE_NAME_TO_MACHINE.get(instance_name, instance_type or instance_name)
205
205
  )
206
+
207
+ def get_total_cost(self, job: V1MultiMachineJob) -> float:
208
+ return job.total_cost
209
+
210
+ def get_num_machines(self, job: V1MultiMachineJob) -> int:
211
+ return job.machines
@@ -354,6 +354,7 @@ def _machine_to_compute_name(machine: Union[Machine, str]) -> str:
354
354
  _COMPUTE_NAME_TO_MACHINE: Dict[str, Machine] = {v: k for k, v in _MACHINE_TO_COMPUTE_NAME.items()}
355
355
 
356
356
  _DEFAULT_CLOUD_URL = "https://lightning.ai"
357
+ _DEFAULT_REGISTRY_URL = "litcr.io"
357
358
 
358
359
 
359
360
  def _get_cloud_url() -> str:
@@ -362,6 +363,12 @@ def _get_cloud_url() -> str:
362
363
  return cloud_url
363
364
 
364
365
 
366
+ def _get_registry_url() -> str:
367
+ registry_url = os.environ.get("LIGHTNING_REGISTRY_URL", _DEFAULT_REGISTRY_URL)
368
+ os.environ["LIGHTNING_REGISTRY_URL"] = registry_url
369
+ return registry_url
370
+
371
+
365
372
  def _sanitize_studio_remote_path(path: str, studio_id: str) -> str:
366
373
  return f"/cloudspaces/{studio_id}/code/content/{path.replace('/teamspace/studios/this_studio/', '')}"
367
374
 
@@ -3,14 +3,18 @@ import re
3
3
  from pathlib import Path
4
4
  from typing import Optional
5
5
 
6
+ from rich.console import Console
7
+
8
+ from lightning_sdk.api.lit_container_api import LitContainerApi
6
9
  from lightning_sdk.cli.exceptions import StudioCliError
7
10
  from lightning_sdk.cli.studios_menu import _StudiosMenu
11
+ from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
8
12
  from lightning_sdk.models import download_model
9
13
  from lightning_sdk.studio import Studio
10
14
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
11
15
 
12
16
 
13
- class _Downloads(_StudiosMenu):
17
+ class _Downloads(_StudiosMenu, _TeamspacesMenu):
14
18
  """Download files and folders from Lightning AI."""
15
19
 
16
20
  def model(self, name: str, download_dir: str = ".") -> None:
@@ -130,3 +134,18 @@ class _Downloads(_StudiosMenu):
130
134
  f"Could not download the file from the given Studio {studio}. "
131
135
  "Please contact Lightning AI directly to resolve this issue."
132
136
  ) from e
137
+
138
+ def container(self, container: str, teamspace: Optional[str] = None, tag: str = "latest") -> None:
139
+ """Download a docker container from a teamspace.
140
+
141
+ Args:
142
+ container: The name of the container to download.
143
+ teamspace: The name of the teamspace to download the container from.
144
+ tag: The tag of the container to download.
145
+ """
146
+ resolved_teamspace = self._resolve_teamspace(teamspace)
147
+ console = Console()
148
+ with console.status("Downloading container..."):
149
+ api = LitContainerApi()
150
+ api.download_container(container, resolved_teamspace, tag)
151
+ console.print("Container downloaded successfully", style="green")
@@ -1,3 +1,7 @@
1
+ import sys
2
+ from types import TracebackType
3
+ from typing import Type
4
+
1
5
  from fire import Fire
2
6
  from lightning_utilities.core.imports import RequirementCache
3
7
 
@@ -32,6 +36,8 @@ class StudioCLI:
32
36
  self.inspect = _Inspect()
33
37
  self.stop = _Stop()
34
38
 
39
+ sys.excepthook = _notify_exception
40
+
35
41
  def login(self) -> None:
36
42
  """Login to Lightning AI Studios."""
37
43
  auth = Auth()
@@ -48,6 +54,11 @@ class StudioCLI:
48
54
  auth.clear()
49
55
 
50
56
 
57
+ def _notify_exception(exception_type: Type[BaseException], value: BaseException, tb: TracebackType) -> None: # No
58
+ """CLI won't show tracebacks, just print the exception message."""
59
+ print(value)
60
+
61
+
51
62
  def main_cli() -> None:
52
63
  """CLI entrypoint."""
53
64
  Fire(StudioCLI(), name="lightning")
lightning_sdk/cli/list.py CHANGED
@@ -20,7 +20,36 @@ class _List(_TeamspacesMenu):
20
20
  """
21
21
  resolved_teamspace = self._resolve_teamspace(teamspace=teamspace)
22
22
 
23
- print("Available Jobs:\n" + "\n".join([j.name for j in resolved_teamspace.jobs]))
23
+ jobs = resolved_teamspace.jobs
24
+
25
+ table = Table(
26
+ pad_edge=True,
27
+ )
28
+ table.add_column("Name")
29
+ table.add_column("Teamspace")
30
+ table.add_column("Studio")
31
+ table.add_column("Image")
32
+ table.add_column("Status")
33
+ table.add_column("Machine")
34
+ table.add_column("Total Cost")
35
+ for j in jobs:
36
+ # we know we just fetched these, so no need to refetch
37
+ j._prevent_refetch_latest = True
38
+ j._internal_job._prevent_refetch_latest = True
39
+
40
+ studio = j.studio
41
+ table.add_row(
42
+ j.name,
43
+ f"{j.teamspace.owner.name}/{j.teamspace.name}",
44
+ studio.name if studio else None,
45
+ j.image,
46
+ str(j.status),
47
+ str(j.machine),
48
+ f"{j.total_cost:.3f}",
49
+ )
50
+
51
+ console = Console()
52
+ console.print(table)
24
53
 
25
54
  def mmts(self, teamspace: Optional[str] = None) -> None:
26
55
  """List multi-machine jobs for a given teamspace.
@@ -32,7 +61,36 @@ class _List(_TeamspacesMenu):
32
61
  """
33
62
  resolved_teamspace = self._resolve_teamspace(teamspace=teamspace)
34
63
 
35
- print("Available MMTs:\n" + "\n".join([j.name for j in resolved_teamspace.multi_machine_jobs]))
64
+ jobs = resolved_teamspace.multi_machine_jobs
65
+
66
+ table = Table(pad_edge=True)
67
+ table.add_column("Name")
68
+ table.add_column("Teamspace")
69
+ table.add_column("Studio")
70
+ table.add_column("Image")
71
+ table.add_column("Status")
72
+ table.add_column("Machine")
73
+ table.add_column("Num Machines")
74
+ table.add_column("Total Cost")
75
+ for j in jobs:
76
+ # we know we just fetched these, so no need to refetch
77
+ j._prevent_refetch_latest = True
78
+ j._internal_job._prevent_refetch_latest = True
79
+
80
+ studio = j.studio
81
+ table.add_row(
82
+ j.name,
83
+ f"{j.teamspace.owner.name}/{j.teamspace.name}",
84
+ studio.name if studio else None,
85
+ j.image,
86
+ str(j.status),
87
+ str(j.machine),
88
+ str(j.num_machines),
89
+ str(j.total_cost),
90
+ )
91
+
92
+ console = Console()
93
+ console.print(table)
36
94
 
37
95
  def containers(self, teamspace: Optional[str] = None) -> None:
38
96
  """Display the list of available containers.
lightning_sdk/cli/run.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Dict, Optional
1
+ from typing import TYPE_CHECKING, Dict, Optional, Union
2
2
 
3
3
  from lightning_sdk.job import Job
4
4
  from lightning_sdk.machine import Machine
@@ -111,7 +111,7 @@ class _Run:
111
111
  # might need to move to different cli library
112
112
  def job(
113
113
  self,
114
- name: str,
114
+ name: Optional[str] = None,
115
115
  machine: Optional[str] = None,
116
116
  command: Optional[str] = None,
117
117
  studio: Optional[str] = None,
@@ -128,16 +128,27 @@ class _Run:
128
128
  artifacts_remote: Optional[str] = None,
129
129
  entrypoint: str = "sh -c",
130
130
  ) -> None:
131
+ if not name:
132
+ from datetime import datetime
133
+
134
+ timestr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
135
+ name = f"job-{timestr}"
136
+
131
137
  if machine is None:
132
138
  # TODO: infer from studio
133
139
  machine = "CPU"
134
- machine_enum = Machine(machine.upper())
140
+ machine_enum: Union[str, Machine]
141
+ try:
142
+ machine_enum = Machine[machine.upper()]
143
+ except KeyError:
144
+ machine_enum = machine
135
145
 
136
146
  resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
137
147
 
138
148
  if cloud_account is None:
139
149
  cloud_account = resolved_teamspace.default_cloud_account
140
150
  machine_enum = Machine(machine.upper())
151
+
141
152
  Job.run(
142
153
  name=name,
143
154
  machine=machine_enum,
@@ -188,7 +199,11 @@ class _Run:
188
199
  if machine is None:
189
200
  # TODO: infer from studio
190
201
  machine = "CPU"
191
- machine_enum = Machine(machine.upper())
202
+ machine_enum: Union[str, Machine]
203
+ try:
204
+ machine_enum = Machine[machine.upper()]
205
+ except KeyError:
206
+ machine_enum = machine
192
207
 
193
208
  resolved_teamspace = Teamspace(name=teamspace, org=org, user=user)
194
209
  if cloud_account is None:
@@ -4,18 +4,22 @@ import os
4
4
  from pathlib import Path
5
5
  from typing import Dict, List, Optional
6
6
 
7
+ from rich.console import Console
8
+ from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
7
9
  from simple_term_menu import TerminalMenu
8
10
  from tqdm import tqdm
9
11
 
12
+ from lightning_sdk.api.lit_container_api import LitContainerApi
10
13
  from lightning_sdk.api.utils import _get_cloud_url
11
14
  from lightning_sdk.cli.exceptions import StudioCliError
12
15
  from lightning_sdk.cli.studios_menu import _StudiosMenu
16
+ from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
13
17
  from lightning_sdk.models import upload_model
14
18
  from lightning_sdk.studio import Studio
15
19
  from lightning_sdk.utils.resolve import _get_authed_user, skip_studio_init
16
20
 
17
21
 
18
- class _Uploads(_StudiosMenu):
22
+ class _Uploads(_StudiosMenu, _TeamspacesMenu):
19
23
  """Upload files and folders to Lightning AI."""
20
24
 
21
25
  _studio_upload_status_path = "~/.lightning/studios/uploads"
@@ -146,6 +150,33 @@ class _Uploads(_StudiosMenu):
146
150
  )
147
151
  print(f"See your file at {studio_url}")
148
152
 
153
+ def container(self, container: str, tag: str = "latest", teamspace: Optional[str] = None) -> None:
154
+ teamspace = self._resolve_teamspace(teamspace)
155
+ api = LitContainerApi()
156
+ console = Console()
157
+ with Progress(
158
+ SpinnerColumn(),
159
+ TextColumn("[progress.description]{task.description}"),
160
+ TimeElapsedColumn(),
161
+ console=console,
162
+ transient=False,
163
+ ) as progress:
164
+ push_task = progress.add_task("Pushing Docker image", total=None)
165
+ resp = api.upload_container(container, teamspace, tag)
166
+ for line in resp:
167
+ if "status" in line:
168
+ console.print(line["status"], style="bright_black")
169
+ progress.update(push_task, description="Pushing Docker image")
170
+ elif "aux" in line:
171
+ console.print(line["aux"], style="bright_black")
172
+ elif "error" in line:
173
+ progress.stop()
174
+ console.print(f"\n[red]{line}[/red]")
175
+ return
176
+ else:
177
+ console.print(line, style="bright_black")
178
+ progress.update(push_task, description="[green]Container pushed![/green]")
179
+
149
180
  def _start_parallel_upload(
150
181
  self, executor: concurrent.futures.ThreadPoolExecutor, studio: Studio, upload_state: Dict[str, str]
151
182
  ) -> List[concurrent.futures.Future]:
lightning_sdk/job/base.py CHANGED
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  class MachineDict(TypedDict):
17
17
  name: str
18
18
  status: "Status"
19
- machine: "Machine"
19
+ machine: Union["Machine", str]
20
20
 
21
21
 
22
22
  class JobDict(MachineDict):
@@ -24,6 +24,7 @@ class JobDict(MachineDict):
24
24
  teamspace: str
25
25
  studio: Optional[str]
26
26
  image: Optional[str]
27
+ total_cost: float
27
28
 
28
29
 
29
30
  class _BaseJob(ABC):
@@ -61,11 +62,13 @@ class _BaseJob(ABC):
61
62
  if _fetch_job:
62
63
  self._update_internal_job()
63
64
 
65
+ self._prevent_refetch_latest = False
66
+
64
67
  @classmethod
65
68
  def run(
66
69
  cls,
67
70
  name: str,
68
- machine: "Machine",
71
+ machine: Union["Machine", str],
69
72
  command: Optional[str] = None,
70
73
  studio: Union["Studio", str, None] = None,
71
74
  image: Optional[str] = None,
@@ -202,7 +205,7 @@ class _BaseJob(ABC):
202
205
  @abstractmethod
203
206
  def _submit(
204
207
  self,
205
- machine: "Machine",
208
+ machine: Union["Machine", str],
206
209
  command: Optional[str] = None,
207
210
  studio: Optional["Studio"] = None,
208
211
  image: Optional[str] = None,
@@ -268,7 +271,7 @@ class _BaseJob(ABC):
268
271
 
269
272
  @property
270
273
  @abstractmethod
271
- def machine(self) -> "Machine":
274
+ def machine(self) -> Union["Machine", str]:
272
275
  """The machine type the job is running on."""
273
276
 
274
277
  @property
@@ -332,6 +335,7 @@ class _BaseJob(ABC):
332
335
  "command": self.command,
333
336
  "status": self.status,
334
337
  "machine": self.machine,
338
+ "total_cost": self.total_cost,
335
339
  }
336
340
 
337
341
  def json(self) -> str:
@@ -361,3 +365,18 @@ class _BaseJob(ABC):
361
365
  self._update_internal_job()
362
366
 
363
367
  return self._job
368
+
369
+ @property
370
+ def total_cost(self) -> float:
371
+ """The number of credits the job was consuming so far."""
372
+ return self._job_api.get_total_cost(self._latest_job)
373
+
374
+ @property
375
+ def _latest_job(self) -> Any:
376
+ """Guarantees to fetch the latest version of a job before returning it."""
377
+ # in some cases we know we just refetched the latest state, no need to refetch again
378
+ if self._prevent_refetch_latest:
379
+ return self._guaranteed_job
380
+
381
+ self._update_internal_job()
382
+ return self._job
lightning_sdk/job/job.py CHANGED
@@ -89,7 +89,7 @@ class Job(_BaseJob):
89
89
  def run(
90
90
  cls,
91
91
  name: str,
92
- machine: "Machine",
92
+ machine: Union["Machine", str],
93
93
  command: Optional[str] = None,
94
94
  studio: Union["Studio", str, None] = None,
95
95
  image: Union[str, None] = None,
@@ -169,7 +169,7 @@ class Job(_BaseJob):
169
169
 
170
170
  def _submit(
171
171
  self,
172
- machine: "Machine",
172
+ machine: Union["Machine", str],
173
173
  command: Optional[str] = None,
174
174
  studio: Optional["Studio"] = None,
175
175
  image: Optional[str] = None,
@@ -225,6 +225,7 @@ class Job(_BaseJob):
225
225
  cloud_account_auth=cloud_account_auth,
226
226
  artifacts_local=artifacts_local,
227
227
  artifacts_remote=artifacts_remote,
228
+ entrypoint=entrypoint,
228
229
  )
229
230
  return self
230
231
 
@@ -248,7 +249,7 @@ class Job(_BaseJob):
248
249
  return self._internal_job.status
249
250
 
250
251
  @property
251
- def machine(self) -> "Machine":
252
+ def machine(self) -> Union["Machine", str]:
252
253
  """The machine type the job is running on."""
253
254
  return self._internal_job.machine
254
255
 
lightning_sdk/job/v1.py CHANGED
@@ -44,7 +44,7 @@ class _JobV1(_BaseJob):
44
44
  def run(
45
45
  cls,
46
46
  name: str,
47
- machine: "Machine",
47
+ machine: Union["Machine", str],
48
48
  command: str,
49
49
  studio: "Studio",
50
50
  teamspace: Union[str, "Teamspace", None] = None,
@@ -89,7 +89,7 @@ class _JobV1(_BaseJob):
89
89
 
90
90
  def _submit(
91
91
  self,
92
- machine: "Machine",
92
+ machine: Union["Machine", str],
93
93
  command: Optional[str] = None,
94
94
  studio: Optional["Studio"] = None,
95
95
  image: Optional[str] = None,
@@ -174,7 +174,7 @@ class _JobV1(_BaseJob):
174
174
 
175
175
  def stop(self) -> None:
176
176
  """Stops the job. is blocking until the ob is stopped."""
177
- if self.status in (Status.Stopped, Status.Failed):
177
+ if self.status in (Status.Stopped, Status.Completed, Status.Failed):
178
178
  return None
179
179
 
180
180
  return self._job_api.stop_job(self._job.id, self.teamspace.id)
@@ -195,7 +195,7 @@ class _JobV1(_BaseJob):
195
195
  return Work(_work[0].id, self, self.teamspace)
196
196
 
197
197
  @property
198
- def machine(self) -> "Machine":
198
+ def machine(self) -> Union["Machine", str]:
199
199
  """Get the machine the job is running on."""
200
200
  return self.work.machine
201
201
 
lightning_sdk/job/v2.py CHANGED
@@ -1,13 +1,13 @@
1
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
1
+ from typing import TYPE_CHECKING, Dict, Optional, Union
2
2
 
3
3
  from lightning_sdk.api.job_api import JobApiV2
4
4
  from lightning_sdk.api.utils import _get_cloud_url
5
5
  from lightning_sdk.job.base import _BaseJob
6
+ from lightning_sdk.status import Status
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from lightning_sdk.machine import Machine
9
10
  from lightning_sdk.organization import Organization
10
- from lightning_sdk.status import Status
11
11
  from lightning_sdk.studio import Studio
12
12
  from lightning_sdk.teamspace import Teamspace
13
13
  from lightning_sdk.user import User
@@ -37,7 +37,7 @@ class _JobV2(_BaseJob):
37
37
 
38
38
  def _submit(
39
39
  self,
40
- machine: "Machine",
40
+ machine: Union["Machine", str],
41
41
  command: Optional[str] = None,
42
42
  studio: Optional["Studio"] = None,
43
43
  image: Optional[str] = None,
@@ -121,6 +121,9 @@ class _JobV2(_BaseJob):
121
121
 
122
122
  def stop(self) -> None:
123
123
  """Stop the job. If the job is already stopped, this is a no-op. This is blocking until the job is stopped."""
124
+ if self.status in (Status.Stopped, Status.Completed, Status.Failed):
125
+ return
126
+
124
127
  self._job_api.stop_job(job_id=self._guaranteed_job.id, teamspace_id=self._teamspace.id)
125
128
 
126
129
  def delete(self) -> None:
@@ -134,19 +137,13 @@ class _JobV2(_BaseJob):
134
137
  cloudspace_id=self._guaranteed_job.spec.cloudspace_id,
135
138
  )
136
139
 
137
- @property
138
- def _latest_job(self) -> Any:
139
- """Guarantees to fetch the latest version of a job before returning it."""
140
- self._update_internal_job()
141
- return self._job
142
-
143
140
  @property
144
141
  def status(self) -> "Status":
145
142
  """The current status of the job."""
146
143
  return self._job_api._job_state_to_external(self._latest_job.state)
147
144
 
148
145
  @property
149
- def machine(self) -> "Machine":
146
+ def machine(self) -> Union["Machine", str]:
150
147
  """The machine type the job is running on."""
151
148
  # only fetch the job it it hasn't been fetched yet as machine cannot change over time
152
149
  return self._job_api._get_job_machine_from_spec(self._guaranteed_job.spec)