lightning-sdk 0.1.40__py3-none-any.whl → 0.1.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightning_sdk/__init__.py +1 -1
- lightning_sdk/ai_hub.py +8 -3
- lightning_sdk/api/ai_hub_api.py +3 -3
- lightning_sdk/api/deployment_api.py +6 -6
- lightning_sdk/api/job_api.py +32 -6
- lightning_sdk/api/mmt_api.py +60 -19
- lightning_sdk/api/studio_api.py +37 -19
- lightning_sdk/api/teamspace_api.py +34 -29
- lightning_sdk/api/utils.py +48 -35
- lightning_sdk/cli/ai_hub.py +3 -3
- lightning_sdk/cli/entrypoint.py +3 -1
- lightning_sdk/cli/mmt.py +11 -10
- lightning_sdk/cli/run.py +9 -8
- lightning_sdk/cli/serve.py +130 -0
- lightning_sdk/deployment/deployment.py +18 -12
- lightning_sdk/job/base.py +118 -24
- lightning_sdk/job/job.py +87 -9
- lightning_sdk/job/v1.py +75 -18
- lightning_sdk/job/v2.py +51 -15
- lightning_sdk/job/work.py +36 -7
- lightning_sdk/lightning_cloud/openapi/__init__.py +13 -0
- lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +215 -5
- lightning_sdk/lightning_cloud/openapi/api/lit_logger_service_api.py +218 -0
- lightning_sdk/lightning_cloud/openapi/api/models_store_api.py +226 -0
- lightning_sdk/lightning_cloud/openapi/api/secret_service_api.py +5 -1
- lightning_sdk/lightning_cloud/openapi/api/snowflake_service_api.py +21 -1
- lightning_sdk/lightning_cloud/openapi/models/__init__.py +13 -0
- lightning_sdk/lightning_cloud/openapi/models/create_deployment_request_defines_a_spec_for_the_job_that_allows_for_autoscaling_jobs.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/deploymenttemplates_id_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/id_visibility_body.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/model_id_versions_body.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/project_id_multimachinejobs_body.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/project_id_snowflake_body.py +15 -67
- lightning_sdk/lightning_cloud/openapi/models/query_query_id_body.py +17 -69
- lightning_sdk/lightning_cloud/openapi/models/snowflake_export_body.py +29 -81
- lightning_sdk/lightning_cloud/openapi/models/snowflake_query_body.py +17 -69
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_deployment_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_model_file_url_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_response.py +17 -17
- lightning_sdk/lightning_cloud/openapi/models/v1_get_model_files_url_response.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_get_project_balance_response.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_header.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_list_multi_machine_job_events_response.py +123 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_metrics_stream.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_model_file.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_event.py +331 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_event_type.py +104 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_fault_tolerance.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_fault_tolerance_strategy.py +105 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_multi_machine_job_status.py +27 -1
- lightning_sdk/lightning_cloud/openapi/models/v1_rule_resource.py +2 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_secret_type.py +1 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_snowflake_data_connection.py +29 -81
- lightning_sdk/lightning_cloud/openapi/models/v1_system_metrics.py +29 -3
- lightning_sdk/lightning_cloud/openapi/models/v1_trainium_system_metrics.py +175 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_update_metrics_stream_visibility_response.py +97 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +41 -67
- lightning_sdk/lightning_cloud/openapi/models/v1_validate_deployment_image_request.py +149 -0
- lightning_sdk/lightning_cloud/openapi/models/v1_validate_deployment_image_response.py +97 -0
- lightning_sdk/lightning_cloud/rest_client.py +2 -0
- lightning_sdk/mmt/__init__.py +3 -0
- lightning_sdk/{_mmt → mmt}/base.py +20 -14
- lightning_sdk/{_mmt → mmt}/mmt.py +46 -17
- lightning_sdk/mmt/v1.py +129 -0
- lightning_sdk/{_mmt → mmt}/v2.py +16 -21
- lightning_sdk/plugin.py +43 -16
- lightning_sdk/services/file_endpoint.py +11 -5
- lightning_sdk/studio.py +16 -9
- lightning_sdk/teamspace.py +26 -14
- lightning_sdk/utils/resolve.py +18 -0
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/METADATA +3 -1
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/RECORD +80 -66
- lightning_sdk/_mmt/__init__.py +0 -3
- lightning_sdk/_mmt/v1.py +0 -69
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/LICENSE +0 -0
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/WHEEL +0 -0
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/entry_points.txt +0 -0
- {lightning_sdk-0.1.40.dist-info → lightning_sdk-0.1.42.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# coding: utf-8
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
external/v1/auth_service.proto
|
|
5
|
+
|
|
6
|
+
No description provided (generated by Swagger Codegen https://github.com/swagger-api/swagger-codegen) # noqa: E501
|
|
7
|
+
|
|
8
|
+
OpenAPI spec version: version not set
|
|
9
|
+
|
|
10
|
+
Generated by: https://github.com/swagger-api/swagger-codegen.git
|
|
11
|
+
|
|
12
|
+
NOTE
|
|
13
|
+
----
|
|
14
|
+
standard swagger-codegen-cli for this python client has been modified
|
|
15
|
+
by custom templates. The purpose of these templates is to include
|
|
16
|
+
typing information in the API and Model code. Please refer to the
|
|
17
|
+
main grid repository for more info
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import pprint
|
|
21
|
+
import re # noqa: F401
|
|
22
|
+
|
|
23
|
+
from typing import TYPE_CHECKING
|
|
24
|
+
|
|
25
|
+
import six
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from datetime import datetime
|
|
29
|
+
from lightning_sdk.lightning_cloud.openapi.models import *
|
|
30
|
+
|
|
31
|
+
class V1ValidateDeploymentImageResponse(object):
|
|
32
|
+
"""NOTE: This class is auto generated by the swagger code generator program.
|
|
33
|
+
|
|
34
|
+
Do not edit the class manually.
|
|
35
|
+
"""
|
|
36
|
+
"""
|
|
37
|
+
Attributes:
|
|
38
|
+
swagger_types (dict): The key is attribute name
|
|
39
|
+
and the value is attribute type.
|
|
40
|
+
attribute_map (dict): The key is attribute name
|
|
41
|
+
and the value is json key in definition.
|
|
42
|
+
"""
|
|
43
|
+
swagger_types = {
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
attribute_map = {
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def __init__(self): # noqa: E501
|
|
50
|
+
"""V1ValidateDeploymentImageResponse - a model defined in Swagger""" # noqa: E501
|
|
51
|
+
self.discriminator = None
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict:
|
|
54
|
+
"""Returns the model properties as a dict"""
|
|
55
|
+
result = {}
|
|
56
|
+
|
|
57
|
+
for attr, _ in six.iteritems(self.swagger_types):
|
|
58
|
+
value = getattr(self, attr)
|
|
59
|
+
if isinstance(value, list):
|
|
60
|
+
result[attr] = list(map(
|
|
61
|
+
lambda x: x.to_dict() if hasattr(x, "to_dict") else x,
|
|
62
|
+
value
|
|
63
|
+
))
|
|
64
|
+
elif hasattr(value, "to_dict"):
|
|
65
|
+
result[attr] = value.to_dict()
|
|
66
|
+
elif isinstance(value, dict):
|
|
67
|
+
result[attr] = dict(map(
|
|
68
|
+
lambda item: (item[0], item[1].to_dict())
|
|
69
|
+
if hasattr(item[1], "to_dict") else item,
|
|
70
|
+
value.items()
|
|
71
|
+
))
|
|
72
|
+
else:
|
|
73
|
+
result[attr] = value
|
|
74
|
+
if issubclass(V1ValidateDeploymentImageResponse, dict):
|
|
75
|
+
for key, value in self.items():
|
|
76
|
+
result[key] = value
|
|
77
|
+
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
def to_str(self) -> str:
|
|
81
|
+
"""Returns the string representation of the model"""
|
|
82
|
+
return pprint.pformat(self.to_dict())
|
|
83
|
+
|
|
84
|
+
def __repr__(self) -> str:
|
|
85
|
+
"""For `print` and `pprint`"""
|
|
86
|
+
return self.to_str()
|
|
87
|
+
|
|
88
|
+
def __eq__(self, other: 'V1ValidateDeploymentImageResponse') -> bool:
|
|
89
|
+
"""Returns true if both objects are equal"""
|
|
90
|
+
if not isinstance(other, V1ValidateDeploymentImageResponse):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
return self.__dict__ == other.__dict__
|
|
94
|
+
|
|
95
|
+
def __ne__(self, other: 'V1ValidateDeploymentImageResponse') -> bool:
|
|
96
|
+
"""Returns true if both objects are not equal"""
|
|
97
|
+
return not self == other
|
|
@@ -31,6 +31,7 @@ from lightning_sdk.lightning_cloud.openapi import (
|
|
|
31
31
|
AssistantsServiceApi,
|
|
32
32
|
StorageServiceApi,
|
|
33
33
|
DeploymentTemplatesServiceApi,
|
|
34
|
+
ModelsStoreApi,
|
|
34
35
|
)
|
|
35
36
|
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
|
|
36
37
|
from lightning_sdk.lightning_cloud.source_code.logs_socket_api import LightningLogsSocketAPI
|
|
@@ -89,6 +90,7 @@ class GridRestClient(
|
|
|
89
90
|
AssistantsServiceApi,
|
|
90
91
|
StorageServiceApi,
|
|
91
92
|
DeploymentTemplatesServiceApi,
|
|
93
|
+
ModelsStoreApi,
|
|
92
94
|
):
|
|
93
95
|
|
|
94
96
|
def __init__(self, api_client: Optional[ApiClient] = None):
|
|
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
|
|
11
11
|
|
|
12
12
|
from lightning_sdk.job.base import _BaseJob
|
|
13
13
|
from lightning_sdk.job.job import Job
|
|
14
|
+
from lightning_sdk.utils.resolve import _resolve_deprecated_cluster
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class _BaseMMT(_BaseJob):
|
|
@@ -26,16 +27,19 @@ class _BaseMMT(_BaseJob):
|
|
|
26
27
|
teamspace: Union[str, "Teamspace", None] = None,
|
|
27
28
|
org: Union[str, "Organization", None] = None,
|
|
28
29
|
user: Union[str, "User", None] = None,
|
|
29
|
-
|
|
30
|
+
cloud_account: Optional[str] = None,
|
|
30
31
|
env: Optional[Dict[str, str]] = None,
|
|
31
32
|
interruptible: bool = False,
|
|
32
33
|
image_credentials: Optional[str] = None,
|
|
33
|
-
|
|
34
|
+
cloud_account_auth: bool = False,
|
|
34
35
|
artifacts_local: Optional[str] = None,
|
|
35
36
|
artifacts_remote: Optional[str] = None,
|
|
37
|
+
cluster: Optional[str] = None, # deprecated in favor of cloud_account
|
|
36
38
|
) -> "_BaseMMT":
|
|
37
39
|
from lightning_sdk.studio import Studio
|
|
38
40
|
|
|
41
|
+
cloud_account = _resolve_deprecated_cluster(cloud_account, cluster)
|
|
42
|
+
|
|
39
43
|
if num_machines <= 1:
|
|
40
44
|
raise ValueError("Multi-Machine training cannot be run with less than 2 Machines")
|
|
41
45
|
|
|
@@ -44,7 +48,9 @@ class _BaseMMT(_BaseJob):
|
|
|
44
48
|
|
|
45
49
|
if image is None:
|
|
46
50
|
if not isinstance(studio, Studio):
|
|
47
|
-
studio = Studio(
|
|
51
|
+
studio = Studio(
|
|
52
|
+
name=studio, teamspace=teamspace, org=org, user=user, cloud_account=cloud_account, create_ok=False
|
|
53
|
+
)
|
|
48
54
|
|
|
49
55
|
# studio is a Studio instance at this point
|
|
50
56
|
if teamspace is None:
|
|
@@ -58,20 +64,20 @@ class _BaseMMT(_BaseJob):
|
|
|
58
64
|
"Can only run jobs with Studio envs in the teamspace of that Studio."
|
|
59
65
|
)
|
|
60
66
|
|
|
61
|
-
if
|
|
62
|
-
|
|
67
|
+
if cloud_account is None:
|
|
68
|
+
cloud_account = studio.cloud_account
|
|
63
69
|
|
|
64
|
-
if
|
|
70
|
+
if cloud_account != studio.cloud_account:
|
|
65
71
|
raise ValueError(
|
|
66
|
-
"Studio
|
|
67
|
-
"Can only run jobs with Studio envs in the same
|
|
72
|
+
"Studio cloud_account does not match provided cloud_account. "
|
|
73
|
+
"Can only run jobs with Studio envs in the same cloud_account."
|
|
68
74
|
)
|
|
69
75
|
|
|
70
76
|
if image_credentials is not None:
|
|
71
77
|
raise ValueError("image_credentials is only supported when using a custom image")
|
|
72
78
|
|
|
73
|
-
if
|
|
74
|
-
raise ValueError("
|
|
79
|
+
if cloud_account_auth:
|
|
80
|
+
raise ValueError("cloud_account_auth is only supported when using a custom image")
|
|
75
81
|
|
|
76
82
|
if artifacts_local is not None or artifacts_remote is not None:
|
|
77
83
|
raise ValueError(
|
|
@@ -99,14 +105,14 @@ class _BaseMMT(_BaseJob):
|
|
|
99
105
|
inst._submit(
|
|
100
106
|
num_machines=num_machines,
|
|
101
107
|
machine=machine,
|
|
102
|
-
|
|
108
|
+
cloud_account=cloud_account,
|
|
103
109
|
command=command,
|
|
104
110
|
studio=studio,
|
|
105
111
|
image=image,
|
|
106
112
|
env=env,
|
|
107
113
|
interruptible=interruptible,
|
|
108
114
|
image_credentials=image_credentials,
|
|
109
|
-
|
|
115
|
+
cloud_account_auth=cloud_account_auth,
|
|
110
116
|
artifacts_local=artifacts_local,
|
|
111
117
|
artifacts_remote=artifacts_remote,
|
|
112
118
|
)
|
|
@@ -122,9 +128,9 @@ class _BaseMMT(_BaseJob):
|
|
|
122
128
|
image: Optional[str] = None,
|
|
123
129
|
env: Optional[Dict[str, str]] = None,
|
|
124
130
|
interruptible: bool = False,
|
|
125
|
-
|
|
131
|
+
cloud_account: Optional[str] = None,
|
|
126
132
|
image_credentials: Optional[str] = None,
|
|
127
|
-
|
|
133
|
+
cloud_account_auth: bool = False,
|
|
128
134
|
artifacts_local: Optional[str] = None,
|
|
129
135
|
artifacts_remote: Optional[str] = None,
|
|
130
136
|
) -> None:
|
|
@@ -1,12 +1,11 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Tuple, Union
|
|
2
2
|
|
|
3
|
-
from lightning_sdk._mmt.base import _BaseMMT
|
|
4
|
-
from lightning_sdk._mmt.v1 import _MMTV1
|
|
5
|
-
from lightning_sdk._mmt.v2 import _MMTV2
|
|
6
3
|
from lightning_sdk.job.job import _has_jobs_v2
|
|
4
|
+
from lightning_sdk.mmt.base import _BaseMMT
|
|
5
|
+
from lightning_sdk.mmt.v1 import _MMTV1
|
|
6
|
+
from lightning_sdk.mmt.v2 import _MMTV2
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
|
-
from lightning_sdk.job import Job
|
|
10
9
|
from lightning_sdk.machine import Machine
|
|
11
10
|
from lightning_sdk.organization import Organization
|
|
12
11
|
from lightning_sdk.status import Status
|
|
@@ -15,7 +14,31 @@ if TYPE_CHECKING:
|
|
|
15
14
|
from lightning_sdk.user import User
|
|
16
15
|
|
|
17
16
|
|
|
17
|
+
class MMTMachine(Protocol):
|
|
18
|
+
"""A single machine in multi-machine training."""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def name(self) -> str:
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def machine(self) -> "Machine":
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def artifact_path(self) -> Optional[str]:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def status(self) -> "Status":
|
|
34
|
+
...
|
|
35
|
+
|
|
36
|
+
|
|
18
37
|
class MMT(_BaseMMT):
|
|
38
|
+
_force_v1: (
|
|
39
|
+
bool
|
|
40
|
+
) = False # required for studio plugin still working correctly as v2 currently does not support the studio env
|
|
41
|
+
|
|
19
42
|
def __init__(
|
|
20
43
|
self,
|
|
21
44
|
name: str,
|
|
@@ -25,7 +48,7 @@ class MMT(_BaseMMT):
|
|
|
25
48
|
*,
|
|
26
49
|
_fetch_job: bool = True,
|
|
27
50
|
) -> None:
|
|
28
|
-
internal_mmt_cls = _MMTV2 if _has_jobs_v2() else _MMTV1
|
|
51
|
+
internal_mmt_cls = _MMTV2 if _has_jobs_v2() and not self._force_v1 else _MMTV1
|
|
29
52
|
|
|
30
53
|
self._internal_mmt = internal_mmt_cls(
|
|
31
54
|
name=name,
|
|
@@ -47,13 +70,14 @@ class MMT(_BaseMMT):
|
|
|
47
70
|
teamspace: Union[str, "Teamspace", None] = None,
|
|
48
71
|
org: Union[str, "Organization", None] = None,
|
|
49
72
|
user: Union[str, "User", None] = None,
|
|
50
|
-
|
|
73
|
+
cloud_account: Optional[str] = None,
|
|
51
74
|
env: Optional[Dict[str, str]] = None,
|
|
52
75
|
interruptible: bool = False,
|
|
53
76
|
image_credentials: Optional[str] = None,
|
|
54
|
-
|
|
77
|
+
cloud_account_auth: bool = False,
|
|
55
78
|
artifacts_local: Optional[str] = None,
|
|
56
79
|
artifacts_remote: Optional[str] = None,
|
|
80
|
+
cluster: Optional[str] = None, # deprecated in favor of cloud_account
|
|
57
81
|
) -> "MMT":
|
|
58
82
|
ret_val = super().run(
|
|
59
83
|
name=name,
|
|
@@ -65,13 +89,14 @@ class MMT(_BaseMMT):
|
|
|
65
89
|
teamspace=teamspace,
|
|
66
90
|
org=org,
|
|
67
91
|
user=user,
|
|
68
|
-
|
|
92
|
+
cloud_account=cloud_account,
|
|
69
93
|
env=env,
|
|
70
94
|
interruptible=interruptible,
|
|
71
95
|
image_credentials=image_credentials,
|
|
72
|
-
|
|
96
|
+
cloud_account_auth=cloud_account_auth,
|
|
73
97
|
artifacts_local=artifacts_local,
|
|
74
98
|
artifacts_remote=artifacts_remote,
|
|
99
|
+
cluster=cluster, # deprecated in favor of cloud_account
|
|
75
100
|
)
|
|
76
101
|
# required for typing with "Job"
|
|
77
102
|
assert isinstance(ret_val, cls)
|
|
@@ -86,23 +111,23 @@ class MMT(_BaseMMT):
|
|
|
86
111
|
image: Optional[str] = None,
|
|
87
112
|
env: Optional[Dict[str, str]] = None,
|
|
88
113
|
interruptible: bool = False,
|
|
89
|
-
|
|
114
|
+
cloud_account: Optional[str] = None,
|
|
90
115
|
image_credentials: Optional[str] = None,
|
|
91
|
-
|
|
116
|
+
cloud_account_auth: bool = False,
|
|
92
117
|
artifacts_local: Optional[str] = None,
|
|
93
118
|
artifacts_remote: Optional[str] = None,
|
|
94
119
|
) -> "MMT":
|
|
95
120
|
self._job = self._internal_mmt._submit(
|
|
96
121
|
num_machines=num_machines,
|
|
97
122
|
machine=machine,
|
|
98
|
-
|
|
123
|
+
cloud_account=cloud_account,
|
|
99
124
|
command=command,
|
|
100
125
|
studio=studio,
|
|
101
126
|
image=image,
|
|
102
127
|
env=env,
|
|
103
128
|
interruptible=interruptible,
|
|
104
129
|
image_credentials=image_credentials,
|
|
105
|
-
|
|
130
|
+
cloud_account_auth=cloud_account_auth,
|
|
106
131
|
artifacts_local=artifacts_local,
|
|
107
132
|
artifacts_remote=artifacts_remote,
|
|
108
133
|
)
|
|
@@ -119,7 +144,7 @@ class MMT(_BaseMMT):
|
|
|
119
144
|
return self._internal_mmt.status
|
|
120
145
|
|
|
121
146
|
@property
|
|
122
|
-
def machines(self) -> Tuple[
|
|
147
|
+
def machines(self) -> Tuple[MMTMachine, ...]:
|
|
123
148
|
return self._internal_mmt.machines
|
|
124
149
|
|
|
125
150
|
@property
|
|
@@ -150,8 +175,8 @@ class MMT(_BaseMMT):
|
|
|
150
175
|
return self._internal_mmt._teamspace
|
|
151
176
|
|
|
152
177
|
@property
|
|
153
|
-
def
|
|
154
|
-
return self._internal_mmt.
|
|
178
|
+
def cloud_account(self) -> Optional[str]:
|
|
179
|
+
return self._internal_mmt.cloud_account
|
|
155
180
|
|
|
156
181
|
def __getattr__(self, key: str) -> Any:
|
|
157
182
|
"""Forward the attribute lookup to the internal job implementation."""
|
|
@@ -159,3 +184,7 @@ class MMT(_BaseMMT):
|
|
|
159
184
|
return getattr(super(), key)
|
|
160
185
|
except AttributeError:
|
|
161
186
|
return getattr(self._internal_mmt, key)
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def _guaranteed_job(self) -> Any:
|
|
190
|
+
return self._internal_mmt._guaranteed_job
|
lightning_sdk/mmt/v1.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
from lightning_sdk.api.mmt_api import MMTApiV1
|
|
4
|
+
from lightning_sdk.job.v1 import _internal_status_to_external_status
|
|
5
|
+
from lightning_sdk.job.work import Work
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from lightning_sdk.machine import Machine
|
|
9
|
+
from lightning_sdk.organization import Organization
|
|
10
|
+
from lightning_sdk.status import Status
|
|
11
|
+
from lightning_sdk.studio import Studio
|
|
12
|
+
from lightning_sdk.teamspace import Teamspace
|
|
13
|
+
from lightning_sdk.user import User
|
|
14
|
+
|
|
15
|
+
from lightning_sdk.mmt.base import _BaseMMT
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class _MMTV1(_BaseMMT):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
name: str,
|
|
22
|
+
teamspace: Union[str, "Teamspace", None] = None,
|
|
23
|
+
org: Union[str, "Organization", None] = None,
|
|
24
|
+
user: Union[str, "User", None] = None,
|
|
25
|
+
*,
|
|
26
|
+
_fetch_job: bool = True,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._job_api = MMTApiV1()
|
|
29
|
+
super().__init__(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=_fetch_job)
|
|
30
|
+
|
|
31
|
+
def _submit(
|
|
32
|
+
self,
|
|
33
|
+
num_machines: int,
|
|
34
|
+
machine: "Machine",
|
|
35
|
+
command: Optional[str] = None,
|
|
36
|
+
studio: Optional["Studio"] = None,
|
|
37
|
+
image: Optional[str] = None,
|
|
38
|
+
env: Optional[Dict[str, str]] = None,
|
|
39
|
+
interruptible: bool = False,
|
|
40
|
+
cloud_account: Optional[str] = None,
|
|
41
|
+
image_credentials: Optional[str] = None,
|
|
42
|
+
cloud_account_auth: bool = False,
|
|
43
|
+
artifacts_local: Optional[str] = None,
|
|
44
|
+
artifacts_remote: Optional[str] = None,
|
|
45
|
+
) -> "_MMTV1":
|
|
46
|
+
if studio is None:
|
|
47
|
+
raise ValueError("Studio is required for submitting jobs")
|
|
48
|
+
if image is not None or image_credentials is not None or cloud_account_auth:
|
|
49
|
+
raise ValueError("Image is not supported for submitting jobs")
|
|
50
|
+
|
|
51
|
+
if artifacts_local is not None or artifacts_remote is not None:
|
|
52
|
+
raise ValueError("Specifying how to persist artifacts is not yet supported with jobs")
|
|
53
|
+
|
|
54
|
+
if env is not None:
|
|
55
|
+
raise ValueError("Environment variables are not supported for submitting jobs")
|
|
56
|
+
if command is None:
|
|
57
|
+
raise ValueError("Command is required for submitting multi-machine jobs")
|
|
58
|
+
|
|
59
|
+
_submitted = self._job_api.submit_job(
|
|
60
|
+
name=self._name,
|
|
61
|
+
num_machines=num_machines,
|
|
62
|
+
command=command,
|
|
63
|
+
studio_id=studio._studio.id,
|
|
64
|
+
teamspace_id=self._teamspace.id,
|
|
65
|
+
cloud_account=cloud_account or "",
|
|
66
|
+
machine=machine,
|
|
67
|
+
interruptible=interruptible,
|
|
68
|
+
strategy="parallel",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self._name = _submitted.name
|
|
72
|
+
self._job = _submitted
|
|
73
|
+
return self
|
|
74
|
+
|
|
75
|
+
def _update_internal_job(self) -> None:
|
|
76
|
+
try:
|
|
77
|
+
self._job = self._job_api.get_job(self._name, self.teamspace.id)
|
|
78
|
+
except ValueError as e:
|
|
79
|
+
raise ValueError(f"Job {self._name} does not exist in Teamspace {self.teamspace.name}") from e
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def machines(self) -> Tuple["Work", ...]:
|
|
83
|
+
works = self._job_api.list_works(self._guaranteed_job.id, self.teamspace.id)
|
|
84
|
+
|
|
85
|
+
return tuple(Work(w.id, self, self.teamspace) for w in works)
|
|
86
|
+
|
|
87
|
+
def stop(self) -> None:
|
|
88
|
+
self._job_api.stop_job(self._guaranteed_job.id, self.teamspace.id)
|
|
89
|
+
|
|
90
|
+
def delete(self) -> None:
|
|
91
|
+
self._job_api.delete_job(self._guaranteed_job.id, self.teamspace.id)
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def status(self) -> "Status":
|
|
95
|
+
try:
|
|
96
|
+
status = self._job_api.get_job_status(self._job.id, self.teamspace.id)
|
|
97
|
+
return _internal_status_to_external_status(status)
|
|
98
|
+
except Exception:
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
f"MMT {self._name} does not exist in Teamspace {self.teamspace.name}. Did you delete it?"
|
|
101
|
+
) from None
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def artifact_path(self) -> Optional[str]:
|
|
105
|
+
return f"/teamspace/jobs/{self.name}"
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def snapshot_path(self) -> Optional[str]:
|
|
109
|
+
return f"/teamspace/jobs/{self.name}/snapshot"
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def machine(self) -> "Machine":
|
|
113
|
+
return self.machines[0].machine
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def name(self) -> str:
|
|
117
|
+
return self._name
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def teamspace(self) -> "Teamspace":
|
|
121
|
+
return self._teamspace
|
|
122
|
+
|
|
123
|
+
# the following and functions are solely to make the Work class function
|
|
124
|
+
@property
|
|
125
|
+
def _id(self) -> str:
|
|
126
|
+
return self._guaranteed_job.id
|
|
127
|
+
|
|
128
|
+
def _name_filter(self, name: str) -> str:
|
|
129
|
+
return name.replace("root.", "")
|
lightning_sdk/{_mmt → mmt}/v2.py
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
-
from lightning_sdk.api.mmt_api import
|
|
3
|
+
from lightning_sdk.api.mmt_api import MMTApiV2
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
6
|
from lightning_sdk.job.job import Job
|
|
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|
|
11
11
|
from lightning_sdk.teamspace import Teamspace
|
|
12
12
|
from lightning_sdk.user import User
|
|
13
13
|
|
|
14
|
-
from lightning_sdk.
|
|
14
|
+
from lightning_sdk.mmt.base import _BaseMMT
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
class _MMTV2(_BaseMMT):
|
|
@@ -24,7 +24,7 @@ class _MMTV2(_BaseMMT):
|
|
|
24
24
|
*,
|
|
25
25
|
_fetch_job: bool = True,
|
|
26
26
|
) -> None:
|
|
27
|
-
self._job_api =
|
|
27
|
+
self._job_api = MMTApiV2()
|
|
28
28
|
super().__init__(name=name, teamspace=teamspace, org=org, user=user, _fetch_job=_fetch_job)
|
|
29
29
|
|
|
30
30
|
def _submit(
|
|
@@ -36,9 +36,9 @@ class _MMTV2(_BaseMMT):
|
|
|
36
36
|
image: Optional[str] = None,
|
|
37
37
|
env: Optional[Dict[str, str]] = None,
|
|
38
38
|
interruptible: bool = False,
|
|
39
|
-
|
|
39
|
+
cloud_account: Optional[str] = None,
|
|
40
40
|
image_credentials: Optional[str] = None,
|
|
41
|
-
|
|
41
|
+
cloud_account_auth: bool = False,
|
|
42
42
|
artifacts_local: Optional[str] = None,
|
|
43
43
|
artifacts_remote: Optional[str] = None,
|
|
44
44
|
) -> "_MMTV2":
|
|
@@ -62,7 +62,7 @@ class _MMTV2(_BaseMMT):
|
|
|
62
62
|
name=self.name,
|
|
63
63
|
num_machines=num_machines,
|
|
64
64
|
command=command,
|
|
65
|
-
|
|
65
|
+
cloud_account=cloud_account,
|
|
66
66
|
teamspace_id=self._teamspace.id,
|
|
67
67
|
studio_id=studio_id,
|
|
68
68
|
image=image,
|
|
@@ -70,7 +70,7 @@ class _MMTV2(_BaseMMT):
|
|
|
70
70
|
interruptible=interruptible,
|
|
71
71
|
env=env,
|
|
72
72
|
image_credentials=image_credentials,
|
|
73
|
-
|
|
73
|
+
cloud_account_auth=cloud_account_auth,
|
|
74
74
|
artifacts_local=artifacts_local,
|
|
75
75
|
artifacts_remote=artifacts_remote,
|
|
76
76
|
)
|
|
@@ -80,7 +80,12 @@ class _MMTV2(_BaseMMT):
|
|
|
80
80
|
|
|
81
81
|
@property
|
|
82
82
|
def machines(self) -> Tuple["Job", ...]:
|
|
83
|
-
|
|
83
|
+
from lightning_sdk.job import Job
|
|
84
|
+
|
|
85
|
+
return tuple(
|
|
86
|
+
Job(name=j.name, teamspace=self.teamspace)
|
|
87
|
+
for j in self._job_api.list_mmt_subjobs(self._guaranteed_job.id, self.teamspace.id)
|
|
88
|
+
)
|
|
84
89
|
|
|
85
90
|
def stop(self) -> None:
|
|
86
91
|
self._job_api.stop_job(job_id=self._guaranteed_job.id, teamspace_id=self._teamspace.id)
|
|
@@ -97,28 +102,18 @@ class _MMTV2(_BaseMMT):
|
|
|
97
102
|
self._update_internal_job()
|
|
98
103
|
return self._job
|
|
99
104
|
|
|
100
|
-
@property
|
|
101
|
-
def _guaranteed_job(self) -> Any:
|
|
102
|
-
"""Guarantees that the job was fetched at some point before returning it.
|
|
103
|
-
|
|
104
|
-
Doesn't guarantee to have the lastest version of the job. Use _latest_job for that.
|
|
105
|
-
"""
|
|
106
|
-
if getattr(self, "_job", None) is None:
|
|
107
|
-
self._update_internal_job()
|
|
108
|
-
|
|
109
|
-
return self._job
|
|
110
|
-
|
|
111
105
|
@property
|
|
112
106
|
def status(self) -> "Status":
|
|
113
|
-
|
|
114
|
-
return self._job_api._job_state_to_external(self._latest_job.desired_state)
|
|
107
|
+
return self._job_api._job_state_to_external(self._latest_job.state)
|
|
115
108
|
|
|
116
109
|
@property
|
|
117
110
|
def artifact_path(self) -> Optional[str]:
|
|
111
|
+
# TODO: Since grouping for those is not done yet on the BE, we cannot yet have a unified link here
|
|
118
112
|
raise NotImplementedError
|
|
119
113
|
|
|
120
114
|
@property
|
|
121
115
|
def snapshot_path(self) -> Optional[str]:
|
|
116
|
+
# TODO: Since grouping for those is not done yet on the BE, we cannot yet have a unified link here
|
|
122
117
|
raise NotImplementedError
|
|
123
118
|
|
|
124
119
|
@property
|