lightning-sdk 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/base_studio_api.py +79 -0
  3. lightning_sdk/api/cluster_api.py +83 -1
  4. lightning_sdk/api/license_api.py +48 -0
  5. lightning_sdk/api/llm_api.py +73 -12
  6. lightning_sdk/api/studio_api.py +50 -1
  7. lightning_sdk/api/teamspace_api.py +127 -1
  8. lightning_sdk/api/utils.py +4 -0
  9. lightning_sdk/base_studio.py +83 -0
  10. lightning_sdk/cli/create.py +21 -1
  11. lightning_sdk/cli/delete.py +6 -8
  12. lightning_sdk/cli/deploy/__init__.py +0 -0
  13. lightning_sdk/cli/deploy/_auth.py +189 -0
  14. lightning_sdk/cli/deploy/devbox.py +157 -0
  15. lightning_sdk/cli/{serve.py → deploy/serve.py} +22 -281
  16. lightning_sdk/cli/download.py +69 -16
  17. lightning_sdk/cli/entrypoint.py +1 -1
  18. lightning_sdk/cli/open.py +21 -2
  19. lightning_sdk/cli/start.py +12 -3
  20. lightning_sdk/cli/teamspace_menu.py +9 -1
  21. lightning_sdk/cli/upload.py +2 -5
  22. lightning_sdk/lightning_cloud/openapi/__init__.py +29 -0
  23. lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
  24. lightning_sdk/lightning_cloud/openapi/api/assistants_service_api.py +121 -0
  25. lightning_sdk/lightning_cloud/openapi/api/billing_service_api.py +9 -1
  26. lightning_sdk/lightning_cloud/openapi/api/cloud_space_service_api.py +226 -0
  27. lightning_sdk/lightning_cloud/openapi/api/cluster_service_api.py +105 -0
  28. lightning_sdk/lightning_cloud/openapi/api/file_system_service_api.py +178 -0
  29. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +984 -101
  30. lightning_sdk/lightning_cloud/openapi/api/product_license_service_api.py +525 -0
  31. lightning_sdk/lightning_cloud/openapi/api/storage_service_api.py +93 -0
  32. lightning_sdk/lightning_cloud/openapi/configuration.py +1 -1
  33. lightning_sdk/lightning_cloud/openapi/models/__init__.py +28 -0
  34. lightning_sdk/lightning_cloud/openapi/models/assistant_id_conversations_body.py +79 -1
  35. lightning_sdk/lightning_cloud/openapi/models/cloudspaces_id_body.py +53 -1
  36. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body.py +331 -0
  37. lightning_sdk/lightning_cloud/openapi/models/deployment_id_alertingpolicies_body1.py +305 -0
  38. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +53 -1
  39. lightning_sdk/lightning_cloud/openapi/models/endpoints_id_body.py +27 -1
  40. lightning_sdk/lightning_cloud/openapi/models/model_id_versions_body.py +27 -1
  41. lightning_sdk/lightning_cloud/openapi/models/models_id_body.py +123 -0
  42. lightning_sdk/lightning_cloud/openapi/models/orgs_id_body.py +183 -1
  43. lightning_sdk/lightning_cloud/openapi/models/pipelines_id_body.py +6 -6
  44. lightning_sdk/lightning_cloud/openapi/models/project_id_cloudspaces_body.py +27 -1
  45. lightning_sdk/lightning_cloud/openapi/models/project_id_storage_body.py +27 -1
  46. lightning_sdk/lightning_cloud/openapi/models/projects_id_body.py +107 -3
  47. lightning_sdk/lightning_cloud/openapi/models/storage_complete_body.py +27 -1
  48. lightning_sdk/lightning_cloud/openapi/models/update.py +79 -1
  49. lightning_sdk/lightning_cloud/openapi/models/uploads_upload_id_body1.py +55 -3
  50. lightning_sdk/lightning_cloud/openapi/models/v1_aws_direct_v1.py +53 -1
  51. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_provider.py +3 -0
  52. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space.py +79 -1
  53. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_config.py +123 -0
  54. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_template_config.py +79 -1
  55. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_environment_type.py +104 -0
  56. lightning_sdk/lightning_cloud/openapi/models/v1_cloud_space_source_type.py +103 -0
  57. lightning_sdk/lightning_cloud/openapi/models/v1_cloudflare_v1.py +66 -66
  58. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +27 -1
  59. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_tagging_options.py +27 -1
  60. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_upload.py +149 -0
  61. lightning_sdk/lightning_cloud/openapi/models/v1_complete_upload.py +55 -3
  62. lightning_sdk/lightning_cloud/openapi/models/v1_conversation.py +27 -1
  63. lightning_sdk/lightning_cloud/openapi/models/v1_create_cloud_space_environment_template_request.py +79 -1
  64. lightning_sdk/lightning_cloud/openapi/models/v1_delete_deployment_alerting_policy_response.py +175 -0
  65. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +53 -1
  66. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_event.py +487 -0
  67. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy.py +383 -0
  68. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_frequency.py +105 -0
  69. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_operation.py +105 -0
  70. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_severity.py +106 -0
  71. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_alerting_policy_type.py +111 -0
  72. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_api.py +27 -1
  73. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +4 -4
  74. lightning_sdk/lightning_cloud/openapi/models/v1_endpoint.py +27 -1
  75. lightning_sdk/lightning_cloud/openapi/models/v1_external_search_user.py +27 -1
  76. lightning_sdk/lightning_cloud/openapi/models/v1_ge_list_deployment_routing_telemetry_response.py +123 -0
  77. lightning_sdk/lightning_cloud/openapi/models/v1_get_cloud_space_instance_open_ports_response.py +123 -0
  78. lightning_sdk/lightning_cloud/openapi/models/v1_get_deployment_routing_telemetry_content_response.py +123 -0
  79. lightning_sdk/lightning_cloud/openapi/models/v1_get_job_stats_response.py +53 -1
  80. lightning_sdk/lightning_cloud/openapi/models/v1_get_organization_storage_metadata_response.py +331 -0
  81. lightning_sdk/lightning_cloud/openapi/models/v1_get_project_balance_response.py +1 -27
  82. lightning_sdk/lightning_cloud/openapi/models/v1_google_cloud_direct_v1.py +27 -1
  83. lightning_sdk/lightning_cloud/openapi/models/v1_job_type.py +1 -0
  84. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_events_response.py +123 -0
  85. lightning_sdk/lightning_cloud/openapi/models/v1_list_deployment_alerting_policies_response.py +175 -0
  86. lightning_sdk/lightning_cloud/openapi/models/v1_list_product_licenses_response.py +123 -0
  87. lightning_sdk/lightning_cloud/openapi/models/v1_managed_model.py +27 -1
  88. lightning_sdk/lightning_cloud/openapi/models/v1_membership.py +43 -17
  89. lightning_sdk/lightning_cloud/openapi/models/v1_modify_filesystem_volume_response.py +97 -0
  90. lightning_sdk/lightning_cloud/openapi/models/v1_organization.py +183 -1
  91. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline.py +6 -6
  92. lightning_sdk/lightning_cloud/openapi/models/v1_pipeline_state.py +111 -0
  93. lightning_sdk/lightning_cloud/openapi/models/v1_presigned_url.py +53 -1
  94. lightning_sdk/lightning_cloud/openapi/models/v1_product_license.py +409 -0
  95. lightning_sdk/lightning_cloud/openapi/models/v1_product_license_check_response.py +123 -0
  96. lightning_sdk/lightning_cloud/openapi/models/v1_project.py +27 -1
  97. lightning_sdk/lightning_cloud/openapi/models/v1_project_membership.py +43 -17
  98. lightning_sdk/lightning_cloud/openapi/models/v1_project_settings.py +107 -3
  99. lightning_sdk/lightning_cloud/openapi/models/v1_project_storage.py +53 -1
  100. lightning_sdk/lightning_cloud/openapi/models/v1_r2_data_connection.py +53 -1
  101. lightning_sdk/lightning_cloud/openapi/models/v1_routing_telemetry.py +253 -0
  102. lightning_sdk/lightning_cloud/openapi/models/v1_secret_type.py +1 -0
  103. lightning_sdk/lightning_cloud/openapi/models/v1_server_alert_type.py +2 -0
  104. lightning_sdk/lightning_cloud/openapi/models/v1_sleep_server_response.py +97 -0
  105. lightning_sdk/lightning_cloud/openapi/models/v1_trigger_filesystem_upgrade_response.py +123 -0
  106. lightning_sdk/lightning_cloud/openapi/models/v1_upload_project_artifact_response.py +27 -1
  107. lightning_sdk/lightning_cloud/openapi/models/v1_usage_report.py +79 -1
  108. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +347 -113
  109. lightning_sdk/lightning_cloud/openapi/models/v1_user_requested_compute_config.py +27 -1
  110. lightning_sdk/lightning_cloud/rest_client.py +4 -0
  111. lightning_sdk/llm/llm.py +132 -40
  112. lightning_sdk/services/__init__.py +1 -1
  113. lightning_sdk/services/license.py +236 -0
  114. lightning_sdk/studio.py +62 -1
  115. lightning_sdk/teamspace.py +68 -0
  116. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/METADATA +1 -1
  117. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/RECORD +122 -86
  118. /lightning_sdk/services/{finetune/__init__.py → finetune_llm.py} +0 -0
  119. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/LICENSE +0 -0
  120. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/WHEEL +0 -0
  121. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/entry_points.txt +0 -0
  122. {lightning_sdk-0.2.14.dist-info → lightning_sdk-0.2.16.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ class V1UserRequestedComputeConfig(object):
42
42
  """
43
43
  swagger_types = {
44
44
  'affinity_identifier': 'str',
45
+ 'cluster_override': 'str',
45
46
  'count': 'int',
46
47
  'cpu_image_override': 'str',
47
48
  'disk_size': 'int',
@@ -56,6 +57,7 @@ class V1UserRequestedComputeConfig(object):
56
57
 
57
58
  attribute_map = {
58
59
  'affinity_identifier': 'affinityIdentifier',
60
+ 'cluster_override': 'clusterOverride',
59
61
  'count': 'count',
60
62
  'cpu_image_override': 'cpuImageOverride',
61
63
  'disk_size': 'diskSize',
@@ -68,9 +70,10 @@ class V1UserRequestedComputeConfig(object):
68
70
  'spot': 'spot'
69
71
  }
70
72
 
71
- def __init__(self, affinity_identifier: 'str' =None, count: 'int' =None, cpu_image_override: 'str' =None, disk_size: 'int' =None, gpu_image_override: 'str' =None, id: 'str' =None, name: 'str' =None, requested_run_duration_seconds: 'str' =None, same_compute_on_resume: 'bool' =None, shm_size: 'int' =None, spot: 'bool' =None): # noqa: E501
73
+ def __init__(self, affinity_identifier: 'str' =None, cluster_override: 'str' =None, count: 'int' =None, cpu_image_override: 'str' =None, disk_size: 'int' =None, gpu_image_override: 'str' =None, id: 'str' =None, name: 'str' =None, requested_run_duration_seconds: 'str' =None, same_compute_on_resume: 'bool' =None, shm_size: 'int' =None, spot: 'bool' =None): # noqa: E501
72
74
  """V1UserRequestedComputeConfig - a model defined in Swagger""" # noqa: E501
73
75
  self._affinity_identifier = None
76
+ self._cluster_override = None
74
77
  self._count = None
75
78
  self._cpu_image_override = None
76
79
  self._disk_size = None
@@ -84,6 +87,8 @@ class V1UserRequestedComputeConfig(object):
84
87
  self.discriminator = None
85
88
  if affinity_identifier is not None:
86
89
  self.affinity_identifier = affinity_identifier
90
+ if cluster_override is not None:
91
+ self.cluster_override = cluster_override
87
92
  if count is not None:
88
93
  self.count = count
89
94
  if cpu_image_override is not None:
@@ -128,6 +133,27 @@ class V1UserRequestedComputeConfig(object):
128
133
 
129
134
  self._affinity_identifier = affinity_identifier
130
135
 
136
+ @property
137
+ def cluster_override(self) -> 'str':
138
+ """Gets the cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
139
+
140
+
141
+ :return: The cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
142
+ :rtype: str
143
+ """
144
+ return self._cluster_override
145
+
146
+ @cluster_override.setter
147
+ def cluster_override(self, cluster_override: 'str'):
148
+ """Sets the cluster_override of this V1UserRequestedComputeConfig.
149
+
150
+
151
+ :param cluster_override: The cluster_override of this V1UserRequestedComputeConfig. # noqa: E501
152
+ :type: str
153
+ """
154
+
155
+ self._cluster_override = cluster_override
156
+
131
157
  @property
132
158
  def count(self) -> 'int':
133
159
  """Gets the count of this V1UserRequestedComputeConfig. # noqa: E501
@@ -35,6 +35,8 @@ from lightning_sdk.lightning_cloud.openapi import (
35
35
  LitRegistryServiceApi,
36
36
  PipelinesServiceApi,
37
37
  SchedulesServiceApi,
38
+ ProductLicenseServiceApi,
39
+ CloudSpaceEnvironmentTemplateServiceApi
38
40
  )
39
41
  from lightning_sdk.lightning_cloud.openapi.rest import ApiException
40
42
  from lightning_sdk.lightning_cloud.source_code.logs_socket_api import LightningLogsSocketAPI
@@ -97,6 +99,8 @@ class GridRestClient(
97
99
  LitRegistryServiceApi,
98
100
  PipelinesServiceApi,
99
101
  SchedulesServiceApi,
102
+ ProductLicenseServiceApi,
103
+ CloudSpaceEnvironmentTemplateServiceApi
100
104
  ):
101
105
 
102
106
  def __init__(self, api_client: Optional[ApiClient] = None):
lightning_sdk/llm/llm.py CHANGED
@@ -1,52 +1,113 @@
1
- from typing import Dict, List, Optional, Set, Tuple, Union
1
+ import os
2
+ import warnings
3
+ from typing import Dict, Generator, List, Optional, Set, Tuple, Union
2
4
 
3
- from lightning_sdk.api import UserApi
4
5
  from lightning_sdk.api.llm_api import LLMApi
5
- from lightning_sdk.lightning_cloud.login import Auth
6
+ from lightning_sdk.cli.teamspace_menu import _TeamspacesMenu
6
7
  from lightning_sdk.lightning_cloud.openapi import V1Assistant
8
+ from lightning_sdk.lightning_cloud.openapi.models.v1_conversation_response_chunk import V1ConversationResponseChunk
7
9
  from lightning_sdk.lightning_cloud.openapi.rest import ApiException
8
10
  from lightning_sdk.organization import Organization
9
- from lightning_sdk.user import User
10
- from lightning_sdk.utils.resolve import _resolve_org, _resolve_user
11
+ from lightning_sdk.owner import Owner
12
+ from lightning_sdk.teamspace import Teamspace
13
+ from lightning_sdk.utils.resolve import _get_authed_user, _resolve_org, _resolve_teamspace
11
14
 
12
15
 
13
16
  class LLM:
14
17
  def __init__(
15
18
  self,
16
19
  name: str,
17
- user: Union[str, "User", None] = None,
18
- org: Union[str, "Organization", None] = None,
20
+ teamspace: Optional[str] = None,
19
21
  ) -> None:
20
- self._auth = Auth()
21
- self._user = None
22
+ """Initializes the LLM instance with teamspace information, which is required for billing purposes.
23
+
24
+ Teamspace information is resolved through the following methods:
25
+ 1. `.lightning/credentials.json` - Attempts to retrieve the teamspace from the local credentials file.
26
+ 2. Environment Variables - Checks for `LIGHTNING_*` environment variables.
27
+ 3. User Authentication - Redirects the user to the login page if teamspace information is not found.
28
+
29
+ Args:
30
+ name (str): The name of the model or resource.
31
+ teamspace (Optional[str]): The specified teamspace for billing. If not provided, it will be resolved
32
+ through the above methods.
33
+
34
+ Raises:
35
+ ValueError: If teamspace information cannot be resolved.
36
+ """
37
+ menu = _TeamspacesMenu()
38
+ user = _get_authed_user()
39
+ possible_teamspaces = menu._get_possible_teamspaces(user)
40
+ if teamspace is None:
41
+ # get current teamspace
42
+ self._teamspace = _resolve_teamspace(teamspace=None, org=None, user=None)
43
+ else:
44
+ self._teamspace = Teamspace(**menu._get_teamspace_from_name(teamspace, possible_teamspaces))
45
+
46
+ if self._teamspace is None:
47
+ # select the first available teamspace
48
+ first_teamspace = next(iter(possible_teamspaces.values()), None)
49
+
50
+ if first_teamspace:
51
+ self._teamspace = Teamspace(
52
+ name=first_teamspace["name"],
53
+ org=first_teamspace["org"],
54
+ user=first_teamspace["user"],
55
+ )
56
+ warnings.warn(
57
+ f"No teamspace given. Using teamspace: {self._teamspace.name}.",
58
+ UserWarning,
59
+ stacklevel=2,
60
+ )
61
+
62
+ if self._teamspace is None:
63
+ raise ValueError("Teamspace is required for billing but could not be resolved. ")
64
+
65
+ self._user = user
66
+
67
+ self._model_provider, self._model_name = self._parse_model_name(name)
22
68
 
23
- try:
24
- self._auth.authenticate()
25
- self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
26
- except ConnectionError as e:
27
- raise e
28
-
29
- self._name = name
30
- try:
31
- self._user = _resolve_user(self._user or user)
32
- except ValueError:
33
- self._user = None
69
+ self._llm_api = LLMApi()
34
70
 
35
- self._name = name
36
- self._org, self._model_name = self._parse_model_name(name)
37
71
  try:
38
72
  # check if it is a org model
39
- self._org = _resolve_org(self._org or org)
73
+ self._org = _resolve_org(self._model_provider)
74
+
75
+ try:
76
+ # check if user has access to the org
77
+ self._org_models = self._build_model_lookup(self._get_org_models())
78
+ except ApiException:
79
+ warnings.warn(
80
+ f"User is not authenticated to access the model in organization: '{self._model_provider}'.\n"
81
+ " Proceeding with appropriate org models, user models, or public models.",
82
+ UserWarning,
83
+ stacklevel=2,
84
+ )
85
+ self._model_provider = None
86
+ raise
40
87
  except ApiException:
41
- self._org = None
88
+ if isinstance(self._teamspace.owner, Organization):
89
+ self._org = self._teamspace.owner
90
+ else:
91
+ self._org = None
92
+ self._org_models = self._build_model_lookup(self._get_org_models())
42
93
 
43
- self._llm_api = LLMApi()
44
94
  self._public_models = self._build_model_lookup(self._get_public_models())
45
- self._org_models = self._build_model_lookup(self._get_org_models())
46
95
  self._user_models = self._build_model_lookup(self._get_user_models())
47
96
  self._model = self._get_model()
48
97
  self._conversations = {}
49
98
 
99
+ @property
100
+ def name(self) -> str:
101
+ return self._model_name
102
+
103
+ @property
104
+ def provider(self) -> str:
105
+ return self._model_provider
106
+
107
+ @property
108
+ def owner(self) -> Optional[Owner]:
109
+ return self._teamspace.owner
110
+
50
111
  def _parse_model_name(self, name: str) -> Tuple[str, str]:
51
112
  parts = name.split("/")
52
113
  if len(parts) == 1:
@@ -95,47 +156,76 @@ class LLM:
95
156
  available_models_str = "\n".join(available_models)
96
157
  raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
97
158
 
98
- def _get_conversations(self) -> Dict[str, str]:
99
- # TODO: after updating backend, this will fetch conversations from backend
100
- # conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
101
- return self._conversations
159
+ def _get_conversations(self) -> None:
160
+ conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
161
+ for conversation in conversations:
162
+ if conversation.name and conversation.name not in self._conversations:
163
+ self._conversations[conversation.name] = conversation.id
164
+
165
+ def _stream_chat_response(
166
+ self, result: Generator[V1ConversationResponseChunk, None, None], conversation: Optional[str] = None
167
+ ) -> Generator[str, None, None]:
168
+ first_line = next(result, None)
169
+ if first_line:
170
+ if conversation and first_line.conversation_id:
171
+ self._conversations[conversation] = first_line.conversation_id
172
+ yield first_line.choices[0].delta.content
102
173
 
103
- def _fetch_conversations(self) -> None:
104
- self._conversations = self._get_conversations()
174
+ for line in result:
175
+ yield line.choices[0].delta.content
105
176
 
106
177
  def chat(
107
178
  self,
108
179
  prompt: str,
109
180
  system_prompt: Optional[str] = None,
110
181
  max_completion_tokens: Optional[int] = 500,
182
+ images: Optional[Union[List[str], str]] = None,
111
183
  conversation: Optional[str] = None,
112
- ) -> str:
184
+ metadata: Optional[Dict[str, str]] = None,
185
+ stream: bool = False,
186
+ upload_local_images: bool = False,
187
+ ) -> Union[str, Generator[str, None, None]]:
113
188
  if conversation and conversation not in self._conversations:
114
- self._fetch_conversations()
189
+ self._get_conversations()
190
+
191
+ if images:
192
+ if isinstance(images, str):
193
+ images = [images]
194
+ for image in images:
195
+ if not isinstance(image, str):
196
+ raise NotImplementedError(f"Image type {type(image)} are not supported yet.")
197
+ if not image.startswith("http") and upload_local_images:
198
+ self._teamspace.upload_file(file_path=image, remote_path=f"images/{os.path.basename(image)}")
115
199
 
116
200
  conversation_id = self._conversations.get(conversation) if conversation else None
117
201
  output = self._llm_api.start_conversation(
118
202
  prompt=prompt,
119
203
  system_prompt=system_prompt,
120
204
  max_completion_tokens=max_completion_tokens,
205
+ images=images,
121
206
  assistant_id=self._model.id,
122
207
  conversation_id=conversation_id,
208
+ billing_project_id=self._teamspace.id,
209
+ metadata=metadata,
210
+ name=conversation,
211
+ stream=stream,
123
212
  )
124
- if conversation and not conversation_id:
125
- self._conversations[conversation] = output.conversation_id
126
- return output.choices[0].delta.content
213
+ if not stream:
214
+ if conversation and not conversation_id:
215
+ self._conversations[conversation] = output.conversation_id
216
+ return output.choices[0].delta.content
217
+ return self._stream_chat_response(output, conversation=conversation)
127
218
 
128
219
  def list_conversations(self) -> List[Dict]:
129
- self._fetch_conversations()
220
+ self._get_conversations()
130
221
  return list(self._conversations.keys())
131
222
 
132
223
  def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
133
224
  return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
134
225
 
135
226
  def get_history(self, conversation: str) -> Optional[List[Dict]]:
136
- # TODO: after updating backend, this will fetch conversation from backend
137
227
  if conversation not in self._conversations:
138
- self._fetch_conversations()
228
+ self._get_conversations()
139
229
 
140
230
  if conversation not in self._conversations:
141
231
  raise ValueError(
@@ -152,6 +242,8 @@ class LLM:
152
242
  return history
153
243
 
154
244
  def reset_conversation(self, conversation: str) -> None:
245
+ if conversation not in self._conversations:
246
+ self._get_conversations()
155
247
  if conversation in self._conversations:
156
248
  self._llm_api.reset_conversation(
157
249
  assistant_id=self._model.id,
@@ -1,5 +1,5 @@
1
1
  from lightning_sdk.services.file_endpoint import Client
2
- from lightning_sdk.services.finetune import LLMFinetune
2
+ from lightning_sdk.services.finetune_llm import LLMFinetune
3
3
  from lightning_sdk.services.utilities import download_file
4
4
 
5
5
  __all__ = ["LLMFinetune", "Client", "download_file"]
@@ -0,0 +1,236 @@
1
+ import importlib
2
+ import json
3
+ import os
4
+ import socket
5
+ import threading
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ from lightning_sdk.api.license_api import LicenseApi
11
+
12
+
13
+ class LightningLicense:
14
+ """This class is used to manage the license for the Lightning SDK."""
15
+
16
+ _is_valid: Optional[bool] = None
17
+ _license_api: Optional[LicenseApi] = None
18
+ _stream_messages: Optional[callable] = None
19
+
20
+ def __init__(
21
+ self,
22
+ name: str,
23
+ license_key: Optional[str] = None,
24
+ product_version: Optional[str] = None,
25
+ product_type: str = "package",
26
+ stream_messages: callable = print,
27
+ ) -> None:
28
+ self._product_name = name
29
+ self._license_key = license_key
30
+ self._product_version = product_version
31
+ self.product_type = product_type
32
+ self._is_valid = None
33
+ self._license_api = None
34
+ self._stream_messages = stream_messages
35
+
36
+ def validate_license(self) -> bool:
37
+ """Validate the license key."""
38
+ if not self.is_online():
39
+ raise ConnectionError("No internet connection.")
40
+
41
+ self._license_api = LicenseApi()
42
+ return self._license_api.valid_license(
43
+ license_key=self.license_key,
44
+ product_name=self.product_name,
45
+ product_version=self.product_version,
46
+ product_type=self.product_type,
47
+ )
48
+
49
+ @staticmethod
50
+ def is_online(timeout: float = 2.0) -> bool:
51
+ """Check if the system is online by attempting to connect to a public DNS server (Google's).
52
+
53
+ This is a simple way to check for internet connectivity.
54
+
55
+ Args:
56
+ timeout: The timeout for the connection attempt.
57
+ """
58
+ try:
59
+ socket.create_connection(("8.8.8.8", 53), timeout=timeout)
60
+ return True
61
+ except OSError:
62
+ return False
63
+
64
+ @property
65
+ def is_valid(self) -> Optional[bool]:
66
+ """Check if the license key is valid.
67
+
68
+ license validation within package:
69
+ - user online with valid key -> everything as now
70
+ - user online with invalid key -> warning using wrong key + instructions
71
+ - user online with no key -> warning for missing license approval + instructions
72
+ - user offline with a key -> small warning that key could not be verified
73
+ - user offline with no key -> warning for missing license approval + instructions
74
+ """
75
+ if isinstance(self._is_valid, bool):
76
+ # if the license key is already validated, return the cached value
77
+ return self._is_valid
78
+ if not self.product_version:
79
+ self._stream_messages("Product version is not set correctly, consider leave it empty for auto-determine.")
80
+ if not self.license_key:
81
+ self._stream_messages(
82
+ "License key is not set neither cannot be found in the package root or user home."
83
+ " Please make sure you have signed the license agreement and set the license key."
84
+ " For more information, please refer to the documentation.",
85
+ )
86
+ is_online = self.is_online()
87
+ if self.license_key and is_online:
88
+ self._is_valid = self.validate_license()
89
+ elif not is_online:
90
+ self._stream_messages(
91
+ "License key is set but the system is offline. "
92
+ "Please make sure you have a valid license key and the system is online."
93
+ )
94
+ return self._is_valid
95
+
96
+ @property
97
+ def has_required_details(self) -> bool:
98
+ """Check if the license key and product name are set."""
99
+ return bool(self.license_key and self.product_name and self.product_type)
100
+
101
+ @staticmethod
102
+ def _find_package_license_key(package_name: str) -> Optional[str]:
103
+ """Find the license key in the package root as .license_key or in user home as .lightning/licenses.json.
104
+
105
+ Args:
106
+ package_name: The name of the package. If not provided, it will be determined from the current module.
107
+ """
108
+ if not package_name:
109
+ return None
110
+ try:
111
+ pkg_locations = importlib.util.find_spec(package_name).submodule_search_locations
112
+ if not pkg_locations:
113
+ return None
114
+ license_file = os.path.join(pkg_locations[0], ".license_key")
115
+ with open(license_file) as fp:
116
+ return fp.read().strip()
117
+ except (FileNotFoundError, ModuleNotFoundError):
118
+ return None
119
+
120
+ @staticmethod
121
+ def _find_user_license_key(package_name: str) -> Optional[str]:
122
+ """Find the license key in the user home as .lightning/licenses.json.
123
+
124
+ Args:
125
+ package_name: The name of the package.
126
+ """
127
+ home = str(Path.home())
128
+ package_name = package_name.lower()
129
+ license_file = os.path.join(home, ".lightning", "licenses.json")
130
+ try:
131
+ with open(license_file) as fp:
132
+ licenses = json.load(fp)
133
+ # Check for the license key in the licenses.json file
134
+ for name in (package_name, package_name.replace("-", "_"), package_name.replace("_", "-")):
135
+ if name in licenses:
136
+ return licenses[name]
137
+ return None
138
+ except (FileNotFoundError, json.JSONDecodeError):
139
+ return None
140
+
141
+ @staticmethod
142
+ def _determine_package_version(package_name: str) -> Optional[str]:
143
+ """Determine the product version based on the instantiation of the class.
144
+
145
+ Args:
146
+ package_name: The name of the package. If not provided, it will be determined from the current module.
147
+ """
148
+ try:
149
+ pkg = importlib.import_module(package_name)
150
+ return getattr(pkg, "__version__", None)
151
+ except ImportError:
152
+ return None
153
+
154
+ @property
155
+ def license_key(self) -> Optional[str]:
156
+ """Get the license key."""
157
+ if not self._license_key:
158
+ # If the license key is not set, fist try to find it in the package root
159
+ self._license_key = self._find_package_license_key(self.product_name.replace("-", "_"))
160
+ # If not found, try to find it in the user home
161
+ if not self._license_key:
162
+ self._license_key = self._find_user_license_key(self.product_name)
163
+ return self._license_key
164
+
165
+ @property
166
+ def product_name(self) -> str:
167
+ """Get the product name."""
168
+ return self._product_name
169
+
170
+ @property
171
+ def product_version(self) -> Optional[str]:
172
+ """Get the product version."""
173
+ if not self._product_version and self.product_type == "package":
174
+ self._product_version = self._determine_package_version(self.product_name.replace("-", "_"))
175
+ return self._product_version
176
+
177
+
178
+ def check_license(
179
+ name: str,
180
+ license_key: Optional[str] = None,
181
+ product_version: Optional[str] = None,
182
+ product_type: str = "package",
183
+ stream_messages: callable = print,
184
+ ) -> None:
185
+ """Run the license check and stream outputs.
186
+
187
+ Args:
188
+ name: The name of the product.
189
+ license_key: The license key to check.
190
+ product_version: The version of the product.
191
+ product_type: The type of the product.
192
+ stream_messages: A callable to stream messages.
193
+ """
194
+ lit_license = LightningLicense(
195
+ name=name,
196
+ license_key=license_key,
197
+ product_version=product_version,
198
+ product_type=product_type,
199
+ stream_messages=stream_messages,
200
+ )
201
+ if lit_license.is_valid is False:
202
+ stream_messages(
203
+ "License key is not valid.\n"
204
+ f" Key: {lit_license.license_key}\n"
205
+ " Please make sure you have a valid license key."
206
+ )
207
+
208
+
209
+ def check_license_in_background(
210
+ name: str,
211
+ license_key: Optional[str] = None,
212
+ product_version: Optional[str] = None,
213
+ product_type: str = "package",
214
+ stream_messages: callable = print,
215
+ ) -> threading.Thread:
216
+ """Run the license check in a background thread and stream outputs.
217
+
218
+ Args:
219
+ name: The name of the product.
220
+ license_key: The license key to check.
221
+ product_version: The version of the product.
222
+ product_type: The type of the product.
223
+ stream_messages: A callable to stream messages.
224
+ """
225
+ check_license_local = partial(
226
+ check_license,
227
+ name=name,
228
+ license_key=license_key,
229
+ product_version=product_version,
230
+ product_type=product_type,
231
+ stream_messages=stream_messages,
232
+ )
233
+
234
+ thread = threading.Thread(target=check_license_local, daemon=True)
235
+ thread.start()
236
+ return thread
lightning_sdk/studio.py CHANGED
@@ -1,13 +1,16 @@
1
1
  import glob
2
2
  import os
3
3
  import warnings
4
+ from enum import Enum
4
5
  from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Union
5
6
 
6
7
  from tqdm.auto import tqdm
7
8
 
9
+ from lightning_sdk.api.cluster_api import ClusterApi
8
10
  from lightning_sdk.api.studio_api import StudioApi
9
11
  from lightning_sdk.api.utils import _machine_to_compute_name
10
12
  from lightning_sdk.constants import _LIGHTNING_DEBUG
13
+ from lightning_sdk.lightning_cloud.openapi import V1CloudSpaceSourceType
11
14
  from lightning_sdk.machine import Machine
12
15
  from lightning_sdk.organization import Organization
13
16
  from lightning_sdk.owner import Owner
@@ -24,6 +27,19 @@ if TYPE_CHECKING:
24
27
  _logger = _setup_logger(__name__)
25
28
 
26
29
 
30
+ class Provider(Enum):
31
+ # Machine providers based on v1CloudProvider
32
+ AWS = "AWS"
33
+ GCP = "GCP"
34
+ VULTR = "VULTR"
35
+ LAMBDA_LABS = "LAMBDA_LABS"
36
+ DGX = "DGX"
37
+ VOLTAGE_PARK = "VOLTAGE_PARK"
38
+ NEBIUS = "NEBIUS"
39
+ CLOUDFLARE = "CLOUDFLARE"
40
+ LIGHTNING = "LIGHTNING"
41
+
42
+
27
43
  class Studio:
28
44
  """A single Lightning AI Studio.
29
45
 
@@ -38,6 +54,8 @@ class Studio:
38
54
  cloud_account: the name of the cloud account, the studio should be created on.
39
55
  Doesn't matter when the studio already exists.
40
56
  create_ok: whether the studio will be created if it does not yet exist. Defaults to True
57
+ provider: the provider of the machine, the studio should be created on.
58
+
41
59
  Note:
42
60
  Since a teamspace can either be owned by an org or by a user directly,
43
61
  only one of the arguments can be provided.
@@ -56,8 +74,11 @@ class Studio:
56
74
  cloud_account: Optional[str] = None,
57
75
  create_ok: bool = True,
58
76
  cluster: Optional[str] = None, # deprecated in favor of cloud_account
77
+ provider: Optional[str] = None,
78
+ source: Optional[V1CloudSpaceSourceType] = None,
59
79
  ) -> None:
60
80
  self._studio_api = StudioApi()
81
+ self._cluster_api = ClusterApi()
61
82
 
62
83
  self._teamspace = _resolve_teamspace(teamspace=teamspace, org=org, user=user)
63
84
  self._cloud_account = _resolve_deprecated_cluster(cloud_account, cluster)
@@ -65,6 +86,16 @@ class Studio:
65
86
 
66
87
  self._plugins = {}
67
88
 
89
+ if provider is not None:
90
+ if isinstance(provider, str) and provider in Provider.__members__:
91
+ provider = Provider(provider)
92
+ else:
93
+ raise ValueError(f"Invalid provider: {provider}. Must be one of {Provider.__members__.keys()}.")
94
+ self._cloud_account = self._cluster_api.get_cluster_provider_mapping(
95
+ self._teamspace.id,
96
+ self._teamspace.owner.id,
97
+ )[provider.value]
98
+
68
99
  if name is None:
69
100
  studio_id = os.environ.get("LIGHTNING_CLOUD_SPACE_ID", None)
70
101
  if studio_id is None:
@@ -76,7 +107,7 @@ class Studio:
76
107
  except ValueError as e:
77
108
  if create_ok:
78
109
  self._studio = self._studio_api.create_studio(
79
- name, self._teamspace.id, cloud_account=self._cloud_account
110
+ name, self._teamspace.id, cloud_account=self._cloud_account, source=source
80
111
  )
81
112
  else:
82
113
  raise ValueError(f"Studio {name} does not exist.") from e
@@ -221,6 +252,36 @@ class Studio:
221
252
  self._studio.id, self._teamspace.id, machine, interruptible=interruptible
222
253
  )
223
254
 
255
+ def run_and_detach(self, *commands: str, timeout: float = 10, check_interval: float = 1) -> str:
256
+ """Runs given commands on the Studio and returns immediately.
257
+
258
+ The command will continue to run in the background.
259
+
260
+ Args:
261
+ timeout: wait for this many seconds for the command to finish.
262
+ check_interval: check the status of the command every this many seconds.
263
+ """
264
+ if check_interval > timeout:
265
+ raise ValueError("check_interval must be less than timeout")
266
+
267
+ if _LIGHTNING_DEBUG:
268
+ print(f"Running {commands=}")
269
+ status = self.status
270
+ if status != Status.Running:
271
+ raise RuntimeError(f"Cannot run a command in a studio that is not running. Studio {self.name} is {status}.")
272
+
273
+ iter_output = self._studio_api.run_studio_commands_and_yield(
274
+ self._studio.id, self._teamspace.id, *commands, timeout=timeout, check_interval=check_interval
275
+ )
276
+
277
+ output = ""
278
+ code = None
279
+ for line, exit_code in iter_output:
280
+ print(line)
281
+ output += line
282
+ code = exit_code
283
+ return output, code
284
+
224
285
  def run_with_exit_code(self, *commands: str) -> Tuple[str, int]:
225
286
  """Runs given commands on the Studio while returning output and exit code.
226
287