lightning-sdk 0.1.57__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. lightning_sdk/__init__.py +5 -3
  2. lightning_sdk/api/deployment_api.py +23 -11
  3. lightning_sdk/api/job_api.py +42 -7
  4. lightning_sdk/api/lit_container_api.py +88 -22
  5. lightning_sdk/api/mmt_api.py +46 -8
  6. lightning_sdk/api/pipeline_api.py +50 -0
  7. lightning_sdk/api/teamspace_api.py +2 -2
  8. lightning_sdk/api/utils.py +15 -5
  9. lightning_sdk/cli/ai_hub.py +30 -65
  10. lightning_sdk/cli/coloring.py +60 -0
  11. lightning_sdk/cli/configure.py +25 -40
  12. lightning_sdk/cli/connect.py +7 -20
  13. lightning_sdk/cli/create.py +83 -0
  14. lightning_sdk/cli/delete.py +72 -75
  15. lightning_sdk/cli/docker.py +77 -0
  16. lightning_sdk/cli/download.py +71 -111
  17. lightning_sdk/cli/entrypoint.py +44 -65
  18. lightning_sdk/cli/generate.py +28 -43
  19. lightning_sdk/cli/inspect.py +22 -50
  20. lightning_sdk/cli/list.py +281 -222
  21. lightning_sdk/cli/mmts_menu.py +1 -1
  22. lightning_sdk/cli/open.py +62 -0
  23. lightning_sdk/cli/run.py +430 -263
  24. lightning_sdk/cli/serve.py +162 -189
  25. lightning_sdk/cli/start.py +55 -36
  26. lightning_sdk/cli/stop.py +97 -55
  27. lightning_sdk/cli/switch.py +53 -36
  28. lightning_sdk/cli/upload.py +318 -245
  29. lightning_sdk/deployment/__init__.py +2 -0
  30. lightning_sdk/deployment/deployment.py +33 -8
  31. lightning_sdk/lightning_cloud/openapi/__init__.py +21 -0
  32. lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
  33. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +10 -6
  34. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +355 -4
  35. lightning_sdk/lightning_cloud/openapi/api/lit_logger_service_api.py +4 -4
  36. lightning_sdk/lightning_cloud/openapi/api/lit_registry_service_api.py +14 -2
  37. lightning_sdk/lightning_cloud/openapi/api/pipelines_service_api.py +670 -0
  38. lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +303 -4
  39. lightning_sdk/lightning_cloud/openapi/models/__init__.py +20 -0
  40. lightning_sdk/lightning_cloud/openapi/models/agents_id_body.py +17 -69
  41. lightning_sdk/lightning_cloud/openapi/models/cluster_id_capacityreservations_body.py +27 -1
  42. lightning_sdk/lightning_cloud/openapi/models/create.py +27 -1
  43. lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +53 -1
  44. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +105 -1
  45. lightning_sdk/lightning_cloud/openapi/models/id_visibility_body1.py +1 -27
  46. lightning_sdk/lightning_cloud/openapi/models/id_visibility_body2.py +149 -0
  47. lightning_sdk/lightning_cloud/openapi/models/org_id_memberships_body.py +27 -1
  48. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +157 -1
  49. lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body.py +435 -0
  50. lightning_sdk/lightning_cloud/openapi/models/project_id_pipelines_body.py +201 -0
  51. lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +157 -1
  52. lightning_sdk/lightning_cloud/openapi/models/slurm_jobs_body.py +79 -1
  53. lightning_sdk/lightning_cloud/openapi/models/uploads_upload_id_body.py +1 -27
  54. lightning_sdk/lightning_cloud/openapi/models/uploads_upload_id_body1.py +175 -0
  55. lightning_sdk/lightning_cloud/openapi/models/v1_agent_job.py +79 -1
  56. lightning_sdk/lightning_cloud/openapi/models/v1_assistant.py +17 -69
  57. lightning_sdk/lightning_cloud/openapi/models/v1_capacity_block_offering.py +27 -1
  58. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_artifact_event_type.py +1 -1
  59. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +131 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_capacity_reservation.py +79 -1
  61. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_security_options.py +27 -1
  62. lightning_sdk/lightning_cloud/openapi/models/v1_complete_upload_temporary_artifact_request.py +175 -0
  63. lightning_sdk/lightning_cloud/openapi/models/v1_create_deployment_request.py +461 -0
  64. lightning_sdk/lightning_cloud/openapi/models/v1_create_deployment_template_request.py +27 -1
  65. lightning_sdk/lightning_cloud/openapi/models/v1_create_job_request.py +201 -0
  66. lightning_sdk/lightning_cloud/openapi/models/v1_create_managed_endpoint_response.py +149 -0
  67. lightning_sdk/lightning_cloud/openapi/models/v1_create_multi_machine_job_request.py +253 -0
  68. lightning_sdk/lightning_cloud/openapi/models/v1_data_connection.py +27 -1
  69. lightning_sdk/lightning_cloud/openapi/models/v1_delete_pipeline_response.py +149 -0
  70. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +105 -1
  71. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_details.py +175 -0
  72. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_template.py +53 -1
  73. lightning_sdk/lightning_cloud/openapi/models/v1_filestore_data_connection.py +201 -0
  74. lightning_sdk/lightning_cloud/openapi/models/v1_filesystem_job.py +27 -1
  75. lightning_sdk/lightning_cloud/openapi/models/v1_filesystem_mmt.py +27 -1
  76. lightning_sdk/lightning_cloud/openapi/models/v1_find_capacity_block_offering_response.py +29 -3
  77. lightning_sdk/lightning_cloud/openapi/models/v1_job.py +133 -3
  78. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +53 -1
  79. lightning_sdk/lightning_cloud/openapi/models/v1_job_timing.py +27 -1
  80. lightning_sdk/lightning_cloud/openapi/models/v1_list_pipelines_response.py +123 -0
  81. lightning_sdk/lightning_cloud/openapi/models/v1_lit_registry_artifact.py +27 -1
  82. lightning_sdk/lightning_cloud/openapi/models/v1_lit_repository.py +29 -1
  83. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
  84. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +27 -1
  85. lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_state.py +2 -0
  86. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +157 -1
  87. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline.py +487 -0
  88. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_step.py +253 -0
  89. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_step_status.py +331 -0
  90. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_step_type.py +104 -0
  91. lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +157 -1
  92. lightning_sdk/lightning_cloud/openapi/models/v1_restart_timing.py +27 -1
  93. lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +1 -0
  94. lightning_sdk/lightning_cloud/openapi/models/v1_shared_filesystem.py +201 -0
  95. lightning_sdk/lightning_cloud/openapi/models/v1_slurm_job.py +27 -1
  96. lightning_sdk/lightning_cloud/openapi/models/v1_update_job_visibility_response.py +97 -0
  97. lightning_sdk/lightning_cloud/openapi/models/v1_upload_temporary_artifact_request.py +123 -0
  98. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +95 -355
  99. lightning_sdk/lightning_cloud/openapi/models/validate.py +27 -1
  100. lightning_sdk/lightning_cloud/rest_client.py +4 -2
  101. lightning_sdk/machine.py +25 -1
  102. lightning_sdk/models.py +18 -12
  103. lightning_sdk/pipeline/__init__.py +4 -0
  104. lightning_sdk/pipeline/pipeline.py +109 -0
  105. lightning_sdk/pipeline/types.py +268 -0
  106. lightning_sdk/pipeline/utils.py +69 -0
  107. lightning_sdk/plugin.py +9 -10
  108. lightning_sdk/services/utilities.py +2 -2
  109. lightning_sdk/studio.py +5 -1
  110. lightning_sdk/teamspace.py +1 -1
  111. lightning_sdk/utils/resolve.py +12 -1
  112. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/METADATA +6 -8
  113. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/RECORD +117 -88
  114. lightning_sdk/cli/legacy.py +0 -135
  115. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/LICENSE +0 -0
  116. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/WHEEL +0 -0
  117. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/entry_points.txt +0 -0
  118. {lightning_sdk-0.1.57.dist-info → lightning_sdk-0.2.0.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ class Validate(object):
45
45
  'check_is_public': 'bool',
46
46
  'cluster_ids': 'list[str]',
47
47
  'efs': 'V1EfsConfig',
48
+ 'filestore': 'V1FilestoreDataConnection',
48
49
  'gcp': 'V1GcpDataConnection',
49
50
  'gcs_folder': 'V1GCSFolderDataConnection',
50
51
  's3_folder': 'V1S3FolderDataConnection'
@@ -55,17 +56,19 @@ class Validate(object):
55
56
  'check_is_public': 'checkIsPublic',
56
57
  'cluster_ids': 'clusterIds',
57
58
  'efs': 'efs',
59
+ 'filestore': 'filestore',
58
60
  'gcp': 'gcp',
59
61
  'gcs_folder': 'gcsFolder',
60
62
  's3_folder': 's3Folder'
61
63
  }
62
64
 
63
- def __init__(self, aws: 'V1AwsDataConnection' =None, check_is_public: 'bool' =None, cluster_ids: 'list[str]' =None, efs: 'V1EfsConfig' =None, gcp: 'V1GcpDataConnection' =None, gcs_folder: 'V1GCSFolderDataConnection' =None, s3_folder: 'V1S3FolderDataConnection' =None): # noqa: E501
65
+ def __init__(self, aws: 'V1AwsDataConnection' =None, check_is_public: 'bool' =None, cluster_ids: 'list[str]' =None, efs: 'V1EfsConfig' =None, filestore: 'V1FilestoreDataConnection' =None, gcp: 'V1GcpDataConnection' =None, gcs_folder: 'V1GCSFolderDataConnection' =None, s3_folder: 'V1S3FolderDataConnection' =None): # noqa: E501
64
66
  """Validate - a model defined in Swagger""" # noqa: E501
65
67
  self._aws = None
66
68
  self._check_is_public = None
67
69
  self._cluster_ids = None
68
70
  self._efs = None
71
+ self._filestore = None
69
72
  self._gcp = None
70
73
  self._gcs_folder = None
71
74
  self._s3_folder = None
@@ -78,6 +81,8 @@ class Validate(object):
78
81
  self.cluster_ids = cluster_ids
79
82
  if efs is not None:
80
83
  self.efs = efs
84
+ if filestore is not None:
85
+ self.filestore = filestore
81
86
  if gcp is not None:
82
87
  self.gcp = gcp
83
88
  if gcs_folder is not None:
@@ -169,6 +174,27 @@ class Validate(object):
169
174
 
170
175
  self._efs = efs
171
176
 
177
+ @property
178
+ def filestore(self) -> 'V1FilestoreDataConnection':
179
+ """Gets the filestore of this Validate. # noqa: E501
180
+
181
+
182
+ :return: The filestore of this Validate. # noqa: E501
183
+ :rtype: V1FilestoreDataConnection
184
+ """
185
+ return self._filestore
186
+
187
+ @filestore.setter
188
+ def filestore(self, filestore: 'V1FilestoreDataConnection'):
189
+ """Sets the filestore of this Validate.
190
+
191
+
192
+ :param filestore: The filestore of this Validate. # noqa: E501
193
+ :type: V1FilestoreDataConnection
194
+ """
195
+
196
+ self._filestore = filestore
197
+
172
198
  @property
173
199
  def gcp(self) -> 'V1GcpDataConnection':
174
200
  """Gets the gcp of this Validate. # noqa: E501
@@ -32,7 +32,8 @@ from lightning_sdk.lightning_cloud.openapi import (
32
32
  StorageServiceApi,
33
33
  DeploymentTemplatesServiceApi,
34
34
  ModelsStoreApi,
35
- LitRegistryServiceApi
35
+ LitRegistryServiceApi,
36
+ PipelinesServiceApi,
36
37
  )
37
38
  from lightning_sdk.lightning_cloud.openapi.rest import ApiException
38
39
  from lightning_sdk.lightning_cloud.source_code.logs_socket_api import LightningLogsSocketAPI
@@ -92,7 +93,8 @@ class GridRestClient(
92
93
  StorageServiceApi,
93
94
  DeploymentTemplatesServiceApi,
94
95
  ModelsStoreApi,
95
- LitRegistryServiceApi
96
+ LitRegistryServiceApi,
97
+ PipelinesServiceApi,
96
98
  ):
97
99
 
98
100
  def __init__(self, api_client: Optional[ApiClient] = None):
lightning_sdk/machine.py CHANGED
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import ClassVar, Optional
2
+ from typing import Any, ClassVar, Optional, Tuple
3
3
 
4
4
 
5
5
  @dataclass(frozen=True)
@@ -42,6 +42,30 @@ class Machine:
42
42
  return self.instance_type == other.instance_type
43
43
  return False
44
44
 
45
+ def is_cpu(self) -> bool:
46
+ """Whether the machine is a CPU."""
47
+ return (
48
+ self == Machine.CPU_SMALL
49
+ or self == Machine.CPU
50
+ or self == Machine.DATA_PREP
51
+ or self == Machine.DATA_PREP_MAX
52
+ or self == Machine.DATA_PREP_ULTRA
53
+ )
54
+
55
+ @classmethod
56
+ def from_str(cls, machine: str, *additional_machine_ids: Any) -> "Machine":
57
+ possible_values: Tuple["Machine", ...] = tuple(
58
+ [machine for machine in cls.__dict__.values() if isinstance(machine, cls)]
59
+ )
60
+ for m in possible_values:
61
+ for machine_id in [machine, *additional_machine_ids]:
62
+ if machine_id in (getattr(m, "name", None), getattr(m, "instance_type", None)):
63
+ return m
64
+
65
+ if additional_machine_ids:
66
+ return cls(machine, *additional_machine_ids)
67
+ return cls(machine, machine)
68
+
45
69
 
46
70
  Machine.CPU_SMALL = Machine(name="CPU_SMALL", instance_type="m3.medium")
47
71
  Machine.CPU = Machine(name="CPU", instance_type="cpu-4")
lightning_sdk/models.py CHANGED
@@ -7,7 +7,7 @@ from lightning_sdk.api import OrgApi, TeamspaceApi, UserApi
7
7
  from lightning_sdk.lightning_cloud.openapi.models import V1Membership, V1OwnerType
8
8
  from lightning_sdk.lightning_cloud.openapi.rest import ApiException
9
9
  from lightning_sdk.user import User
10
- from lightning_sdk.utils.resolve import _get_authed_user
10
+ from lightning_sdk.utils.resolve import _get_authed_user, _resolve_teamspace
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from lightning_sdk.teamspace import Teamspace
@@ -76,20 +76,20 @@ def _parse_model_name_and_version(name: str) -> Tuple[str, str, str, str]:
76
76
  """Parse the name argument into its components."""
77
77
  try:
78
78
  org_name, teamspace_name, model_name = name.split("/")
79
- parts = model_name.split(":")
80
- if len(parts) == 1:
81
- return org_name, teamspace_name, parts[0], "latest"
82
- if len(parts) == 2:
83
- return org_name, teamspace_name, parts[0], parts[1]
84
- # The rest of the validation for name and version happens in the backend
85
- raise ValueError(
86
- "Model version is expected to be in the format `entity/modelname:version` separated by a"
87
- f" single colon, but got: {name}"
88
- )
89
79
  except ValueError as err:
90
80
  raise ValueError(
91
- f"Model name must be in the format 'organization/teamspace/model' but you provided '{name}'."
81
+ f"Model name must be in the format `organization/teamspace/model_name` but you provided '{name}'."
92
82
  ) from err
83
+ parts = model_name.split(":")
84
+ if len(parts) == 1:
85
+ return org_name, teamspace_name, parts[0], "default"
86
+ if len(parts) == 2:
87
+ return org_name, teamspace_name, parts[0], parts[1]
88
+ # The rest of the validation for name and version happens in the backend
89
+ raise ValueError(
90
+ "Model version is expected to be in the format `organization/teamspace/model_name:version`"
91
+ f" separated by a single colon, but got: {name}"
92
+ )
93
93
 
94
94
 
95
95
  def download_model(
@@ -105,6 +105,9 @@ def download_model(
105
105
  download_dir: The directory where the Model should be downloaded.
106
106
  progress_bar: Whether to show a progress bar when downloading.
107
107
  """
108
+ if "/" not in name: # do some magic if you run studio
109
+ teamspace = _resolve_teamspace(None, None, None)
110
+ name = f"{teamspace.owner.name}/{teamspace.name}/{name}"
108
111
  teamspace_owner_name, teamspace_name, model_name, version = _parse_model_name_and_version(name)
109
112
 
110
113
  download_dir = Path(download_dir)
@@ -144,6 +147,9 @@ def upload_model(
144
147
  If not provided, the default cloud account for the Teamspace will be used.
145
148
  progress_bar: Whether to show a progress bar for the upload.
146
149
  """
150
+ if "/" not in name: # do some magic if you run studio
151
+ teamspace = _resolve_teamspace(None, None, None)
152
+ name = f"{teamspace.owner.name}/{teamspace.name}/{name}"
147
153
  org_name, teamspace_name, model_name, _ = _parse_model_name_and_version(name)
148
154
  teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
149
155
  return teamspace.upload_model(
@@ -0,0 +1,4 @@
1
+ from lightning_sdk.pipeline.pipeline import Pipeline
2
+ from lightning_sdk.pipeline.types import MMT, Deployment, Job
3
+
4
+ __all__ = ["Pipeline", "Job", "MMT", "Deployment"]
@@ -0,0 +1,109 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from lightning_sdk.api import UserApi
4
+ from lightning_sdk.api.pipeline_api import PipelineApi
5
+ from lightning_sdk.lightning_cloud.login import Auth
6
+ from lightning_sdk.organization import Organization
7
+ from lightning_sdk.pipeline.types import MMT, Deployment, Job
8
+ from lightning_sdk.pipeline.utils import prepare_steps
9
+ from lightning_sdk.services.utilities import _get_cluster
10
+ from lightning_sdk.teamspace import Teamspace
11
+ from lightning_sdk.user import User
12
+ from lightning_sdk.utils.resolve import _resolve_org, _resolve_teamspace, _resolve_user
13
+
14
+
15
+ class Pipeline:
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ teamspace: Union[str, "Teamspace", None] = None,
20
+ org: Union[str, "Organization", None] = None,
21
+ user: Union[str, "User", None] = None,
22
+ cloud_account: Optional[str] = None,
23
+ shared_filesystem: Optional[bool] = None,
24
+ ) -> None:
25
+ """The Lightning Pipeline can be used to create complex DAG.
26
+
27
+ Arguments:
28
+ name: The desired name of the pipeline.
29
+ teamspace: The teamspace where the pipeline will be created.
30
+ org: The organization where the pipeline will be created.
31
+ user: The creator of the pipeline.
32
+ cloud_account: The cloud account to use for the entire pipeline.
33
+ shared_filesystem: Whether the pipeline should use a shared filesystem across all nodes.
34
+ Note: This forces the pipeline steps to be in the cloud_account and same region
35
+ """
36
+ self._auth = Auth()
37
+ self._user = None
38
+
39
+ try:
40
+ self._auth.authenticate()
41
+ if user is None:
42
+ self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
43
+ except ConnectionError as e:
44
+ raise e
45
+
46
+ self._name = name
47
+ self._org = _resolve_org(org)
48
+ self._user = _resolve_user(self._user or user)
49
+
50
+ self._teamspace = _resolve_teamspace(
51
+ teamspace=teamspace,
52
+ org=self._org,
53
+ user=self._user,
54
+ )
55
+
56
+ self._pipeline_api = PipelineApi()
57
+ self._cloud_account = _get_cluster(
58
+ client=self._pipeline_api._client, project_id=self._teamspace.id, cluster_id=cloud_account
59
+ )
60
+ self._shared_filesystem = shared_filesystem
61
+ self._is_created = False
62
+
63
+ pipeline = None
64
+
65
+ if name.startswith("pip_"):
66
+ pipeline = self._pipeline_api.get_pipeline_by_id(name, self._teamspace.id)
67
+
68
+ if pipeline:
69
+ self._name = pipeline.name
70
+ self._is_created = True
71
+ self._pipeline = pipeline
72
+
73
+ def run(self, steps: List[Union[Job, Deployment, MMT]]) -> None:
74
+ if len(steps) == 0:
75
+ raise ValueError("The provided steps is empty")
76
+
77
+ for step_idx, step in enumerate(steps):
78
+ if step.name in [None, ""]:
79
+ raise ValueError(f"The step {step_idx} requires a name")
80
+
81
+ steps = [
82
+ step.to_proto(self._teamspace, self._cloud_account.cluster_id or "", self._shared_filesystem)
83
+ for step in steps
84
+ ]
85
+
86
+ self._pipeline = self._pipeline_api.create_pipeline(
87
+ self._name,
88
+ self._teamspace.id,
89
+ prepare_steps(steps),
90
+ self._shared_filesystem or False,
91
+ )
92
+
93
+ def stop(self) -> None:
94
+ if self._pipeline is None:
95
+ return
96
+
97
+ self._pipeline_api.stop(self._pipeline)
98
+
99
+ def delete(self) -> None:
100
+ if self._pipeline is None:
101
+ return
102
+
103
+ self._pipeline_api.delete(self._teamspace.id, self._pipeline.id)
104
+
105
+ @property
106
+ def name(self) -> str:
107
+ if self._pipeline:
108
+ return self._pipeline.name
109
+ return None
@@ -0,0 +1,268 @@
1
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
2
+
3
+ from lightning_sdk.api.deployment_api import (
4
+ AutoScaleConfig,
5
+ AutoScalingMetric,
6
+ BasicAuth,
7
+ Env,
8
+ ExecHealthCheck,
9
+ HttpHealthCheck,
10
+ ReleaseStrategy,
11
+ Secret,
12
+ TokenAuth,
13
+ to_autoscaling,
14
+ to_endpoint,
15
+ to_spec,
16
+ to_strategy,
17
+ )
18
+ from lightning_sdk.job.v2 import JobApiV2
19
+ from lightning_sdk.lightning_cloud.openapi.models import (
20
+ V1CreateDeploymentRequest,
21
+ V1PipelineStep,
22
+ V1PipelineStepType,
23
+ )
24
+ from lightning_sdk.mmt.v2 import MMTApiV2
25
+ from lightning_sdk.studio import Studio
26
+
27
+ if TYPE_CHECKING:
28
+ from lightning_sdk.machine import Machine
29
+ from lightning_sdk.organization import Organization
30
+ from lightning_sdk.teamspace import Teamspace
31
+ from lightning_sdk.user import User
32
+
33
+
34
+ from lightning_sdk.pipeline.utils import DEFAULT
35
+
36
+
37
+ class Deployment:
38
+ # Note: This class is only temporary while pipeline is wip
39
+
40
+ def __init__(
41
+ self,
42
+ name: Optional[str] = None,
43
+ machine: Optional["Machine"] = None,
44
+ image: Optional[str] = None,
45
+ autoscale: Optional["AutoScaleConfig"] = None,
46
+ ports: Optional[List[float]] = None,
47
+ release_strategy: Optional["ReleaseStrategy"] = None,
48
+ entrypoint: Optional[str] = None,
49
+ command: Optional[str] = None,
50
+ env: Union[List[Union["Secret", "Env"]], Dict[str, str], None] = None,
51
+ spot: Optional[bool] = None,
52
+ replicas: Optional[int] = None,
53
+ health_check: Optional[Union["HttpHealthCheck", "ExecHealthCheck"]] = None,
54
+ auth: Optional[Union["BasicAuth", "TokenAuth"]] = None,
55
+ cloud_account: Optional[str] = None,
56
+ custom_domain: Optional[str] = None,
57
+ quantity: Optional[int] = None,
58
+ wait_for: Union[str, List[str]] = DEFAULT,
59
+ ) -> None:
60
+ self.name = name
61
+ self.machine = machine
62
+ self.image = image
63
+ self.autoscale = autoscale or AutoScaleConfig(
64
+ min_replicas=0,
65
+ max_replicas=1,
66
+ target_metrics=[
67
+ AutoScalingMetric(
68
+ name="CPU" if machine.is_cpu() else "GPU",
69
+ target=80,
70
+ )
71
+ ],
72
+ )
73
+ self.ports = ports
74
+ self.release_strategy = release_strategy
75
+ self.entrypoint = entrypoint
76
+ self.command = command
77
+ self.env = env
78
+ self.spot = spot
79
+ self.replicas = replicas or 1
80
+ self.health_check = health_check
81
+ self.auth = auth
82
+ self.cloud_account = cloud_account or ""
83
+ self.custom_domain = custom_domain
84
+ self.quantity = quantity
85
+ self.wait_for = wait_for
86
+
87
+ def to_proto(self, teamspace: "Teamspace", cloud_account: str, shared_filesystem: bool) -> V1PipelineStep:
88
+ _validate_cloud_account(cloud_account, self.cloud_account, shared_filesystem)
89
+ return V1PipelineStep(
90
+ name=self.name,
91
+ type=V1PipelineStepType.DEPLOYMENT,
92
+ wait_for=to_wait_for(self.wait_for),
93
+ deployment=V1CreateDeploymentRequest(
94
+ autoscaling=to_autoscaling(self.autoscale, self.replicas),
95
+ endpoint=to_endpoint(self.ports, self.auth, self.custom_domain),
96
+ name=self.name,
97
+ project_id=teamspace.id,
98
+ replicas=self.replicas,
99
+ spec=to_spec(
100
+ cloud_account=self.cloud_account or cloud_account,
101
+ command=self.command,
102
+ entrypoint=self.entrypoint,
103
+ env=self.env,
104
+ image=self.image,
105
+ spot=self.spot,
106
+ machine=self.machine,
107
+ health_check=self.health_check,
108
+ quantity=self.quantity,
109
+ ),
110
+ strategy=to_strategy(self.release_strategy),
111
+ ),
112
+ )
113
+
114
+
115
+ class Job:
116
+ # Note: This class is only temporary while pipeline is wip
117
+
118
+ def __init__(
119
+ self,
120
+ machine: Union["Machine", str],
121
+ name: Optional[str] = None,
122
+ command: Optional[str] = None,
123
+ studio: Union["Studio", str, None] = None,
124
+ image: Union[str, None] = None,
125
+ teamspace: Union[str, "Teamspace", None] = None,
126
+ org: Union[str, "Organization", None] = None,
127
+ user: Union[str, "User", None] = None,
128
+ cloud_account: Optional[str] = None,
129
+ env: Optional[Dict[str, str]] = None,
130
+ interruptible: bool = False,
131
+ image_credentials: Optional[str] = None,
132
+ cloud_account_auth: bool = False,
133
+ entrypoint: str = "sh -c",
134
+ path_mappings: Optional[Dict[str, str]] = None,
135
+ wait_for: Union[str, List[str]] = DEFAULT,
136
+ ) -> None:
137
+ self.name = name
138
+ self.machine = machine
139
+ self.command = command
140
+ self.studio = studio
141
+ self.image = image
142
+ self.teamspace = teamspace
143
+ self.org = org
144
+ self.user = user
145
+ self.cloud_account = cloud_account
146
+ self.env = env
147
+ self.interruptible = interruptible
148
+ self.image_credentials = image_credentials
149
+ self.cloud_account_auth = cloud_account_auth
150
+ self.entrypoint = entrypoint
151
+ self.path_mappings = path_mappings
152
+ self.wait_for = wait_for
153
+
154
+ def to_proto(self, teamspace: "Teamspace", cloud_account: str, shared_filesystem: bool) -> V1PipelineStep:
155
+ _validate_cloud_account(cloud_account, self.cloud_account, shared_filesystem)
156
+ body = JobApiV2._create_job_body(
157
+ name=self.name,
158
+ command=self.command,
159
+ cloud_account=self.cloud_account or cloud_account,
160
+ studio_id=None,
161
+ image=self.image,
162
+ machine=self.machine,
163
+ interruptible=self.interruptible,
164
+ env=self.env,
165
+ image_credentials=self.image_credentials,
166
+ cloud_account_auth=self.cloud_account_auth,
167
+ entrypoint=self.entrypoint,
168
+ path_mappings=self.path_mappings,
169
+ artifacts_local=None,
170
+ artifacts_remote=None,
171
+ )
172
+
173
+ return V1PipelineStep(
174
+ name=self.name,
175
+ type=V1PipelineStepType.JOB,
176
+ wait_for=to_wait_for(self.wait_for),
177
+ job=body,
178
+ )
179
+
180
+
181
+ class MMT:
182
+ # Note: This class is only temporary while pipeline is wip
183
+
184
+ def __init__(
185
+ self,
186
+ name: str,
187
+ machine: Union["Machine", str],
188
+ num_machines: Optional[int] = 2,
189
+ command: Optional[str] = None,
190
+ studio: Union["Studio", str, None] = None,
191
+ image: Optional[str] = None,
192
+ teamspace: Union[str, "Teamspace", None] = None,
193
+ org: Union[str, "Organization", None] = None,
194
+ user: Union[str, "User", None] = None,
195
+ cloud_account: Optional[str] = None,
196
+ env: Optional[Dict[str, str]] = None,
197
+ interruptible: bool = False,
198
+ image_credentials: Optional[str] = None,
199
+ cloud_account_auth: bool = False,
200
+ entrypoint: str = "sh -c",
201
+ path_mappings: Optional[Dict[str, str]] = None,
202
+ wait_for: Union[str, List[str]] = DEFAULT,
203
+ ) -> None:
204
+ self.machine = machine
205
+ self.num_machines = num_machines
206
+ self.name = name
207
+ self.command = command
208
+ self.studio = studio
209
+ self.image = image
210
+ self.teamspace = teamspace
211
+ self.org = org
212
+ self.user = user
213
+ self.cloud_account = cloud_account
214
+ self.env = env
215
+ self.interruptible = interruptible
216
+ self.image_credentials = image_credentials
217
+ self.cloud_account_auth = cloud_account_auth
218
+ self.entrypoint = entrypoint
219
+ self.path_mappings = path_mappings
220
+ self.wait_for = wait_for
221
+
222
+ def to_proto(self, teamspace: "Teamspace", cloud_account: str, shared_filesystem: bool) -> V1PipelineStep:
223
+ _validate_cloud_account(cloud_account, self.cloud_account, shared_filesystem)
224
+ body = MMTApiV2._create_mmt_body(
225
+ name=self.name,
226
+ num_machines=self.num_machines,
227
+ command=self.command,
228
+ cloud_account=self.cloud_account or cloud_account,
229
+ studio_id=self.studio.studio_id if isinstance(self.studio, Studio) else None,
230
+ image=self.image,
231
+ machine=self.machine,
232
+ interruptible=self.interruptible,
233
+ env=self.env,
234
+ image_credentials=self.image_credentials,
235
+ cloud_account_auth=self.cloud_account_auth,
236
+ entrypoint=self.entrypoint,
237
+ path_mappings=self.path_mappings,
238
+ artifacts_local=None, # deprecated in favor of path_mappings
239
+ artifacts_remote=None, # deprecated in favor of path_mappings
240
+ )
241
+
242
+ return V1PipelineStep(
243
+ name=self.name,
244
+ type=V1PipelineStepType.MMT,
245
+ wait_for=to_wait_for(self.wait_for),
246
+ mmt=body,
247
+ )
248
+
249
+
250
+ def to_wait_for(wait_for: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
251
+ if wait_for == DEFAULT:
252
+ return wait_for
253
+
254
+ if wait_for is None:
255
+ return []
256
+
257
+ return wait_for if isinstance(wait_for, list) else [wait_for]
258
+
259
+
260
+ def _validate_cloud_account(pipeline_cloud_account: str, step_cloud_account: str, shared_filesystem: bool) -> None:
261
+ if not shared_filesystem:
262
+ return
263
+
264
+ if pipeline_cloud_account != "" and step_cloud_account != "" and pipeline_cloud_account != step_cloud_account:
265
+ raise ValueError(
266
+ "With shared filesystem enabled, all the pipeline steps wait_for to be on the same cluster."
267
+ f" Found {pipeline_cloud_account} and {step_cloud_account}"
268
+ )
@@ -0,0 +1,69 @@
1
+ from typing import List
2
+
3
+ from lightning_sdk.lightning_cloud.openapi.models import V1PipelineStep, V1PipelineStepType
4
+
5
+ DEFAULT = "DEFAULT"
6
+
7
+
8
+ def prepare_steps(steps: List["V1PipelineStep"]) -> List["V1PipelineStep"]:
9
+ """The prepare_steps function is responsible for creating dependencies between steps.
10
+
11
+ The dependencies are based on whether a step wait_for to be executed before another.
12
+ """
13
+ name_to_step = {}
14
+ name_to_idx = {}
15
+
16
+ for current_step_idx, current_step in enumerate(steps):
17
+ if current_step.name not in name_to_step:
18
+ name_to_step[current_step.name] = current_step
19
+ name_to_idx[current_step.name] = current_step_idx
20
+ else:
21
+ raise ValueError(f"A step with the name {current_step.name} already exists.")
22
+
23
+ if steps[0].wait_for != DEFAULT:
24
+ raise ValueError("The first step isn't allowed to receive `wait_for=...`.")
25
+
26
+ steps[0].wait_for = []
27
+
28
+ # This implements a linear dependency between the steps as the default behaviour
29
+ for current_step_idx, current_step in reversed(list(enumerate(steps))):
30
+ if current_step_idx == 0:
31
+ continue
32
+
33
+ if current_step.wait_for == DEFAULT:
34
+ prev_step_idx = current_step_idx - 1
35
+ wait_for = []
36
+ while prev_step_idx > -1:
37
+ prev_step = steps[prev_step_idx]
38
+ wait_for.insert(0, steps[prev_step_idx].name)
39
+ if prev_step.wait_for != []:
40
+ break
41
+ prev_step_idx -= 1
42
+ current_step.wait_for = wait_for
43
+ else:
44
+ for name in current_step.wait_for:
45
+ if current_step.name == name:
46
+ raise ValueError("You can only reference prior steps")
47
+
48
+ if name not in name_to_step:
49
+ raise ValueError(f"The step {current_step_idx} doesn't have a valid wait_for. Found {name}")
50
+
51
+ if name_to_idx[name] >= name_to_idx[current_step.name]:
52
+ raise ValueError("You can only reference prior steps")
53
+
54
+ print()
55
+ print("===== Generated Pipeline =====")
56
+ for step_idx, step in enumerate(steps):
57
+ step_type = ""
58
+ if step.type == V1PipelineStepType.DEPLOYMENT:
59
+ step_type = "Deployment"
60
+ elif step.type == V1PipelineStepType.JOB:
61
+ step_type = "Job"
62
+ else:
63
+ step_type = "MMT"
64
+ wait_for = "nothing" if len(step.wait_for) == 0 else step.wait_for
65
+ print(f"{step_idx} - {step_type}['{step.name}'] wait_for {wait_for}")
66
+ print("===== ================== =====")
67
+ print()
68
+
69
+ return steps
lightning_sdk/plugin.py CHANGED
@@ -186,16 +186,15 @@ class MultiMachineTrainingPlugin(_Plugin):
186
186
 
187
187
  machine = _resolve_deprecated_cloud_compute(machine, cloud_compute)
188
188
 
189
- with forced_v1(MMT) as v1mmt:
190
- return v1mmt.run(
191
- name=name,
192
- num_machines=num_instances,
193
- machine=machine,
194
- command=command,
195
- studio=self._studio,
196
- teamspace=self._studio.teamspace,
197
- interruptible=interruptible,
198
- )
189
+ return MMT.run(
190
+ name=name,
191
+ num_machines=num_instances,
192
+ machine=machine,
193
+ command=command,
194
+ studio=self._studio,
195
+ teamspace=self._studio.teamspace,
196
+ interruptible=interruptible,
197
+ )
199
198
 
200
199
 
201
200
  class MultiMachineDataPrepPlugin(_Plugin):
@@ -6,7 +6,7 @@ import requests
6
6
  import urllib3
7
7
 
8
8
  from lightning_sdk.api.utils import _get_cloud_url
9
- from lightning_sdk.lightning_cloud.openapi import V1Membership
9
+ from lightning_sdk.lightning_cloud.openapi import V1Membership, V1ProjectClusterBinding
10
10
  from lightning_sdk.lightning_cloud.rest_client import LightningClient
11
11
 
12
12
  _CHUNK_SIZE = 1024 * 1024
@@ -35,7 +35,7 @@ def _get_project(client: LightningClient, project_name: Optional[str] = None) ->
35
35
  raise ValueError("No valid projects found. Please reach out to lightning.ai team to create a project")
36
36
 
37
37
 
38
- def _get_cluster(client: LightningClient, project_id: str, cluster_id: Optional[str] = None) -> V1Membership:
38
+ def _get_cluster(client: LightningClient, project_id: str, cluster_id: Optional[str] = None) -> V1ProjectClusterBinding:
39
39
  """Get a project membership for the user from the backend."""
40
40
  clusters = client.projects_service_list_project_cluster_bindings(project_id=project_id)
41
41
  if cluster_id: