skypilot-nightly 1.0.0.dev20240926__py3-none-any.whl → 1.0.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sky/clouds/oci.py CHANGED
@@ -431,14 +431,17 @@ class OCI(clouds.Cloud):
431
431
 
432
432
  def get_credential_file_mounts(self) -> Dict[str, str]:
433
433
  """Returns a dict of credential file paths to mount paths."""
434
- oci_cfg_file = oci_adaptor.get_config_file()
435
- # Pass-in a profile parameter so that multiple profile in oci
436
- # config file is supported (2023/06/09).
437
- oci_cfg = oci_adaptor.get_oci_config(
438
- profile=oci_utils.oci_config.get_profile())
439
- api_key_file = oci_cfg[
440
- 'key_file'] if 'key_file' in oci_cfg else 'BadConf'
441
- sky_cfg_file = oci_utils.oci_config.get_sky_user_config_file()
434
+ try:
435
+ oci_cfg_file = oci_adaptor.get_config_file()
436
+ # Pass-in a profile parameter so that multiple profile in oci
437
+ # config file is supported (2023/06/09).
438
+ oci_cfg = oci_adaptor.get_oci_config(
439
+ profile=oci_utils.oci_config.get_profile())
440
+ api_key_file = oci_cfg[
441
+ 'key_file'] if 'key_file' in oci_cfg else 'BadConf'
442
+ sky_cfg_file = oci_utils.oci_config.get_sky_user_config_file()
443
+ except ImportError:
444
+ return {}
442
445
 
443
446
  # OCI config and API key file are mandatory
444
447
  credential_files = [oci_cfg_file, api_key_file]
@@ -68,26 +68,35 @@ def list_accelerators_realtime(
68
68
  # TODO(romilb): This should be refactored to use get_kubernetes_node_info()
69
69
  # function from kubernetes_utils.
70
70
  del all_regions, require_price # Unused.
71
+ # TODO(zhwu): this should return all accelerators in multiple kubernetes
72
+ # clusters defined by allowed_contexts.
73
+ if region_filter is None:
74
+ context = kubernetes_utils.get_current_kube_config_context_name()
75
+ else:
76
+ context = region_filter
77
+ if context is None:
78
+ return {}, {}, {}
79
+
71
80
  k8s_cloud = Kubernetes()
72
81
  if not any(
73
82
  map(k8s_cloud.is_same_cloud,
74
83
  sky_check.get_cached_enabled_clouds_or_refresh())
75
- ) or not kubernetes_utils.check_credentials()[0]:
84
+ ) or not kubernetes_utils.check_credentials(context)[0]:
76
85
  return {}, {}, {}
77
86
 
78
- has_gpu = kubernetes_utils.detect_gpu_resource()
87
+ has_gpu = kubernetes_utils.detect_gpu_resource(context)
79
88
  if not has_gpu:
80
89
  return {}, {}, {}
81
90
 
82
- label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter()
91
+ label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter(context)
83
92
  if not label_formatter:
84
93
  return {}, {}, {}
85
94
 
86
95
  accelerators_qtys: Set[Tuple[str, int]] = set()
87
96
  key = label_formatter.get_label_key()
88
- nodes = kubernetes_utils.get_kubernetes_nodes()
97
+ nodes = kubernetes_utils.get_kubernetes_nodes(context)
89
98
  # Get the pods to get the real-time GPU usage
90
- pods = kubernetes_utils.get_kubernetes_pods()
99
+ pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context)
91
100
  # Total number of GPUs in the cluster
92
101
  total_accelerators_capacity: Dict[str, int] = {}
93
102
  # Total number of GPUs currently available in the cluster
@@ -160,7 +169,7 @@ def list_accelerators_realtime(
160
169
  memory=None,
161
170
  price=0.0,
162
171
  spot_price=0.0,
163
- region='kubernetes'))
172
+ region=context))
164
173
 
165
174
  df = pd.DataFrame(result,
166
175
  columns=[
@@ -175,7 +184,6 @@ def list_accelerators_realtime(
175
184
  qtys_map = common.list_accelerators_impl('Kubernetes', df, gpus_only,
176
185
  name_filter, region_filter,
177
186
  quantity_filter, case_sensitive)
178
-
179
187
  return qtys_map, total_accelerators_capacity, total_accelerators_available
180
188
 
181
189
 
@@ -79,13 +79,14 @@ def _open_ports_using_ingress(
79
79
  )
80
80
 
81
81
  # Prepare service names, ports, for template rendering
82
- service_details = [(f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
83
- _PATH_PREFIX.format(
84
- cluster_name_on_cloud=cluster_name_on_cloud,
85
- port=port,
86
- namespace=kubernetes_utils.
87
- get_current_kube_config_context_namespace()).rstrip(
88
- '/').lstrip('/')) for port in ports]
82
+ service_details = [
83
+ (f'{cluster_name_on_cloud}--skypilot-svc--{port}', port,
84
+ _PATH_PREFIX.format(
85
+ cluster_name_on_cloud=cluster_name_on_cloud,
86
+ port=port,
87
+ namespace=kubernetes_utils.get_kube_config_context_namespace(
88
+ context)).rstrip('/').lstrip('/')) for port in ports
89
+ ]
89
90
 
90
91
  # Generate ingress and services specs
91
92
  # We batch ingress rule creation because each rule triggers a hot reload of
@@ -171,7 +172,8 @@ def _cleanup_ports_for_ingress(
171
172
  for port in ports:
172
173
  service_name = f'{cluster_name_on_cloud}--skypilot-svc--{port}'
173
174
  network_utils.delete_namespaced_service(
174
- namespace=provider_config.get('namespace', 'default'),
175
+ namespace=provider_config.get('namespace',
176
+ kubernetes_utils.DEFAULT_NAMESPACE),
175
177
  service_name=service_name,
176
178
  )
177
179
 
@@ -208,11 +210,13 @@ def query_ports(
208
210
  return _query_ports_for_ingress(
209
211
  cluster_name_on_cloud=cluster_name_on_cloud,
210
212
  ports=ports,
213
+ provider_config=provider_config,
211
214
  )
212
215
  elif port_mode == kubernetes_enums.KubernetesPortMode.PODIP:
213
216
  return _query_ports_for_podip(
214
217
  cluster_name_on_cloud=cluster_name_on_cloud,
215
218
  ports=ports,
219
+ provider_config=provider_config,
216
220
  )
217
221
  else:
218
222
  return {}
@@ -231,8 +235,14 @@ def _query_ports_for_loadbalancer(
231
235
  result: Dict[int, List[common.Endpoint]] = {}
232
236
  service_name = _LOADBALANCER_SERVICE_NAME.format(
233
237
  cluster_name_on_cloud=cluster_name_on_cloud)
238
+ context = provider_config.get(
239
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
240
+ namespace = provider_config.get(
241
+ 'namespace',
242
+ kubernetes_utils.get_kube_config_context_namespace(context))
234
243
  external_ip = network_utils.get_loadbalancer_ip(
235
- namespace=provider_config.get('namespace', 'default'),
244
+ context=context,
245
+ namespace=namespace,
236
246
  service_name=service_name,
237
247
  # Timeout is set so that we can retry the query when the
238
248
  # cluster is firstly created and the load balancer is not ready yet.
@@ -251,19 +261,24 @@ def _query_ports_for_loadbalancer(
251
261
  def _query_ports_for_ingress(
252
262
  cluster_name_on_cloud: str,
253
263
  ports: List[int],
264
+ provider_config: Dict[str, Any],
254
265
  ) -> Dict[int, List[common.Endpoint]]:
255
- ingress_details = network_utils.get_ingress_external_ip_and_ports()
266
+ context = provider_config.get(
267
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
268
+ ingress_details = network_utils.get_ingress_external_ip_and_ports(context)
256
269
  external_ip, external_ports = ingress_details
257
270
  if external_ip is None:
258
271
  return {}
259
272
 
273
+ namespace = provider_config.get(
274
+ 'namespace',
275
+ kubernetes_utils.get_kube_config_context_namespace(context))
260
276
  result: Dict[int, List[common.Endpoint]] = {}
261
277
  for port in ports:
262
278
  path_prefix = _PATH_PREFIX.format(
263
279
  cluster_name_on_cloud=cluster_name_on_cloud,
264
280
  port=port,
265
- namespace=kubernetes_utils.
266
- get_current_kube_config_context_namespace())
281
+ namespace=namespace)
267
282
 
268
283
  http_port, https_port = external_ports \
269
284
  if external_ports is not None else (None, None)
@@ -282,10 +297,15 @@ def _query_ports_for_ingress(
282
297
  def _query_ports_for_podip(
283
298
  cluster_name_on_cloud: str,
284
299
  ports: List[int],
300
+ provider_config: Dict[str, Any],
285
301
  ) -> Dict[int, List[common.Endpoint]]:
286
- namespace = kubernetes_utils.get_current_kube_config_context_namespace()
302
+ context = provider_config.get(
303
+ 'context', kubernetes_utils.get_current_kube_config_context_name())
304
+ namespace = provider_config.get(
305
+ 'namespace',
306
+ kubernetes_utils.get_kube_config_context_namespace(context))
287
307
  pod_name = kubernetes_utils.get_head_pod_name(cluster_name_on_cloud)
288
- pod_ip = network_utils.get_pod_ip(namespace, pod_name)
308
+ pod_ip = network_utils.get_pod_ip(context, namespace, pod_name)
289
309
 
290
310
  result: Dict[int, List[common.Endpoint]] = {}
291
311
  if pod_ip is None:
@@ -220,10 +220,11 @@ def ingress_controller_exists(context: str,
220
220
 
221
221
 
222
222
  def get_ingress_external_ip_and_ports(
223
+ context: str,
223
224
  namespace: str = 'ingress-nginx'
224
225
  ) -> Tuple[Optional[str], Optional[Tuple[int, int]]]:
225
226
  """Returns external ip and ports for the ingress controller."""
226
- core_api = kubernetes.core_api()
227
+ core_api = kubernetes.core_api(context)
227
228
  ingress_services = [
228
229
  item for item in core_api.list_namespaced_service(
229
230
  namespace, _request_timeout=kubernetes.API_TIMEOUT).items
@@ -257,11 +258,12 @@ def get_ingress_external_ip_and_ports(
257
258
  return external_ip, None
258
259
 
259
260
 
260
- def get_loadbalancer_ip(namespace: str,
261
+ def get_loadbalancer_ip(context: str,
262
+ namespace: str,
261
263
  service_name: str,
262
264
  timeout: int = 0) -> Optional[str]:
263
265
  """Returns the IP address of the load balancer."""
264
- core_api = kubernetes.core_api()
266
+ core_api = kubernetes.core_api(context)
265
267
 
266
268
  ip = None
267
269
 
@@ -282,9 +284,9 @@ def get_loadbalancer_ip(namespace: str,
282
284
  return ip
283
285
 
284
286
 
285
- def get_pod_ip(namespace: str, pod_name: str) -> Optional[str]:
287
+ def get_pod_ip(context: str, namespace: str, pod_name: str) -> Optional[str]:
286
288
  """Returns the IP address of the pod."""
287
- core_api = kubernetes.core_api()
289
+ core_api = kubernetes.core_api(context)
288
290
  pod = core_api.read_namespaced_pod(pod_name,
289
291
  namespace,
290
292
  _request_timeout=kubernetes.API_TIMEOUT)