lightning-sdk 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. lightning_sdk/__init__.py +1 -1
  2. lightning_sdk/api/llm_api.py +28 -5
  3. lightning_sdk/api/studio_api.py +17 -0
  4. lightning_sdk/cli/entrypoint.py +1 -1
  5. lightning_sdk/cli/serve.py +149 -39
  6. lightning_sdk/deployment/deployment.py +2 -2
  7. lightning_sdk/lightning_cloud/openapi/__init__.py +6 -0
  8. lightning_sdk/lightning_cloud/openapi/api/__init__.py +1 -0
  9. lightning_sdk/lightning_cloud/openapi/api/git_credentials_service_api.py +497 -0
  10. lightning_sdk/lightning_cloud/openapi/api/jobs_service_api.py +14 -5
  11. lightning_sdk/lightning_cloud/openapi/models/__init__.py +5 -0
  12. lightning_sdk/lightning_cloud/openapi/models/deployments_id_body.py +27 -1
  13. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_accelerator.py +27 -1
  14. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_security_options.py +27 -1
  15. lightning_sdk/lightning_cloud/openapi/models/v1_cluster_spec.py +79 -1
  16. lightning_sdk/lightning_cloud/openapi/models/v1_create_git_credentials_request.py +175 -0
  17. lightning_sdk/lightning_cloud/openapi/models/v1_delete_git_credentials_response.py +97 -0
  18. lightning_sdk/lightning_cloud/openapi/models/v1_deployment.py +27 -1
  19. lightning_sdk/lightning_cloud/openapi/models/v1_deployment_state.py +2 -0
  20. lightning_sdk/lightning_cloud/openapi/models/v1_git_credentials.py +227 -0
  21. lightning_sdk/lightning_cloud/openapi/models/v1_list_git_credentials_response.py +123 -0
  22. lightning_sdk/lightning_cloud/openapi/models/v1_list_job_resources_response.py +15 -15
  23. lightning_sdk/lightning_cloud/openapi/models/v1_nebius_direct_v1.py +149 -0
  24. lightning_sdk/lightning_cloud/openapi/models/v1_user_features.py +53 -1
  25. lightning_sdk/llm/llm.py +134 -30
  26. lightning_sdk/plugin.py +19 -0
  27. lightning_sdk/studio.py +33 -0
  28. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/METADATA +1 -1
  29. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/RECORD +34 -28
  30. /lightning_sdk/cli/{docker.py → docker_cli.py} +0 -0
  31. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/LICENSE +0 -0
  32. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/WHEEL +0 -0
  33. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/entry_points.txt +0 -0
  34. {lightning_sdk-0.2.12.dist-info → lightning_sdk-0.2.14.dist-info}/top_level.txt +0 -0
@@ -85,6 +85,7 @@ class V1UserFeatures(object):
85
85
  'mmt_strategy_selector': 'bool',
86
86
  'multicloud_saas': 'bool',
87
87
  'multiple_studio_versions': 'bool',
88
+ 'nerf_fs_nonpaying': 'bool',
88
89
  'org_admin_alerts': 'bool',
89
90
  'org_level_member_permissions': 'bool',
90
91
  'org_usage_limits': 'bool',
@@ -110,6 +111,7 @@ class V1UserFeatures(object):
110
111
  'slurm_machine_selector': 'bool',
111
112
  'stop_ide_container_on_shutdown': 'bool',
112
113
  'studio_config': 'bool',
114
+ 'studio_deployment': 'bool',
113
115
  'studio_on_stop': 'bool',
114
116
  'studio_version_visibility': 'bool',
115
117
  'studios_dashboard': 'bool',
@@ -167,6 +169,7 @@ class V1UserFeatures(object):
167
169
  'mmt_strategy_selector': 'mmtStrategySelector',
168
170
  'multicloud_saas': 'multicloudSaas',
169
171
  'multiple_studio_versions': 'multipleStudioVersions',
172
+ 'nerf_fs_nonpaying': 'nerfFsNonpaying',
170
173
  'org_admin_alerts': 'orgAdminAlerts',
171
174
  'org_level_member_permissions': 'orgLevelMemberPermissions',
172
175
  'org_usage_limits': 'orgUsageLimits',
@@ -192,6 +195,7 @@ class V1UserFeatures(object):
192
195
  'slurm_machine_selector': 'slurmMachineSelector',
193
196
  'stop_ide_container_on_shutdown': 'stopIdeContainerOnShutdown',
194
197
  'studio_config': 'studioConfig',
198
+ 'studio_deployment': 'studioDeployment',
195
199
  'studio_on_stop': 'studioOnStop',
196
200
  'studio_version_visibility': 'studioVersionVisibility',
197
201
  'studios_dashboard': 'studiosDashboard',
@@ -204,7 +208,7 @@ class V1UserFeatures(object):
204
208
  'weka': 'weka'
205
209
  }
206
210
 
207
- def __init__(self, affiliate_links: 'bool' =None, agents_v2: 'bool' =None, ai_hub_monetization: 'bool' =None, auto_fast_load: 'bool' =None, auto_join_orgs: 'bool' =None, b2c_experience: 'bool' =None, byoc_litcr: 'bool' =None, cap_add: 'list[str]' =None, cap_drop: 'list[str]' =None, capacity_reservation_byoc: 'bool' =None, capacity_reservation_dry_run: 'bool' =None, chat_models: 'bool' =None, cloud_space_environment_templates: 'bool' =None, code_tab: 'bool' =None, collab_screen_sharing: 'bool' =None, concurrent_gpu_limit: 'bool' =None, cost_attribution_settings: 'bool' =None, custom_app_domain: 'bool' =None, custom_instance_types: 'bool' =None, datasets: 'bool' =None, default_one_cluster: 'bool' =None, deployment_alerts: 'bool' =None, deployment_persistent_disk: 'bool' =None, dgx_cloud: 'bool' =None, doc_helper_chat: 'bool' =None, docs_agent: 'bool' =None, drive_v2: 'bool' =None, enable_storage_limits: 'bool' =None, enterprise_compute_admin: 'bool' =None, fair_share: 'bool' =None, featured_studios_admin: 'bool' =None, filestore: 'bool' =None, gcp_local_disk_binding: 'bool' =None, inactive_notify_delete: 'bool' =None, instant_capacity_reservation: 'bool' =None, job_artifacts_v2: 'bool' =None, lambda_labs: 'bool' =None, landing_studios: 'bool' =None, lit_logger: 'bool' =None, marketplace: 'bool' =None, mmt_fault_tolerance: 'bool' =None, mmt_strategy_selector: 'bool' =None, multicloud_saas: 'bool' =None, multiple_studio_versions: 'bool' =None, org_admin_alerts: 'bool' =None, org_level_member_permissions: 'bool' =None, org_usage_limits: 'bool' =None, pipelines: 'bool' =None, plugin_distributed: 'bool' =None, plugin_inference: 'bool' =None, plugin_label_studio: 'bool' =None, plugin_langflow: 'bool' =None, plugin_python_profiler: 'bool' =None, plugin_service: 'bool' =None, plugin_sweeps: 'bool' =None, pricing_updates: 'bool' =None, product_generator: 'bool' =None, project_selector: 'bool' =None, publish_pipelines: 'bool' =None, r2_data_connections: 'bool' =None, restartable_jobs: 'bool' =None, runnable_public_studio_page: 'bool' =None, security_docs: 'bool' =None, show_dev_admin: 'bool' =None, single_wallet: 'bool' =None, slurm: 'bool' =None, slurm_machine_selector: 'bool' =None, stop_ide_container_on_shutdown: 'bool' =None, studio_config: 'bool' =None, studio_on_stop: 'bool' =None, studio_version_visibility: 'bool' =None, studios_dashboard: 'bool' =None, studios_dashboard_system_metrics: 'bool' =None, teamspace_storage_tab: 'bool' =None, trainium2: 'bool' =None, use_rclone_mounts_only: 'bool' =None, voltage_park: 'bool' =None, vultr: 'bool' =None, weka: 'bool' =None): # noqa: E501
211
+ def __init__(self, affiliate_links: 'bool' =None, agents_v2: 'bool' =None, ai_hub_monetization: 'bool' =None, auto_fast_load: 'bool' =None, auto_join_orgs: 'bool' =None, b2c_experience: 'bool' =None, byoc_litcr: 'bool' =None, cap_add: 'list[str]' =None, cap_drop: 'list[str]' =None, capacity_reservation_byoc: 'bool' =None, capacity_reservation_dry_run: 'bool' =None, chat_models: 'bool' =None, cloud_space_environment_templates: 'bool' =None, code_tab: 'bool' =None, collab_screen_sharing: 'bool' =None, concurrent_gpu_limit: 'bool' =None, cost_attribution_settings: 'bool' =None, custom_app_domain: 'bool' =None, custom_instance_types: 'bool' =None, datasets: 'bool' =None, default_one_cluster: 'bool' =None, deployment_alerts: 'bool' =None, deployment_persistent_disk: 'bool' =None, dgx_cloud: 'bool' =None, doc_helper_chat: 'bool' =None, docs_agent: 'bool' =None, drive_v2: 'bool' =None, enable_storage_limits: 'bool' =None, enterprise_compute_admin: 'bool' =None, fair_share: 'bool' =None, featured_studios_admin: 'bool' =None, filestore: 'bool' =None, gcp_local_disk_binding: 'bool' =None, inactive_notify_delete: 'bool' =None, instant_capacity_reservation: 'bool' =None, job_artifacts_v2: 'bool' =None, lambda_labs: 'bool' =None, landing_studios: 'bool' =None, lit_logger: 'bool' =None, marketplace: 'bool' =None, mmt_fault_tolerance: 'bool' =None, mmt_strategy_selector: 'bool' =None, multicloud_saas: 'bool' =None, multiple_studio_versions: 'bool' =None, nerf_fs_nonpaying: 'bool' =None, org_admin_alerts: 'bool' =None, org_level_member_permissions: 'bool' =None, org_usage_limits: 'bool' =None, pipelines: 'bool' =None, plugin_distributed: 'bool' =None, plugin_inference: 'bool' =None, plugin_label_studio: 'bool' =None, plugin_langflow: 'bool' =None, plugin_python_profiler: 'bool' =None, plugin_service: 'bool' =None, plugin_sweeps: 'bool' =None, pricing_updates: 'bool' =None, product_generator: 'bool' =None, project_selector: 'bool' =None, publish_pipelines: 'bool' =None, r2_data_connections: 'bool' =None, restartable_jobs: 'bool' =None, runnable_public_studio_page: 'bool' =None, security_docs: 'bool' =None, show_dev_admin: 'bool' =None, single_wallet: 'bool' =None, slurm: 'bool' =None, slurm_machine_selector: 'bool' =None, stop_ide_container_on_shutdown: 'bool' =None, studio_config: 'bool' =None, studio_deployment: 'bool' =None, studio_on_stop: 'bool' =None, studio_version_visibility: 'bool' =None, studios_dashboard: 'bool' =None, studios_dashboard_system_metrics: 'bool' =None, teamspace_storage_tab: 'bool' =None, trainium2: 'bool' =None, use_rclone_mounts_only: 'bool' =None, voltage_park: 'bool' =None, vultr: 'bool' =None, weka: 'bool' =None): # noqa: E501
208
212
  """V1UserFeatures - a model defined in Swagger""" # noqa: E501
209
213
  self._affiliate_links = None
210
214
  self._agents_v2 = None
@@ -250,6 +254,7 @@ class V1UserFeatures(object):
250
254
  self._mmt_strategy_selector = None
251
255
  self._multicloud_saas = None
252
256
  self._multiple_studio_versions = None
257
+ self._nerf_fs_nonpaying = None
253
258
  self._org_admin_alerts = None
254
259
  self._org_level_member_permissions = None
255
260
  self._org_usage_limits = None
@@ -275,6 +280,7 @@ class V1UserFeatures(object):
275
280
  self._slurm_machine_selector = None
276
281
  self._stop_ide_container_on_shutdown = None
277
282
  self._studio_config = None
283
+ self._studio_deployment = None
278
284
  self._studio_on_stop = None
279
285
  self._studio_version_visibility = None
280
286
  self._studios_dashboard = None
@@ -374,6 +380,8 @@ class V1UserFeatures(object):
374
380
  self.multicloud_saas = multicloud_saas
375
381
  if multiple_studio_versions is not None:
376
382
  self.multiple_studio_versions = multiple_studio_versions
383
+ if nerf_fs_nonpaying is not None:
384
+ self.nerf_fs_nonpaying = nerf_fs_nonpaying
377
385
  if org_admin_alerts is not None:
378
386
  self.org_admin_alerts = org_admin_alerts
379
387
  if org_level_member_permissions is not None:
@@ -424,6 +432,8 @@ class V1UserFeatures(object):
424
432
  self.stop_ide_container_on_shutdown = stop_ide_container_on_shutdown
425
433
  if studio_config is not None:
426
434
  self.studio_config = studio_config
435
+ if studio_deployment is not None:
436
+ self.studio_deployment = studio_deployment
427
437
  if studio_on_stop is not None:
428
438
  self.studio_on_stop = studio_on_stop
429
439
  if studio_version_visibility is not None:
@@ -1369,6 +1379,27 @@ class V1UserFeatures(object):
1369
1379
 
1370
1380
  self._multiple_studio_versions = multiple_studio_versions
1371
1381
 
1382
+ @property
1383
+ def nerf_fs_nonpaying(self) -> 'bool':
1384
+ """Gets the nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
1385
+
1386
+
1387
+ :return: The nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
1388
+ :rtype: bool
1389
+ """
1390
+ return self._nerf_fs_nonpaying
1391
+
1392
+ @nerf_fs_nonpaying.setter
1393
+ def nerf_fs_nonpaying(self, nerf_fs_nonpaying: 'bool'):
1394
+ """Sets the nerf_fs_nonpaying of this V1UserFeatures.
1395
+
1396
+
1397
+ :param nerf_fs_nonpaying: The nerf_fs_nonpaying of this V1UserFeatures. # noqa: E501
1398
+ :type: bool
1399
+ """
1400
+
1401
+ self._nerf_fs_nonpaying = nerf_fs_nonpaying
1402
+
1372
1403
  @property
1373
1404
  def org_admin_alerts(self) -> 'bool':
1374
1405
  """Gets the org_admin_alerts of this V1UserFeatures. # noqa: E501
@@ -1894,6 +1925,27 @@ class V1UserFeatures(object):
1894
1925
 
1895
1926
  self._studio_config = studio_config
1896
1927
 
1928
+ @property
1929
+ def studio_deployment(self) -> 'bool':
1930
+ """Gets the studio_deployment of this V1UserFeatures. # noqa: E501
1931
+
1932
+
1933
+ :return: The studio_deployment of this V1UserFeatures. # noqa: E501
1934
+ :rtype: bool
1935
+ """
1936
+ return self._studio_deployment
1937
+
1938
+ @studio_deployment.setter
1939
+ def studio_deployment(self, studio_deployment: 'bool'):
1940
+ """Sets the studio_deployment of this V1UserFeatures.
1941
+
1942
+
1943
+ :param studio_deployment: The studio_deployment of this V1UserFeatures. # noqa: E501
1944
+ :type: bool
1945
+ """
1946
+
1947
+ self._studio_deployment = studio_deployment
1948
+
1897
1949
  @property
1898
1950
  def studio_on_stop(self) -> 'bool':
1899
1951
  """Gets the studio_on_stop of this V1UserFeatures. # noqa: E501
lightning_sdk/llm/llm.py CHANGED
@@ -1,41 +1,64 @@
1
- from typing import Dict, List, Optional, Set, Tuple
1
+ from typing import Dict, List, Optional, Set, Tuple, Union
2
2
 
3
+ from lightning_sdk.api import UserApi
3
4
  from lightning_sdk.api.llm_api import LLMApi
5
+ from lightning_sdk.lightning_cloud.login import Auth
4
6
  from lightning_sdk.lightning_cloud.openapi import V1Assistant
7
+ from lightning_sdk.lightning_cloud.openapi.rest import ApiException
8
+ from lightning_sdk.organization import Organization
9
+ from lightning_sdk.user import User
10
+ from lightning_sdk.utils.resolve import _resolve_org, _resolve_user
5
11
 
6
12
 
7
13
  class LLM:
8
- def __init__(self, name: str) -> None:
14
+ def __init__(
15
+ self,
16
+ name: str,
17
+ user: Union[str, "User", None] = None,
18
+ org: Union[str, "Organization", None] = None,
19
+ ) -> None:
20
+ self._auth = Auth()
21
+ self._user = None
22
+
23
+ try:
24
+ self._auth.authenticate()
25
+ self._user = User(name=UserApi()._get_user_by_id(self._auth.user_id).username)
26
+ except ConnectionError as e:
27
+ raise e
28
+
29
+ self._name = name
30
+ try:
31
+ self._user = _resolve_user(self._user or user)
32
+ except ValueError:
33
+ self._user = None
34
+
9
35
  self._name = name
10
36
  self._org, self._model_name = self._parse_model_name(name)
37
+ try:
38
+ # check if it is a org model
39
+ self._org = _resolve_org(self._org or org)
40
+ except ApiException:
41
+ self._org = None
42
+
11
43
  self._llm_api = LLMApi()
12
- self._models = self._build_model_lookup(self._llm_api.list_models())
13
- self._model_exists()
14
- self._public_models = self._build_public_model_lookup(self._get_public_models())
44
+ self._public_models = self._build_model_lookup(self._get_public_models())
45
+ self._org_models = self._build_model_lookup(self._get_org_models())
46
+ self._user_models = self._build_model_lookup(self._get_user_models())
15
47
  self._model = self._get_model()
48
+ self._conversations = {}
16
49
 
17
50
  def _parse_model_name(self, name: str) -> Tuple[str, str]:
18
51
  parts = name.split("/")
19
- if len(parts) != 2:
20
- raise ValueError(f"Model name must be in the format `organization/model_name`, but got '{name}'.")
21
- return parts[0], parts[1]
52
+ if len(parts) == 1:
53
+ # a user model or a org model
54
+ return None, parts[0]
55
+ if len(parts) == 2:
56
+ return parts[0], parts[1]
57
+ raise ValueError(
58
+ f"Model name must be in the format `organization/model_name` or `model_name`, but got '{name}'."
59
+ )
22
60
 
23
61
  def _build_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
24
- return {endpoint.id: {model.name for model in endpoint.models_metadata} for endpoint in endpoints}
25
-
26
- def _model_exists(self) -> bool:
27
- if self._org not in self._models:
28
- raise ValueError(
29
- f"Model provider {self._org} not found. Available models providers: {list(self._models.keys())}"
30
- )
31
-
32
- if self._model_name not in self._models[self._org]:
33
- raise ValueError(
34
- f"Model {self._model_name} not found. Available models by {self._org}: {self._models[self._org]}"
35
- )
36
- return True
37
-
38
- def _build_public_model_lookup(self, endpoints: List[str]) -> Dict[str, Set[str]]:
39
62
  result = {}
40
63
  for endpoint in endpoints:
41
64
  result.setdefault(endpoint.model, []).append(endpoint)
@@ -44,13 +67,94 @@ class LLM:
44
67
  def _get_public_models(self) -> List[str]:
45
68
  return self._llm_api.get_public_models()
46
69
 
47
- def _get_model(self, public_model: bool = True) -> V1Assistant:
48
- # TODO figure out how to identify if model is public or not
49
- if not public_model:
50
- raise NotImplementedError("Non-public models are not supported yet.")
70
+ def _get_org_models(self) -> List[str]:
71
+ return self._llm_api.get_org_models(self._org.id) if self._org else []
72
+
73
+ def _get_user_models(self) -> List[str]:
74
+ return self._llm_api.get_user_models(self._user.id) if self._user else []
75
+
76
+ def _get_model(self) -> V1Assistant:
51
77
  # TODO how to handle multiple models with same model type? For now, just use the first one
52
- return self._public_models.get(self._model_name)[0]
78
+ if self._model_name in self._public_models:
79
+ return self._public_models.get(self._model_name)[0]
80
+ if self._model_name in self._org_models:
81
+ return self._org_models.get(self._model_name)[0]
82
+ if self._model_name in self._user_models:
83
+ return self._user_models.get(self._model_name)[0]
84
+
85
+ available_models = []
86
+ if self._public_models:
87
+ available_models.append(f"Public Models: {', '.join(self._public_models.keys())}")
53
88
 
54
- def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
55
- output = self._llm_api.start_conversation(prompt, system_prompt, self._model.id)
89
+ if self._org and self._org_models:
90
+ available_models.append(f"Org ({self._org.name}) Models: {', '.join(self._org_models.keys())}")
91
+
92
+ if self._user and self._user_models:
93
+ available_models.append(f"User ({self._user.name}) Models: {', '.join(self._user_models.keys())}")
94
+
95
+ available_models_str = "\n".join(available_models)
96
+ raise ValueError(f"Model '{self._model_name}' not found. \nAvailable models: \n{available_models_str}")
97
+
98
+ def _get_conversations(self) -> Dict[str, str]:
99
+ # TODO: after updating backend, this will fetch conversations from backend
100
+ # conversations = self._llm_api.list_conversations(assistant_id=self._model.id)
101
+ return self._conversations
102
+
103
+ def _fetch_conversations(self) -> None:
104
+ self._conversations = self._get_conversations()
105
+
106
+ def chat(
107
+ self,
108
+ prompt: str,
109
+ system_prompt: Optional[str] = None,
110
+ max_completion_tokens: Optional[int] = 500,
111
+ conversation: Optional[str] = None,
112
+ ) -> str:
113
+ if conversation and conversation not in self._conversations:
114
+ self._fetch_conversations()
115
+
116
+ conversation_id = self._conversations.get(conversation) if conversation else None
117
+ output = self._llm_api.start_conversation(
118
+ prompt=prompt,
119
+ system_prompt=system_prompt,
120
+ max_completion_tokens=max_completion_tokens,
121
+ assistant_id=self._model.id,
122
+ conversation_id=conversation_id,
123
+ )
124
+ if conversation and not conversation_id:
125
+ self._conversations[conversation] = output.conversation_id
56
126
  return output.choices[0].delta.content
127
+
128
+ def list_conversations(self) -> List[Dict]:
129
+ self._fetch_conversations()
130
+ return list(self._conversations.keys())
131
+
132
+ def _get_conversation_messages(self, conversation_id: str) -> Optional[str]:
133
+ return self._llm_api.get_conversation(assistant_id=self._model.id, conversation_id=conversation_id)
134
+
135
+ def get_history(self, conversation: str) -> Optional[List[Dict]]:
136
+ # TODO: after updating backend, this will fetch conversation from backend
137
+ if conversation not in self._conversations:
138
+ self._fetch_conversations()
139
+
140
+ if conversation not in self._conversations:
141
+ raise ValueError(
142
+ f"Conversation '{conversation}' not found. \nAvailable conversations: {self._conversations.keys()}"
143
+ )
144
+
145
+ messages = self._get_conversation_messages(self._conversations[conversation])
146
+ history = []
147
+ for message in messages:
148
+ if message.author.role == "user":
149
+ history.append({"role": "user", "content": message.content[0].parts[0]})
150
+ elif message.author.role == "assistant":
151
+ history.append({"role": "assistant", "content": message.content[0].parts[0]})
152
+ return history
153
+
154
+ def reset_conversation(self, conversation: str) -> None:
155
+ if conversation in self._conversations:
156
+ self._llm_api.reset_conversation(
157
+ assistant_id=self._model.id,
158
+ conversation_id=self._conversations[conversation],
159
+ )
160
+ del self._conversations[conversation]
lightning_sdk/plugin.py CHANGED
@@ -395,6 +395,25 @@ class SlurmJobsPlugin(_Plugin):
395
395
  return resp
396
396
 
397
397
 
398
+ class CustomPortPlugin(_Plugin):
399
+ """Plugin handling the port of a given service."""
400
+
401
+ _plugin_run_name = "Custom Port"
402
+ _slug_name = "custom-port"
403
+
404
+ def run(self, name: Optional[str] = None, port: int = 8000) -> str:
405
+ """Starts a new port to the given Studio."""
406
+ if name is None:
407
+ name = _run_name("port")
408
+
409
+ return self._studio._studio_api.start_new_port(
410
+ teamspace_id=self._studio._teamspace.id,
411
+ studio_id=self._studio._studio.id,
412
+ name=name,
413
+ port=port,
414
+ )
415
+
416
+
398
417
  @runtime_checkable
399
418
  class _RunnablePlugin(Protocol):
400
419
  _plugin_run_name: str
lightning_sdk/studio.py CHANGED
@@ -1,7 +1,10 @@
1
+ import glob
1
2
  import os
2
3
  import warnings
3
4
  from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple, Union
4
5
 
6
+ from tqdm.auto import tqdm
7
+
5
8
  from lightning_sdk.api.studio_api import StudioApi
6
9
  from lightning_sdk.api.utils import _machine_to_compute_name
7
10
  from lightning_sdk.constants import _LIGHTNING_DEBUG
@@ -265,6 +268,34 @@ class Studio:
265
268
  progress_bar=progress_bar,
266
269
  )
267
270
 
271
+ def upload_folder(self, folder_path: str, remote_path: Optional[str] = None, progress_bar: bool = True) -> None:
272
+ """Uploads a given folder to a remote path on the Studio."""
273
+ if folder_path is None:
274
+ raise ValueError("Cannot upload a folder that is None.")
275
+ folder_path = os.path.normpath(folder_path)
276
+ if os.path.isfile(folder_path):
277
+ raise NotADirectoryError(f"Cannot upload a file as a folder. '{folder_path}' is a file.")
278
+ if not os.path.exists(folder_path):
279
+ raise NotADirectoryError(f"Cannot upload a folder that does not exist. '{folder_path}' is not a directory.")
280
+ all_files = []
281
+ for fp in glob.glob(os.path.join(folder_path, "**"), recursive=True):
282
+ if not os.path.isfile(fp):
283
+ continue
284
+ rel_path = os.path.relpath(fp, folder_path)
285
+ remote_file = os.path.join(remote_path, rel_path) if remote_path else rel_path
286
+ all_files.append((fp, remote_file))
287
+
288
+ if progress_bar:
289
+ progress_bar = tqdm(total=len(all_files), desc="Uploading files", unit="file")
290
+ for local_file, remote_path in sorted(all_files, key=lambda p: p[1]):
291
+ if progress_bar:
292
+ progress_bar.set_description(f"Uploading {local_file}")
293
+ self.upload_file(local_file, remote_path=remote_path, progress_bar=False)
294
+ if progress_bar:
295
+ progress_bar.update(1)
296
+ if progress_bar:
297
+ progress_bar.close()
298
+
268
299
  def download_file(self, remote_path: str, file_path: Optional[str] = None) -> None:
269
300
  """Downloads a file from the Studio to a given target path."""
270
301
  if file_path is None:
@@ -445,6 +476,7 @@ class Studio:
445
476
  def _add_plugin(self, plugin_name: str) -> None:
446
477
  """Adds the just installed plugin to the internal list of plugins."""
447
478
  from lightning_sdk.plugin import (
479
+ CustomPortPlugin,
448
480
  InferenceServerPlugin,
449
481
  JobsPlugin,
450
482
  MultiMachineTrainingPlugin,
@@ -458,6 +490,7 @@ class Studio:
458
490
  "jobs": JobsPlugin,
459
491
  "multi-machine-training": MultiMachineTrainingPlugin,
460
492
  "inference-server": InferenceServerPlugin,
493
+ "custom-port": CustomPortPlugin,
461
494
  }.get(plugin_name, Plugin)
462
495
 
463
496
  description = self._list_installed_plugins()[plugin_name]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lightning_sdk
3
- Version: 0.2.12
3
+ Version: 0.2.14
4
4
  Summary: SDK to develop using Lightning AI Studios
5
5
  Author-email: Lightning-AI <justus@lightning.ai>
6
6
  License: MIT License