lightning-sdk 2025.7.17__py3-none-any.whl → 2025.7.30rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. lightning_sdk/__init__.py +3 -2
  2. lightning_sdk/api/cloud_account_api.py +204 -0
  3. lightning_sdk/api/deployment_api.py +11 -0
  4. lightning_sdk/api/job_api.py +82 -10
  5. lightning_sdk/api/llm_api.py +1 -1
  6. lightning_sdk/api/mmt_api.py +44 -5
  7. lightning_sdk/api/pipeline_api.py +4 -3
  8. lightning_sdk/api/studio_api.py +51 -8
  9. lightning_sdk/api/utils.py +6 -2
  10. lightning_sdk/cli/clusters_menu.py +3 -3
  11. lightning_sdk/cli/create.py +25 -11
  12. lightning_sdk/cli/deploy/_auth.py +19 -3
  13. lightning_sdk/cli/deploy/serve.py +21 -5
  14. lightning_sdk/cli/download.py +25 -1
  15. lightning_sdk/cli/entrypoint.py +4 -2
  16. lightning_sdk/cli/list.py +5 -1
  17. lightning_sdk/cli/run.py +3 -1
  18. lightning_sdk/cli/start.py +40 -8
  19. lightning_sdk/cli/switch.py +3 -1
  20. lightning_sdk/deployment/deployment.py +8 -0
  21. lightning_sdk/job/base.py +27 -3
  22. lightning_sdk/job/job.py +28 -4
  23. lightning_sdk/job/v1.py +10 -1
  24. lightning_sdk/job/v2.py +22 -2
  25. lightning_sdk/job/work.py +5 -1
  26. lightning_sdk/lightning_cloud/openapi/__init__.py +14 -1
  27. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +428 -0
  28. lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +153 -48
  29. lightning_sdk/lightning_cloud/openapi/api/cloudy_service_api.py +295 -0
  30. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +93 -0
  31. lightning_sdk/lightning_cloud/openapi/models/__init__.py +14 -1
  32. lightning_sdk/lightning_cloud/openapi/models/agentmanagedendpoints_id_body.py +27 -1
  33. lightning_sdk/lightning_cloud/openapi/models/blogposts_id_body.py +53 -1
  34. lightning_sdk/lightning_cloud/openapi/models/conversations_id_body1.py +123 -0
  35. lightning_sdk/lightning_cloud/openapi/models/messages_id_body.py +123 -0
  36. lightning_sdk/lightning_cloud/openapi/models/metricsstream_id_body.py +27 -1
  37. lightning_sdk/lightning_cloud/openapi/models/project_id_schedules_body.py +81 -3
  38. lightning_sdk/lightning_cloud/openapi/models/schedules_id_body.py +79 -1
  39. lightning_sdk/lightning_cloud/openapi/models/user_id_upgradetrigger_body.py +201 -0
  40. lightning_sdk/lightning_cloud/openapi/models/user_user_id_body.py +201 -0
  41. lightning_sdk/lightning_cloud/openapi/models/v1_billing_subscription.py +27 -1
  42. lightning_sdk/lightning_cloud/openapi/models/v1_blog_post.py +53 -1
  43. lightning_sdk/lightning_cloud/openapi/models/v1_cloudy_settings.py +227 -0
  44. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  45. lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
  46. lightning_sdk/lightning_cloud/openapi/models/v1_conversation_response_chunk.py +27 -1
  47. lightning_sdk/lightning_cloud/openapi/models/v1_create_billing_upgrade_trigger_record_response.py +97 -0
  48. lightning_sdk/lightning_cloud/openapi/models/v1_create_blog_post_request.py +53 -1
  49. lightning_sdk/lightning_cloud/openapi/models/v1_create_checkout_session_request.py +27 -1
  50. lightning_sdk/lightning_cloud/openapi/models/v1_create_subscription_checkout_session_request.py +55 -3
  51. lightning_sdk/lightning_cloud/openapi/models/v1_function_call.py +149 -0
  52. lightning_sdk/lightning_cloud/openapi/models/{v1_get_clickhouse_assistant_session_daily_aggregated_response.py → v1_get_assistant_session_daily_aggregated_response.py} +22 -22
  53. lightning_sdk/lightning_cloud/openapi/models/v1_get_cluster_health_response.py +149 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_get_user_response.py +27 -1
  55. lightning_sdk/lightning_cloud/openapi/models/v1_job_spec.py +27 -1
  56. lightning_sdk/lightning_cloud/openapi/models/v1_kubernetes_direct_v1.py +105 -1
  57. lightning_sdk/lightning_cloud/openapi/models/v1_like_status.py +104 -0
  58. lightning_sdk/lightning_cloud/openapi/models/v1_list_published_managed_endpoints_response.py +123 -0
  59. lightning_sdk/lightning_cloud/openapi/models/v1_managed_endpoint.py +27 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +95 -17
  61. lightning_sdk/lightning_cloud/openapi/models/v1_message.py +27 -1
  62. lightning_sdk/lightning_cloud/openapi/models/v1_quote_subscription_response.py +27 -1
  63. lightning_sdk/lightning_cloud/openapi/models/v1_resource_visibility.py +27 -1
  64. lightning_sdk/lightning_cloud/openapi/models/v1_response_choice.py +29 -3
  65. lightning_sdk/lightning_cloud/openapi/models/v1_schedule.py +79 -1
  66. lightning_sdk/lightning_cloud/openapi/models/v1_service_health.py +27 -1
  67. lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1.py +79 -1
  68. lightning_sdk/lightning_cloud/openapi/models/v1_slurm_v1_status.py +79 -1
  69. lightning_sdk/lightning_cloud/openapi/models/v1_tool_call.py +175 -0
  70. lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_like_response.py +149 -0
  71. lightning_sdk/lightning_cloud/openapi/models/v1_update_conversation_message_like_response.py +149 -0
  72. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +79 -313
  73. lightning_sdk/lightning_cloud/openapi/models/v1_volume_state.py +1 -0
  74. lightning_sdk/llm/llm.py +69 -11
  75. lightning_sdk/llm/public_assistants.json +32 -8
  76. lightning_sdk/machine.py +151 -43
  77. lightning_sdk/mmt/base.py +20 -2
  78. lightning_sdk/mmt/mmt.py +25 -3
  79. lightning_sdk/mmt/v1.py +7 -1
  80. lightning_sdk/mmt/v2.py +27 -3
  81. lightning_sdk/models.py +1 -1
  82. lightning_sdk/organization.py +4 -0
  83. lightning_sdk/pipeline/pipeline.py +16 -5
  84. lightning_sdk/pipeline/printer.py +5 -3
  85. lightning_sdk/pipeline/schedule.py +844 -1
  86. lightning_sdk/pipeline/steps.py +19 -4
  87. lightning_sdk/sandbox.py +4 -1
  88. lightning_sdk/serve.py +2 -0
  89. lightning_sdk/studio.py +91 -44
  90. lightning_sdk/teamspace.py +19 -10
  91. lightning_sdk/utils/resolve.py +37 -2
  92. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/METADATA +7 -5
  93. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/RECORD +98 -85
  94. lightning_sdk/api/cluster_api.py +0 -119
  95. /lightning_sdk/cli/{inspect.py → inspection.py} +0 -0
  96. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/LICENSE +0 -0
  97. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/WHEEL +0 -0
  98. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/entry_points.txt +0 -0
  99. {lightning_sdk-2025.7.17.dist-info → lightning_sdk-2025.7.30rc0.dist-info}/top_level.txt +0 -0
lightning_sdk/llm/llm.py CHANGED
@@ -2,6 +2,7 @@ import json
2
2
  import os
3
3
  from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
+ from lightning_sdk.api import TeamspaceApi, UserApi
5
6
  from lightning_sdk.api.llm_api import LLMApi
6
7
  from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
7
8
 
@@ -13,7 +14,7 @@ PUBLIC_MODEL_PROVIDERS: Dict[str, str] = {
13
14
  }
14
15
 
15
16
 
16
- def _load_public_assistants() -> Dict[str, str]:
17
+ def _load_public_assistants() -> Dict[str, Dict[str, Any]]:
17
18
  """Load public assistants from a JSON file."""
18
19
  try:
19
20
  json_path = os.path.join(os.path.dirname(__file__), "public_assistants.json")
@@ -28,7 +29,7 @@ class LLM:
28
29
  _auth_info_cached: ClassVar[bool] = False
29
30
  _cached_auth_info: ClassVar[Dict[str, Optional[str]]] = {}
30
31
  _llm_api_cache: ClassVar[Dict[Optional[str], LLMApi]] = {}
31
- _public_assistants: ClassVar[Optional[Dict[str, str]]] = None
32
+ _public_assistants: ClassVar[Optional[Dict[str, Dict[str, Any]]]] = None
32
33
 
33
34
  def __new__(cls, name: str, teamspace: Optional[str] = None, enable_async: Optional[bool] = False) -> "LLM":
34
35
  return super().__new__(cls)
@@ -55,8 +56,18 @@ class LLM:
55
56
  Raises:
56
57
  ValueError: If teamspace information cannot be resolved.
57
58
  """
58
- # TODO support user input teamspace
59
- self._get_auth_info()
59
+ teamspace_name = None
60
+ if teamspace:
61
+ try:
62
+ owner, teamspace_name = teamspace.split("/", maxsplit=1)
63
+ except ValueError as e:
64
+ raise ValueError(
65
+ f"Invalid teamspace format: '{teamspace}'. "
66
+ "Teamspace should be specified as '{teamspace_owner}/{teamspace_name}' "
67
+ "(e.g., 'my-org/my-teamspace')."
68
+ ) from e
69
+
70
+ self._get_auth_info(teamspace_name)
60
71
 
61
72
  self._model_provider, self._model_name = self._parse_model_name(name)
62
73
  self._enable_async = enable_async
@@ -66,6 +77,7 @@ class LLM:
66
77
  LLM._llm_api_cache[teamspace] = LLMApi()
67
78
  self._llm_api = LLM._llm_api_cache[teamspace]
68
79
 
80
+ self._context_length = None
69
81
  self._model_id = self._get_model_id()
70
82
  self._conversations = {}
71
83
 
@@ -77,14 +89,45 @@ class LLM:
77
89
  def provider(self) -> str:
78
90
  return self._model_provider
79
91
 
80
- def _get_auth_info(self) -> None:
92
+ def context_length(self, model: Optional[str] = None) -> Optional[int]:
93
+ if model is None:
94
+ return self._context_length
95
+
96
+ context_info = self._public_assistants.get(model)
97
+ if context_info is None or "context_length" not in context_info:
98
+ raise ValueError(f"Cannot access context length of model '{model}'.")
99
+
100
+ return int(context_info["context_length"])
101
+
102
+ def _get_auth_info(self, teamspace_name: Optional[str] = None) -> None:
103
+ # TODO: Validate user input teamspace name
81
104
  if not LLM._auth_info_cached:
82
- teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
83
105
  if teamspace_name is None:
84
- raise ValueError(
85
- "Teamspace name must be provided either through "
86
- "the environment variable LIGHTNING_TEAMSPACE or as an argument."
87
- )
106
+ # studio users
107
+ teamspace_name = os.environ.get("LIGHTNING_TEAMSPACE", None)
108
+
109
+ if teamspace_name is None:
110
+ # local users with no given teamspace
111
+ try:
112
+ teamspace_api = TeamspaceApi()
113
+ user_api = UserApi()
114
+ authed_user = user_api._client.auth_service_get_user()
115
+ default_teamspace = teamspace_api.list_teamspaces(owner_id=authed_user.id)[0]
116
+ teamspace_name = default_teamspace.name
117
+ teamspace_id = default_teamspace.id
118
+ os.environ["LIGHTNING_CLOUD_PROJECT_ID"] = teamspace_id
119
+ os.environ["LIGHTNING_TEAMSPACE"] = teamspace_name
120
+ except Exception as err:
121
+ # throw an appropriate error that guides users to login through the platform
122
+ raise ValueError(
123
+ "Teamspace information is missing. "
124
+ "If this is your first time using LitAI, please log in at https://lightning.ai/sign-up "
125
+ "and re-run your script, or set the environment variable LIGHTNING_TEAMSPACE=<your-teamspace>."
126
+ ) from err
127
+
128
+ # TODO: when teamspace_name is given, we don't know the teamspace_id yet
129
+ # TODO: if LIGHTNING_CLOUD_PROJECT_ID does not exist, we have to get the id from the teamspace name
130
+
88
131
  LLM._cached_auth_info = {
89
132
  "teamspace_name": teamspace_name,
90
133
  "teamspace_id": os.environ.get("LIGHTNING_CLOUD_PROJECT_ID", None),
@@ -125,7 +168,10 @@ class LLM:
125
168
  and LLM._public_assistants
126
169
  and f"{self._model_provider}/{self._model_name}" in LLM._public_assistants
127
170
  ):
128
- return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]
171
+ self._context_length = int(
172
+ LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]["context_length"]
173
+ )
174
+ return LLM._public_assistants[f"{self._model_provider}/{self._model_name}"]["id"]
129
175
  try:
130
176
  return self._llm_api.get_assistant(
131
177
  model_provider=PUBLIC_MODEL_PROVIDERS[self._model_provider],
@@ -139,6 +185,18 @@ class LLM:
139
185
  "Please check the model name or provider."
140
186
  ) from e
141
187
 
188
+ if self._model_provider == "lightning-ai":
189
+ # Try model provider model
190
+ try:
191
+ return self._llm_api.get_assistant(
192
+ model_provider=self._model_provider,
193
+ model_name=self._model_name,
194
+ user_name="",
195
+ org_name="",
196
+ )
197
+ except Exception:
198
+ pass
199
+
142
200
  # Try organization model
143
201
  try:
144
202
  return self._llm_api.get_assistant(
@@ -1,10 +1,34 @@
1
1
  {
2
- "openai/gpt-4o": "ast_01jdjds71fs8gt47jexzed4czs",
3
- "openai/gpt-4": "ast_01jd38ze6tjbrcd4942nhz41zn",
4
- "openai/o3-mini": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
5
- "anthropic/claude-3-5-sonnet-20240620": "ast_01jd3923a6p98rqwh3dpj686pq",
6
- "google/gemini-2.5-pro": "ast_01jz3tdb1fhey798k95pv61v57",
7
- "google/gemini-2.5-flash": "ast_01jz3thxskg4fcdk4xhkjkym5a",
8
- "google/gemini-2.5-flash-lite-preview-06-17": "ast_01jz3thxskg4fcdk4xhkjkym5b",
9
- "lightning-ai/llama4-maverick": "ast_01k073vsqs66tenpns02cf5jnq"
2
+ "openai/gpt-4o": {
3
+ "id": "ast_01jdjds71fs8gt47jexzed4czs",
4
+ "context_length": 128000
5
+ },
6
+ "openai/gpt-4": {
7
+ "id": "ast_01jd38ze6tjbrcd4942nhz41zn",
8
+ "context_length": 8192
9
+ },
10
+ "openai/o3-mini": {
11
+ "id": "ast_01jz3t13fhnjhh11t1k8b5gyp1",
12
+ "context_length": 128000
13
+ },
14
+ "anthropic/claude-3-5-sonnet-20240620": {
15
+ "id": "ast_01jd3923a6p98rqwh3dpj686pq",
16
+ "context_length": 200000
17
+ },
18
+ "google/gemini-2.5-pro": {
19
+ "id": "ast_01jz3tdb1fhey798k95pv61v57",
20
+ "context_length": 1048576
21
+ },
22
+ "google/gemini-2.5-flash": {
23
+ "id": "ast_01jz3thxskg4fcdk4xhkjkym5a",
24
+ "context_length": 8000
25
+ },
26
+ "google/gemini-2.5-flash-lite-preview-06-17": {
27
+ "id": "ast_01jz3thxskg4fcdk4xhkjkym5b",
28
+ "context_length": 8000
29
+ },
30
+ "lightning-ai/llama4-maverick": {
31
+ "id": "ast_01k0wgg56tm8mv9n12aq2mnxas",
32
+ "context_length": 100000
33
+ }
10
34
  }
lightning_sdk/machine.py CHANGED
@@ -1,60 +1,113 @@
1
1
  from dataclasses import dataclass
2
+ from enum import Enum
2
3
  from typing import Any, ClassVar, Optional, Tuple
3
4
 
4
5
 
6
+ class CloudProvider(Enum):
7
+ AWS = "AWS"
8
+ GCP = "GCP"
9
+ VULTR = "VULTR"
10
+ LAMBDA_LABS = "LAMBDA_LABS"
11
+ DGX = "DGX"
12
+ VOLTAGE_PARK = "VOLTAGE_PARK"
13
+ NEBIUS = "NEBIUS"
14
+ LIGHTNING = "LIGHTNING"
15
+
16
+ def __str__(self) -> str:
17
+ """Converts the CloudProvider to a str."""
18
+ return self.value
19
+
20
+
5
21
  @dataclass(frozen=True)
6
22
  class Machine:
7
- # Default Machines
8
- CPU: ClassVar["Machine"]
23
+ # supported CPU variations
24
+ CPU_X_2: ClassVar["Machine"]
25
+ CPU_X_4: ClassVar["Machine"]
26
+ CPU_X_8: ClassVar["Machine"]
27
+ CPU_X_16: ClassVar["Machine"]
28
+ # default CPU machines
9
29
  CPU_SMALL: ClassVar["Machine"]
30
+ CPU: ClassVar["Machine"]
31
+ # supported data-prep variations (big disk)
10
32
  DATA_PREP: ClassVar["Machine"]
11
33
  DATA_PREP_MAX: ClassVar["Machine"]
12
34
  DATA_PREP_ULTRA: ClassVar["Machine"]
35
+
36
+ # supported GPU types
37
+ # supported T4 variations
13
38
  T4: ClassVar["Machine"]
39
+ T4_X_2: ClassVar["Machine"]
14
40
  T4_X_4: ClassVar["Machine"]
41
+ T4_X_8: ClassVar["Machine"]
42
+ # supported L4 variations
15
43
  L4: ClassVar["Machine"]
16
44
  L4_X_2: ClassVar["Machine"]
17
45
  L4_X_4: ClassVar["Machine"]
18
46
  L4_X_8: ClassVar["Machine"]
19
- A10G: ClassVar["Machine"]
20
- A10G_X_4: ClassVar["Machine"]
21
- A10G_X_8: ClassVar["Machine"]
47
+ # supported L40S variations
22
48
  L40S: ClassVar["Machine"]
49
+ L40S_X_2: ClassVar["Machine"]
23
50
  L40S_X_4: ClassVar["Machine"]
24
51
  L40S_X_8: ClassVar["Machine"]
52
+ # supported A100 variations
53
+ # defaults, can be either A100 type depending on cloud provider availability
54
+ A100: ClassVar["Machine"]
25
55
  A100_X_2: ClassVar["Machine"]
26
56
  A100_X_4: ClassVar["Machine"]
27
57
  A100_X_8: ClassVar["Machine"]
28
- B200_X_8: ClassVar["Machine"]
58
+ # A100 40GB versions
59
+ A100_40GB: ClassVar["Machine"]
60
+ A100_40GB_X_2: ClassVar["Machine"]
61
+ A100_40GB_X_4: ClassVar["Machine"]
62
+ A100_40GB_X_8: ClassVar["Machine"]
63
+ # A100 80GB versions
64
+ A100_80GB: ClassVar["Machine"]
65
+ A100_80GB_X_2: ClassVar["Machine"]
66
+ A100_80GB_X_4: ClassVar["Machine"]
67
+ A100_80GB_X_8: ClassVar["Machine"]
68
+
69
+ H100: ClassVar["Machine"]
70
+ H100_X_2: ClassVar["Machine"]
71
+ H100_X_4: ClassVar["Machine"]
29
72
  H100_X_8: ClassVar["Machine"]
73
+
74
+ H200: ClassVar["Machine"]
30
75
  H200_X_8: ClassVar["Machine"]
76
+ B200_X_8: ClassVar["Machine"]
77
+
78
+ # Specialized Machines
31
79
 
32
80
  name: str
33
- instance_type: str
81
+ slug: str
82
+ instance_type: Optional[str] = None
83
+ family: Optional[str] = None
84
+ accelerator_count: Optional[int] = None
34
85
  cost: Optional[float] = None
35
86
  interruptible_cost: Optional[float] = None
36
87
  wait_time: Optional[float] = None
37
88
  interruptible_wait_time: Optional[float] = None
89
+ _include_in_cli: bool = True
38
90
 
39
91
  def __str__(self) -> str:
40
92
  """String representation of the Machine."""
41
- return str(self.name) if self.name else str(self.instance_type)
93
+ return str(self.name) if self.name else (self.slug if self.slug else str(self.instance_type))
42
94
 
43
95
  def __eq__(self, other: object) -> bool:
44
96
  """Machines are equal if the instance type is equal."""
45
97
  if isinstance(other, Machine):
46
- return self.instance_type == other.instance_type
98
+ return (
99
+ # equality based on raw instance type (provider specific)
100
+ (self.instance_type and self.instance_type == other.instance_type)
101
+ # equality based on slug (provider agnostic)
102
+ or self.slug == other.slug
103
+ # equality based on machine specs (e.g. A100_80GB_X_8 == A100_X_8)
104
+ or (self.family == other.family and self.accelerator_count == other.accelerator_count)
105
+ )
47
106
  return False
48
107
 
49
108
  def is_cpu(self) -> bool:
50
109
  """Whether the machine is a CPU."""
51
- return (
52
- self == Machine.CPU
53
- or self == Machine.CPU_SMALL
54
- or self == Machine.DATA_PREP
55
- or self == Machine.DATA_PREP_MAX
56
- or self == Machine.DATA_PREP_ULTRA
57
- )
110
+ return self.family in ("CPU", "DATA_PREP")
58
111
 
59
112
  @classmethod
60
113
  def from_str(cls, machine: str, *additional_machine_ids: Any) -> "Machine":
@@ -63,34 +116,89 @@ class Machine:
63
116
  )
64
117
  for m in possible_values:
65
118
  for machine_id in [machine, *additional_machine_ids]:
66
- if machine_id in (getattr(m, "name", None), getattr(m, "instance_type", None)):
119
+ if machine_id in (
120
+ getattr(m, "name", None),
121
+ getattr(m, "instance_type", None),
122
+ getattr(m, "slug", None),
123
+ ):
67
124
  return m
68
125
 
69
126
  if additional_machine_ids:
70
127
  return cls(machine, *additional_machine_ids)
71
- return cls(machine, machine)
72
-
73
-
74
- Machine.CPU = Machine(name="CPU", instance_type="cpu-4")
75
- Machine.CPU_SMALL = Machine(name="CPU_SMALL", instance_type="n2d-standard-2") # GCP
76
- Machine.DATA_PREP = Machine(name="DATA_PREP", instance_type="data-large")
77
- Machine.DATA_PREP_MAX = Machine(name="DATA_PREP_MAX", instance_type="data-max")
78
- Machine.DATA_PREP_ULTRA = Machine(name="DATA_PREP_ULTRA", instance_type="data-ultra")
79
- Machine.T4 = Machine(name="T4", instance_type="g4dn.2xlarge")
80
- Machine.T4_X_4 = Machine(name="T4_X_4", instance_type="g4dn.12xlarge")
81
- Machine.L4 = Machine(name="L4", instance_type="g6.4xlarge")
82
- Machine.L4_X_2 = Machine(name="L4_X_2", instance_type="g2-standard-24") # GCP
83
- Machine.L4_X_4 = Machine(name="L4_X_4", instance_type="g6.12xlarge")
84
- Machine.L4_X_8 = Machine(name="L4_X_8", instance_type="g6.48xlarge")
85
- Machine.A10G = Machine(name="A10G", instance_type="g5.8xlarge")
86
- Machine.A10G_X_4 = Machine(name="A10G_X_4", instance_type="g5.12xlarge")
87
- Machine.A10G_X_8 = Machine(name="A10G_X_8", instance_type="g5.48xlarge")
88
- Machine.L40S = Machine(name="L40S", instance_type="g6e.4xlarge")
89
- Machine.L40S_X_4 = Machine(name="L40S_X_4", instance_type="g6e.12xlarge")
90
- Machine.L40S_X_8 = Machine(name="L40S_X_8", instance_type="g6e.48xlarge")
91
- Machine.A100_X_2 = Machine(name="A100_X_2", instance_type="a2-ultragpu-2g") # GCP
92
- Machine.A100_X_4 = Machine(name="A100_X_4", instance_type="a2-ultragpu-4g") # GCP
93
- Machine.A100_X_8 = Machine(name="A100_X_8", instance_type="p4d.24xlarge")
94
- Machine.B200_X_8 = Machine(name="B200_X_8", instance_type="a4-highgpu-8g") # GCP
95
- Machine.H100_X_8 = Machine(name="H100_X_8", instance_type="p5.48xlarge")
96
- Machine.H200_X_8 = Machine(name="H200_X_8", instance_type="p5en.48xlarge")
128
+ return cls(machine, machine, machine)
129
+
130
+
131
+ # CPU machines
132
+ # default CPU machines
133
+ Machine.CPU_SMALL = Machine(name="CPU_SMALL", slug="cpu-2", family="CPU", accelerator_count=2)
134
+ Machine.CPU = Machine(name="CPU", slug="cpu-4", family="CPU", accelerator_count=4)
135
+ # available CPU variations
136
+ Machine.CPU_X_2 = Machine(name="CPU_X_2", slug="cpu-2", family="CPU", accelerator_count=2)
137
+ Machine.CPU_X_4 = Machine(name="CPU_X_4", slug="cpu-4", family="CPU", accelerator_count=4)
138
+ Machine.CPU_X_8 = Machine(name="CPU_X_8", slug="cpu-8", family="CPU", accelerator_count=8)
139
+ Machine.CPU_X_16 = Machine(name="CPU_X_16", slug="cpu-16", family="CPU", accelerator_count=16)
140
+ # available data-prep (big disk) machines
141
+ Machine.DATA_PREP = Machine(name="DATA_PREP", slug="data-prep-mid", family="DATA_PREP", accelerator_count=32)
142
+ Machine.DATA_PREP_MAX = Machine(
143
+ name="DATA_PREP_MAX", slug="data-prep-max-large", family="DATA_PREP", accelerator_count=64
144
+ )
145
+ Machine.DATA_PREP_ULTRA = Machine(
146
+ name="DATA_PREP_ULTRA", slug="data-prep-ultra-extra-large", family="DATA_PREP", accelerator_count=96
147
+ )
148
+
149
+ # GPU machines
150
+ # available T4 machines
151
+ Machine.T4 = Machine(name="T4", slug="lit-t4-1", family="T4", accelerator_count=1)
152
+ Machine.T4_X_2 = Machine(name="T4_X_2", slug="lit-t4-2", family="T4", accelerator_count=2)
153
+ Machine.T4_X_4 = Machine(name="T4_X_4", slug="lit-t4-4", family="T4", accelerator_count=4)
154
+ Machine.T4_X_8 = Machine(name="T4_X_8", slug="lit-t4-8", family="T4", accelerator_count=8)
155
+ # available L4 machines
156
+ Machine.L4 = Machine(name="L4", slug="lit-l4-1", family="L4", accelerator_count=1)
157
+ Machine.L4_X_2 = Machine(name="L4_X_2", slug="lit-l4-2", family="L4", accelerator_count=2)
158
+ Machine.L4_X_4 = Machine(name="L4_X_4", slug="lit-l4-4", family="L4", accelerator_count=4)
159
+ Machine.L4_X_8 = Machine(name="L4_X_8", slug="lit-l4-8", family="L4", accelerator_count=8)
160
+ # available L40S machines
161
+ Machine.L40S = Machine(name="L40S", slug="lit-l40s-1", family="L40S", accelerator_count=1)
162
+ Machine.L40S_X_2 = Machine(name="L40S_X_2", slug="lit-l40s-2", family="L40S", accelerator_count=2)
163
+ Machine.L40S_X_4 = Machine(name="L40S_X_4", slug="lit-l40s-4", family="L40S", accelerator_count=4)
164
+ Machine.L40S_X_8 = Machine(name="L40S_X_8", slug="lit-l40s-8", family="L40S", accelerator_count=8)
165
+ # available A100 Machines
166
+ Machine.A100 = Machine(name="A100", slug="lit-a100-1", family="A100", accelerator_count=1)
167
+ Machine.A100_X_2 = Machine(name="A100_X_2", slug="lit-a100-2", family="A100", accelerator_count=2)
168
+ Machine.A100_X_4 = Machine(name="A100_X_4", slug="lit-a100-4", family="A100", accelerator_count=4)
169
+ Machine.A100_X_8 = Machine(name="A100_X_8", slug="lit-a100-8", family="A100", accelerator_count=8)
170
+ # don't include variants in cli, only default types that can match for all variants
171
+ Machine.A100_40GB = Machine(
172
+ name="A100_40GB", slug="lit-a100-40gb-1", family="A100", accelerator_count=1, _include_in_cli=False
173
+ )
174
+ Machine.A100_40GB_X_2 = Machine(
175
+ name="A100_40GB_X_2", slug="lit-a100-40gb-2", family="A100", accelerator_count=2, _include_in_cli=False
176
+ )
177
+ Machine.A100_40GB_X_4 = Machine(
178
+ name="A100_40GB_X_4", slug="lit-a100-40gb-4", family="A100", accelerator_count=4, _include_in_cli=False
179
+ )
180
+ Machine.A100_40GB_X_8 = Machine(
181
+ name="A100_40GB_X_8", slug="lit-a100-40gb-8", family="A100", accelerator_count=8, _include_in_cli=False
182
+ )
183
+ Machine.A100_80GB = Machine(
184
+ name="A100_80GB", slug="lit-a100-80gb-1", family="A100", accelerator_count=1, _include_in_cli=False
185
+ )
186
+ Machine.A100_80GB_X_2 = Machine(
187
+ name="A100_80GB_X_2", slug="lit-a100-80gb-2", family="A100", accelerator_count=2, _include_in_cli=False
188
+ )
189
+ Machine.A100_80GB_X_4 = Machine(
190
+ name="A100_80GB_X_4", slug="lit-a100-80gb-4", family="A100", accelerator_count=4, _include_in_cli=False
191
+ )
192
+ Machine.A100_80GB_X_8 = Machine(
193
+ name="A100_80GB_X_8", slug="lit-a100-80gb-8", family="A100", accelerator_count=8, _include_in_cli=False
194
+ )
195
+ # available H100 machines
196
+ Machine.H100 = Machine(name="H100", slug="lit-h100-1", family="H100", accelerator_count=1)
197
+ Machine.H100_X_2 = Machine(name="H100_X_2", slug="lit-h100-2", family="H100", accelerator_count=2)
198
+ Machine.H100_X_4 = Machine(name="H100_X_4", slug="lit-h100-4", family="H100", accelerator_count=4)
199
+ Machine.H100_X_8 = Machine(name="H100_X_8", slug="lit-h100-8", family="H100", accelerator_count=8)
200
+ # available H200 machines
201
+ Machine.H200 = Machine(name="H200", slug="lit-h200x-1", family="H200", accelerator_count=1)
202
+ Machine.H200_X_8 = Machine(name="H200_X_8", slug="lit-h200x-8", family="H200", accelerator_count=8)
203
+ # available B200 machines
204
+ Machine.B200_X_8 = Machine(name="B200_X_8", slug="lit-b200x-8", family="B200", accelerator_count=8)
lightning_sdk/mmt/base.py CHANGED
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Protocol, Tuple, Union
4
4
 
5
5
  if TYPE_CHECKING:
6
6
  from lightning_sdk.job.base import MachineDict
7
- from lightning_sdk.machine import Machine
7
+ from lightning_sdk.machine import CloudProvider, Machine
8
8
  from lightning_sdk.organization import Organization
9
9
  from lightning_sdk.status import Status
10
10
  from lightning_sdk.studio import Studio
@@ -64,12 +64,14 @@ class _BaseMMT(_BaseJob):
64
64
  org: Union[str, "Organization", None] = None,
65
65
  user: Union[str, "User", None] = None,
66
66
  cloud_account: Optional[str] = None,
67
+ cloud_provider: Optional[Union["CloudProvider", str]] = None,
67
68
  env: Optional[Dict[str, str]] = None,
68
69
  interruptible: bool = False,
69
70
  image_credentials: Optional[str] = None,
70
71
  cloud_account_auth: bool = False,
71
72
  entrypoint: str = "sh -c",
72
73
  path_mappings: Optional[Dict[str, str]] = None,
74
+ max_runtime: Optional[int] = None,
73
75
  artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
74
76
  artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
75
77
  cluster: Optional[str] = None, # deprecated in favor of cloud_account
@@ -89,7 +91,11 @@ class _BaseMMT(_BaseJob):
89
91
  user: The user owning the teamspace (if any). Defaults to the current user.
90
92
  cloud_account: The cloud account to run the job on.
91
93
  Defaults to the studio cloud account if running with studio compute env.
92
- If not provided will fall back to the teamspaces default cloud account.
94
+ If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
95
+ will fall back to the teamspaces default cloud account.
96
+ cloud_account_provider: The provider to select the cloud-account from.
97
+ If set, must be in agreement with the provider from the cloud_account (if specified).
98
+ If not specified, falls backto the teamspace default cloud account.
93
99
  env: Environment variables to set inside the job.
94
100
  interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
95
101
  image_credentials: The credentials used to pull the image. Required if the image is private.
@@ -109,6 +115,10 @@ class _BaseMMT(_BaseJob):
109
115
  }
110
116
  If the path inside the connection is omitted it's assumed to be the root path of that connection.
111
117
  Only applicable when submitting docker jobs.
118
+ max_runtime: the duration (in seconds) for which to allocate the machine.
119
+ Irrelevant for most machines, required for some of the top-end machines on GCP.
120
+ If in doubt, set it. Won't have an effect on machines not requiring it.
121
+ Defaults to 3h
112
122
  """
113
123
  from lightning_sdk.lightning_cloud.openapi.rest import ApiException
114
124
  from lightning_sdk.studio import Studio
@@ -191,6 +201,7 @@ class _BaseMMT(_BaseJob):
191
201
  num_machines=num_machines,
192
202
  machine=machine,
193
203
  cloud_account=cloud_account,
204
+ cloud_provider=cloud_provider,
194
205
  command=command,
195
206
  studio=studio,
196
207
  image=image,
@@ -202,6 +213,7 @@ class _BaseMMT(_BaseJob):
202
213
  path_mappings=path_mappings,
203
214
  artifacts_local=artifacts_local,
204
215
  artifacts_remote=artifacts_remote,
216
+ max_runtime=max_runtime,
205
217
  )
206
218
  return inst
207
219
 
@@ -216,12 +228,14 @@ class _BaseMMT(_BaseJob):
216
228
  env: Optional[Dict[str, str]] = None,
217
229
  interruptible: bool = False,
218
230
  cloud_account: Optional[str] = None,
231
+ cloud_provider: Optional[Union["CloudProvider", str]] = None,
219
232
  image_credentials: Optional[str] = None,
220
233
  cloud_account_auth: bool = False,
221
234
  entrypoint: str = "sh -c",
222
235
  path_mappings: Optional[Dict[str, str]] = None,
223
236
  artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
224
237
  artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
238
+ max_runtime: Optional[int] = None,
225
239
  ) -> None:
226
240
  """Submit a new multi-machine job to the Lightning AI platform.
227
241
 
@@ -253,6 +267,10 @@ class _BaseMMT(_BaseJob):
253
267
  }
254
268
  If the path inside the connection is omitted it's assumed to be the root path of that connection.
255
269
  Only applicable when submitting docker jobs.
270
+ max_runtime: the duration (in seconds) for which to allocate the machine.
271
+ Irrelevant for most machines, required for some of the top-end machines on GCP.
272
+ If in doubt, set it. Won't have an effect on machines not requiring it.
273
+ Defaults to 3h
256
274
  """
257
275
 
258
276
  @property
lightning_sdk/mmt/mmt.py CHANGED
@@ -1,12 +1,13 @@
1
1
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
2
2
 
3
+ from lightning_sdk.api.cloud_account_api import CloudAccountApi
3
4
  from lightning_sdk.mmt.base import MMTMachine, _BaseMMT
4
5
  from lightning_sdk.mmt.v1 import _MMTV1
5
6
  from lightning_sdk.mmt.v2 import _MMTV2
6
7
  from lightning_sdk.utils.resolve import _setup_logger
7
8
 
8
9
  if TYPE_CHECKING:
9
- from lightning_sdk.machine import Machine
10
+ from lightning_sdk.machine import CloudProvider, Machine
10
11
  from lightning_sdk.organization import Organization
11
12
  from lightning_sdk.status import Status
12
13
  from lightning_sdk.studio import Studio
@@ -75,6 +76,7 @@ class MMT(_BaseMMT):
75
76
  )
76
77
 
77
78
  self._internal_mmt = mmt
79
+ self._cloud_account_api = CloudAccountApi()
78
80
 
79
81
  @classmethod
80
82
  def run(
@@ -89,12 +91,14 @@ class MMT(_BaseMMT):
89
91
  org: Union[str, "Organization", None] = None,
90
92
  user: Union[str, "User", None] = None,
91
93
  cloud_account: Optional[str] = None,
94
+ cloud_provider: Optional[Union["CloudProvider", str]] = None,
92
95
  env: Optional[Dict[str, str]] = None,
93
96
  interruptible: bool = False,
94
97
  image_credentials: Optional[str] = None,
95
98
  cloud_account_auth: bool = False,
96
99
  entrypoint: str = "sh -c",
97
100
  path_mappings: Optional[Dict[str, str]] = None,
101
+ max_runtime: Optional[int] = None,
98
102
  artifacts_local: Optional[str] = None,
99
103
  artifacts_remote: Optional[str] = None,
100
104
  cluster: Optional[str] = None, # deprecated in favor of cloud_account
@@ -114,7 +118,11 @@ class MMT(_BaseMMT):
114
118
  user: The user owning the teamspace (if any). Defaults to the current user.
115
119
  cloud_account: The cloud account to run the job on.
116
120
  Defaults to the studio cloud account if running with studio compute env.
117
- If not provided will fall back to the teamspaces default cloud account.
121
+ If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
122
+ will fall back to the teamspaces default cloud account.
123
+ cloud_account_provider: The provider to select the cloud-account from.
124
+ If set, must be in agreement with the provider from the cloud_account (if specified).
125
+ If not specified, falls backto the teamspace default cloud account.
118
126
  env: Environment variables to set inside the job.
119
127
  interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
120
128
  image_credentials: The credentials used to pull the image. Required if the image is private.
@@ -145,6 +153,7 @@ class MMT(_BaseMMT):
145
153
  org=org,
146
154
  user=user,
147
155
  cloud_account=cloud_account,
156
+ cloud_provider=cloud_provider,
148
157
  env=env,
149
158
  interruptible=interruptible,
150
159
  image_credentials=image_credentials,
@@ -154,6 +163,7 @@ class MMT(_BaseMMT):
154
163
  artifacts_local=artifacts_local,
155
164
  artifacts_remote=artifacts_remote,
156
165
  cluster=cluster, # deprecated in favor of cloud_account
166
+ max_runtime=max_runtime,
157
167
  )
158
168
  # required for typing with "MMT"
159
169
  assert isinstance(ret_val, cls)
@@ -173,10 +183,12 @@ class MMT(_BaseMMT):
173
183
  env: Optional[Dict[str, str]] = None,
174
184
  interruptible: bool = False,
175
185
  cloud_account: Optional[str] = None,
186
+ cloud_provider: Optional[Union["CloudProvider", str]] = None,
176
187
  image_credentials: Optional[str] = None,
177
188
  cloud_account_auth: bool = False,
178
189
  entrypoint: str = "sh -c",
179
190
  path_mappings: Optional[Dict[str, str]] = None,
191
+ max_runtime: Optional[int] = None,
180
192
  artifacts_local: Optional[str] = None, # deprecated in favor of path_mappings
181
193
  artifacts_remote: Optional[str] = None, # deprecated in favor of path_mappings
182
194
  ) -> "MMT":
@@ -193,7 +205,11 @@ class MMT(_BaseMMT):
193
205
  interruptible: Whether the job should run on interruptible instances. They are cheaper but can be preempted.
194
206
  cloud_account: The cloud account to run the job on.
195
207
  Defaults to the studio cloud account if running with studio compute env.
196
- If not provided will fall back to the teamspaces default cloud account.
208
+ If not provided and `cloud_account_provider` is set, will resolve cluster from this, else
209
+ will fall back to the teamspaces default cloud account.
210
+ cloud_account_provider: The provider to select the cloud-account from.
211
+ If set, must be in agreement with the provider from the cloud_account (if specified).
212
+ If not specified, falls backto the teamspace default cloud account.
197
213
  image_credentials: The credentials used to pull the image. Required if the image is private.
198
214
  This should be the name of the respective credentials secret created on the Lightning AI platform.
199
215
  cloud_account_auth: Whether to authenticate with the cloud account to pull the image.
@@ -211,11 +227,16 @@ class MMT(_BaseMMT):
211
227
  }
212
228
  If the path inside the connection is omitted it's assumed to be the root path of that connection.
213
229
  Only applicable when submitting docker jobs.
230
+ max_runtime: the duration (in seconds) for which to allocate the machine.
231
+ Irrelevant for most machines, required for some of the top-end machines on GCP.
232
+ If in doubt, set it. Won't have an effect on machines not requiring it.
233
+ Defaults to 3h
214
234
  """
215
235
  self._job = self._internal_mmt._submit(
216
236
  num_machines=num_machines,
217
237
  machine=machine,
218
238
  cloud_account=cloud_account,
239
+ cloud_provider=cloud_provider,
219
240
  command=command,
220
241
  studio=studio,
221
242
  image=image,
@@ -227,6 +248,7 @@ class MMT(_BaseMMT):
227
248
  path_mappings=path_mappings,
228
249
  artifacts_local=artifacts_local,
229
250
  artifacts_remote=artifacts_remote,
251
+ max_runtime=max_runtime,
230
252
  )
231
253
  return self
232
254