skypilot-nightly 1.0.0.dev20250215__py3-none-any.whl → 1.0.0.dev20250217__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sky/__init__.py +48 -22
  2. sky/adaptors/aws.py +2 -1
  3. sky/adaptors/azure.py +4 -4
  4. sky/adaptors/cloudflare.py +4 -4
  5. sky/adaptors/kubernetes.py +8 -8
  6. sky/authentication.py +42 -45
  7. sky/backends/backend.py +2 -2
  8. sky/backends/backend_utils.py +108 -221
  9. sky/backends/cloud_vm_ray_backend.py +283 -282
  10. sky/benchmark/benchmark_utils.py +6 -2
  11. sky/check.py +40 -28
  12. sky/cli.py +1213 -1116
  13. sky/client/__init__.py +1 -0
  14. sky/client/cli.py +5644 -0
  15. sky/client/common.py +345 -0
  16. sky/client/sdk.py +1757 -0
  17. sky/cloud_stores.py +12 -6
  18. sky/clouds/__init__.py +0 -2
  19. sky/clouds/aws.py +20 -13
  20. sky/clouds/azure.py +5 -3
  21. sky/clouds/cloud.py +1 -1
  22. sky/clouds/cudo.py +2 -1
  23. sky/clouds/do.py +2 -1
  24. sky/clouds/fluidstack.py +3 -2
  25. sky/clouds/gcp.py +10 -8
  26. sky/clouds/ibm.py +8 -7
  27. sky/clouds/kubernetes.py +7 -6
  28. sky/clouds/lambda_cloud.py +8 -7
  29. sky/clouds/oci.py +4 -3
  30. sky/clouds/paperspace.py +2 -1
  31. sky/clouds/runpod.py +2 -1
  32. sky/clouds/scp.py +8 -7
  33. sky/clouds/service_catalog/__init__.py +3 -3
  34. sky/clouds/service_catalog/aws_catalog.py +7 -1
  35. sky/clouds/service_catalog/common.py +4 -2
  36. sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +2 -2
  37. sky/clouds/utils/oci_utils.py +1 -1
  38. sky/clouds/vast.py +2 -1
  39. sky/clouds/vsphere.py +2 -1
  40. sky/core.py +263 -99
  41. sky/dag.py +4 -0
  42. sky/data/mounting_utils.py +2 -1
  43. sky/data/storage.py +97 -35
  44. sky/data/storage_utils.py +69 -9
  45. sky/exceptions.py +138 -5
  46. sky/execution.py +47 -50
  47. sky/global_user_state.py +105 -22
  48. sky/jobs/__init__.py +12 -14
  49. sky/jobs/client/__init__.py +0 -0
  50. sky/jobs/client/sdk.py +296 -0
  51. sky/jobs/constants.py +30 -1
  52. sky/jobs/controller.py +12 -6
  53. sky/jobs/dashboard/dashboard.py +2 -6
  54. sky/jobs/recovery_strategy.py +22 -29
  55. sky/jobs/server/__init__.py +1 -0
  56. sky/jobs/{core.py → server/core.py} +101 -34
  57. sky/jobs/server/dashboard_utils.py +64 -0
  58. sky/jobs/server/server.py +182 -0
  59. sky/jobs/utils.py +32 -23
  60. sky/models.py +27 -0
  61. sky/optimizer.py +9 -11
  62. sky/provision/__init__.py +6 -3
  63. sky/provision/aws/config.py +2 -2
  64. sky/provision/aws/instance.py +1 -1
  65. sky/provision/azure/instance.py +1 -1
  66. sky/provision/cudo/instance.py +1 -1
  67. sky/provision/do/instance.py +1 -1
  68. sky/provision/do/utils.py +0 -5
  69. sky/provision/fluidstack/fluidstack_utils.py +4 -3
  70. sky/provision/fluidstack/instance.py +4 -2
  71. sky/provision/gcp/instance.py +1 -1
  72. sky/provision/instance_setup.py +2 -2
  73. sky/provision/kubernetes/constants.py +8 -0
  74. sky/provision/kubernetes/instance.py +1 -1
  75. sky/provision/kubernetes/utils.py +67 -76
  76. sky/provision/lambda_cloud/instance.py +3 -15
  77. sky/provision/logging.py +1 -1
  78. sky/provision/oci/instance.py +7 -4
  79. sky/provision/paperspace/instance.py +1 -1
  80. sky/provision/provisioner.py +3 -2
  81. sky/provision/runpod/instance.py +1 -1
  82. sky/provision/vast/instance.py +1 -1
  83. sky/provision/vast/utils.py +2 -1
  84. sky/provision/vsphere/instance.py +2 -11
  85. sky/resources.py +55 -40
  86. sky/serve/__init__.py +6 -10
  87. sky/serve/client/__init__.py +0 -0
  88. sky/serve/client/sdk.py +366 -0
  89. sky/serve/constants.py +3 -0
  90. sky/serve/replica_managers.py +10 -10
  91. sky/serve/serve_utils.py +56 -36
  92. sky/serve/server/__init__.py +0 -0
  93. sky/serve/{core.py → server/core.py} +37 -17
  94. sky/serve/server/server.py +117 -0
  95. sky/serve/service.py +8 -1
  96. sky/server/__init__.py +1 -0
  97. sky/server/common.py +441 -0
  98. sky/server/constants.py +21 -0
  99. sky/server/html/log.html +174 -0
  100. sky/server/requests/__init__.py +0 -0
  101. sky/server/requests/executor.py +462 -0
  102. sky/server/requests/payloads.py +481 -0
  103. sky/server/requests/queues/__init__.py +0 -0
  104. sky/server/requests/queues/mp_queue.py +76 -0
  105. sky/server/requests/requests.py +567 -0
  106. sky/server/requests/serializers/__init__.py +0 -0
  107. sky/server/requests/serializers/decoders.py +192 -0
  108. sky/server/requests/serializers/encoders.py +166 -0
  109. sky/server/server.py +1095 -0
  110. sky/server/stream_utils.py +144 -0
  111. sky/setup_files/MANIFEST.in +1 -0
  112. sky/setup_files/dependencies.py +12 -4
  113. sky/setup_files/setup.py +1 -1
  114. sky/sky_logging.py +9 -13
  115. sky/skylet/autostop_lib.py +2 -2
  116. sky/skylet/constants.py +46 -12
  117. sky/skylet/events.py +5 -6
  118. sky/skylet/job_lib.py +78 -66
  119. sky/skylet/log_lib.py +17 -11
  120. sky/skypilot_config.py +79 -94
  121. sky/task.py +119 -73
  122. sky/templates/aws-ray.yml.j2 +4 -4
  123. sky/templates/azure-ray.yml.j2 +3 -2
  124. sky/templates/cudo-ray.yml.j2 +3 -2
  125. sky/templates/fluidstack-ray.yml.j2 +3 -2
  126. sky/templates/gcp-ray.yml.j2 +3 -2
  127. sky/templates/ibm-ray.yml.j2 +3 -2
  128. sky/templates/jobs-controller.yaml.j2 +1 -12
  129. sky/templates/kubernetes-ray.yml.j2 +3 -2
  130. sky/templates/lambda-ray.yml.j2 +3 -2
  131. sky/templates/oci-ray.yml.j2 +3 -2
  132. sky/templates/paperspace-ray.yml.j2 +3 -2
  133. sky/templates/runpod-ray.yml.j2 +3 -2
  134. sky/templates/scp-ray.yml.j2 +3 -2
  135. sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
  136. sky/templates/vsphere-ray.yml.j2 +4 -2
  137. sky/templates/websocket_proxy.py +64 -0
  138. sky/usage/constants.py +8 -0
  139. sky/usage/usage_lib.py +45 -11
  140. sky/utils/accelerator_registry.py +33 -53
  141. sky/utils/admin_policy_utils.py +2 -1
  142. sky/utils/annotations.py +51 -0
  143. sky/utils/cli_utils/status_utils.py +33 -3
  144. sky/utils/cluster_utils.py +356 -0
  145. sky/utils/command_runner.py +69 -14
  146. sky/utils/common.py +74 -0
  147. sky/utils/common_utils.py +133 -93
  148. sky/utils/config_utils.py +204 -0
  149. sky/utils/control_master_utils.py +2 -3
  150. sky/utils/controller_utils.py +133 -147
  151. sky/utils/dag_utils.py +72 -24
  152. sky/utils/kubernetes/deploy_remote_cluster.sh +2 -2
  153. sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
  154. sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
  155. sky/utils/log_utils.py +83 -23
  156. sky/utils/message_utils.py +81 -0
  157. sky/utils/registry.py +127 -0
  158. sky/utils/resources_utils.py +2 -2
  159. sky/utils/rich_utils.py +213 -34
  160. sky/utils/schemas.py +19 -2
  161. sky/{status_lib.py → utils/status_lib.py} +12 -7
  162. sky/utils/subprocess_utils.py +51 -35
  163. sky/utils/timeline.py +7 -2
  164. sky/utils/ux_utils.py +95 -25
  165. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/METADATA +8 -3
  166. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/RECORD +170 -132
  167. sky/clouds/cloud_registry.py +0 -76
  168. sky/utils/cluster_yaml_utils.py +0 -24
  169. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/LICENSE +0 -0
  170. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/WHEEL +0 -0
  171. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/entry_points.txt +0 -0
  172. {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/top_level.txt +0 -0
sky/utils/common_utils.py CHANGED
@@ -5,7 +5,7 @@ import functools
5
5
  import getpass
6
6
  import hashlib
7
7
  import inspect
8
- import json
8
+ import io
9
9
  import os
10
10
  import platform
11
11
  import random
@@ -23,6 +23,8 @@ import yaml
23
23
  from sky import exceptions
24
24
  from sky import sky_logging
25
25
  from sky.skylet import constants
26
+ from sky.usage import constants as usage_constants
27
+ from sky.utils import annotations
26
28
  from sky.utils import ux_utils
27
29
  from sky.utils import validator
28
30
 
@@ -36,16 +38,12 @@ CLUSTER_NAME_HASH_LENGTH = 2
36
38
 
37
39
  _COLOR_PATTERN = re.compile(r'\x1b[^m]*m')
38
40
 
39
- _PAYLOAD_PATTERN = re.compile(r'<sky-payload>(.*)</sky-payload>')
40
- _PAYLOAD_STR = '<sky-payload>{}</sky-payload>'
41
-
42
41
  _VALID_ENV_VAR_REGEX = '[a-zA-Z_][a-zA-Z0-9_]*'
43
42
 
44
43
  logger = sky_logging.init_logger(__name__)
45
44
 
46
- _usage_run_id = None
47
-
48
45
 
46
+ @annotations.lru_cache(scope='request')
49
47
  def get_usage_run_id() -> str:
50
48
  """Returns a unique run id for each 'run'.
51
49
 
@@ -53,42 +51,44 @@ def get_usage_run_id() -> str:
53
51
  and has called its CLI or programmatic APIs. For example, two successive
54
52
  `sky launch` are two runs.
55
53
  """
56
- global _usage_run_id
57
- if _usage_run_id is None:
58
- _usage_run_id = str(uuid.uuid4())
59
- return _usage_run_id
54
+ usage_run_id = os.getenv(usage_constants.USAGE_RUN_ID_ENV_VAR)
55
+ if usage_run_id is not None:
56
+ return usage_run_id
57
+ return str(uuid.uuid4())
58
+
59
+
60
+ def _is_valid_user_hash(user_hash: Optional[str]) -> bool:
61
+ if user_hash is None:
62
+ return False
63
+ try:
64
+ int(user_hash, 16)
65
+ except (TypeError, ValueError):
66
+ return False
67
+ return len(user_hash) == USER_HASH_LENGTH
68
+
69
+
70
+ def generate_user_hash() -> str:
71
+ """Generates a unique user-machine specific hash."""
72
+ hash_str = user_and_hostname_hash()
73
+ user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:USER_HASH_LENGTH]
74
+ if not _is_valid_user_hash(user_hash):
75
+ # A fallback in case the hash is invalid.
76
+ user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH]
77
+ return user_hash
60
78
 
61
79
 
62
- def get_user_hash(force_fresh_hash: bool = False) -> str:
80
+ def get_user_hash() -> str:
63
81
  """Returns a unique user-machine specific hash as a user id.
64
82
 
65
83
  We cache the user hash in a file to avoid potential user_name or
66
84
  hostname changes causing a new user hash to be generated.
67
-
68
- Args:
69
- force_fresh_hash: Bypasses the cached hash in USER_HASH_FILE and the
70
- hash in the USER_ID_ENV_VAR and forces a fresh user-machine hash
71
- to be generated. Used by `kubernetes.ssh_key_secret_field_name` to
72
- avoid controllers sharing the same ssh key field name as the
73
- local client.
74
85
  """
86
+ user_hash = os.getenv(constants.USER_ID_ENV_VAR)
87
+ if _is_valid_user_hash(user_hash):
88
+ assert user_hash is not None
89
+ return user_hash
75
90
 
76
- def _is_valid_user_hash(user_hash: Optional[str]) -> bool:
77
- if user_hash is None:
78
- return False
79
- try:
80
- int(user_hash, 16)
81
- except (TypeError, ValueError):
82
- return False
83
- return len(user_hash) == USER_HASH_LENGTH
84
-
85
- if not force_fresh_hash:
86
- user_hash = os.getenv(constants.USER_ID_ENV_VAR)
87
- if _is_valid_user_hash(user_hash):
88
- assert user_hash is not None
89
- return user_hash
90
-
91
- if not force_fresh_hash and os.path.exists(_USER_HASH_FILE):
91
+ if os.path.exists(_USER_HASH_FILE):
92
92
  # Read from cached user hash file.
93
93
  with open(_USER_HASH_FILE, 'r', encoding='utf-8') as f:
94
94
  # Remove invalid characters.
@@ -96,19 +96,10 @@ def get_user_hash(force_fresh_hash: bool = False) -> str:
96
96
  if _is_valid_user_hash(user_hash):
97
97
  return user_hash
98
98
 
99
- hash_str = user_and_hostname_hash()
100
- user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:USER_HASH_LENGTH]
101
- if not _is_valid_user_hash(user_hash):
102
- # A fallback in case the hash is invalid.
103
- user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH]
99
+ user_hash = generate_user_hash()
104
100
  os.makedirs(os.path.dirname(_USER_HASH_FILE), exist_ok=True)
105
- if not force_fresh_hash:
106
- # Do not cache to file if force_fresh_hash is True since the file may
107
- # be intentionally using a different hash, e.g. we want to keep the
108
- # user_hash for usage collection the same on the jobs/serve controller
109
- # as users' local client.
110
- with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f:
111
- f.write(user_hash)
101
+ with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f:
102
+ f.write(user_hash)
112
103
  return user_hash
113
104
 
114
105
 
@@ -253,7 +244,46 @@ class Backoff:
253
244
  return self._backoff
254
245
 
255
246
 
256
- def get_pretty_entry_point() -> str:
247
+ _current_command: Optional[str] = None
248
+ _current_client_entrypoint: Optional[str] = None
249
+
250
+
251
+ def set_client_entrypoint_and_command(client_entrypoint: Optional[str],
252
+ client_command: Optional[str]):
253
+ """Override the current client entrypoint and command.
254
+
255
+ This is useful when we are on the SkyPilot API server side and we have a
256
+ client entrypoint and command from the client.
257
+ """
258
+ global _current_command, _current_client_entrypoint
259
+ _current_command = client_command
260
+ _current_client_entrypoint = client_entrypoint
261
+
262
+
263
+ def get_current_command() -> str:
264
+ """Returns the command related to this operation.
265
+
266
+ Normally uses get_pretty_entry_point(), but will use the client command on
267
+ the server side.
268
+ """
269
+ if _current_command is not None:
270
+ return _current_command
271
+
272
+ return get_pretty_entrypoint_cmd()
273
+
274
+
275
+ def get_current_client_entrypoint(server_entrypoint: str) -> str:
276
+ """Returns the current client entrypoint.
277
+
278
+ Gets the client entrypoint from the context, if it is not set, returns the
279
+ server entrypoint.
280
+ """
281
+ if _current_client_entrypoint is not None:
282
+ return _current_client_entrypoint
283
+ return server_entrypoint
284
+
285
+
286
+ def get_pretty_entrypoint_cmd() -> str:
257
287
  """Returns the prettified entry point of this process (sys.argv).
258
288
 
259
289
  Example return values:
@@ -298,29 +328,51 @@ def user_and_hostname_hash() -> str:
298
328
  return f'{getpass.getuser()}-{hostname_hash}'
299
329
 
300
330
 
301
- def read_yaml(path: str) -> Dict[str, Any]:
331
+ def read_yaml(path: Optional[str]) -> Dict[str, Any]:
332
+ if path is None:
333
+ raise ValueError('Attempted to read a None YAML.')
302
334
  with open(path, 'r', encoding='utf-8') as f:
303
335
  config = yaml.safe_load(f)
304
336
  return config
305
337
 
306
338
 
339
+ def read_yaml_all_str(yaml_str: str) -> List[Dict[str, Any]]:
340
+ stream = io.StringIO(yaml_str)
341
+ config = yaml.safe_load_all(stream)
342
+ configs = list(config)
343
+ if not configs:
344
+ # Empty YAML file.
345
+ return [{}]
346
+ return configs
347
+
348
+
307
349
  def read_yaml_all(path: str) -> List[Dict[str, Any]]:
308
350
  with open(path, 'r', encoding='utf-8') as f:
309
- config = yaml.safe_load_all(f)
310
- configs = list(config)
311
- if not configs:
312
- # Empty YAML file.
313
- return [{}]
314
- return configs
351
+ return read_yaml_all_str(f.read())
315
352
 
316
353
 
317
354
  def dump_yaml(path: str, config: Union[List[Dict[str, Any]],
318
355
  Dict[str, Any]]) -> None:
356
+ """Dumps a YAML file.
357
+
358
+ Args:
359
+ path: the path to the YAML file.
360
+ config: the configuration to dump.
361
+ """
319
362
  with open(path, 'w', encoding='utf-8') as f:
320
363
  f.write(dump_yaml_str(config))
321
364
 
322
365
 
323
366
  def dump_yaml_str(config: Union[List[Dict[str, Any]], Dict[str, Any]]) -> str:
367
+ """Dumps a YAML string.
368
+
369
+ Args:
370
+ config: the configuration to dump.
371
+
372
+ Returns:
373
+ The YAML string.
374
+ """
375
+
324
376
  # https://github.com/yaml/pyyaml/issues/127
325
377
  class LineBreakDumper(yaml.SafeDumper):
326
378
 
@@ -408,43 +460,6 @@ def retry(method, max_retries=3, initial_backoff=1):
408
460
  return method_with_retries
409
461
 
410
462
 
411
- def encode_payload(payload: Any) -> str:
412
- """Encode a payload to make it more robust for parsing.
413
-
414
- This makes message transfer more robust to any additional strings added to
415
- the message during transfer.
416
-
417
- An example message that is polluted by the system warning:
418
- "LC_ALL: cannot change locale (en_US.UTF-8)\n<sky-payload>hello, world</sky-payload>" # pylint: disable=line-too-long
419
-
420
- Args:
421
- payload: A str, dict or list to be encoded.
422
-
423
- Returns:
424
- A string that is encoded from the payload.
425
- """
426
- payload_str = json.dumps(payload)
427
- payload_str = _PAYLOAD_STR.format(payload_str)
428
- return payload_str
429
-
430
-
431
- def decode_payload(payload_str: str) -> Any:
432
- """Decode a payload string.
433
-
434
- Args:
435
- payload_str: A string that is encoded from a payload.
436
-
437
- Returns:
438
- A str, dict or list that is decoded from the payload string.
439
- """
440
- matched = _PAYLOAD_PATTERN.findall(payload_str)
441
- if not matched:
442
- raise ValueError(f'Invalid payload string: \n{payload_str}')
443
- payload_str = matched[0]
444
- payload = json.loads(payload_str)
445
- return payload
446
-
447
-
448
463
  def class_fullname(cls, skip_builtins: bool = True):
449
464
  """Get the full name of a class.
450
465
 
@@ -492,12 +507,14 @@ def remove_color(s: str):
492
507
  return _COLOR_PATTERN.sub('', s)
493
508
 
494
509
 
495
- def remove_file_if_exists(path: str):
510
+ def remove_file_if_exists(path: Optional[str]):
496
511
  """Delete a file if it exists.
497
512
 
498
513
  Args:
499
514
  path: The path to the file.
500
515
  """
516
+ if path is None:
517
+ return
501
518
  try:
502
519
  os.remove(path)
503
520
  except FileNotFoundError:
@@ -600,7 +617,7 @@ def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):
600
617
 
601
618
  if err_msg:
602
619
  with ux_utils.print_exception_no_traceback():
603
- raise ValueError(err_msg)
620
+ raise exceptions.InvalidSkyPilotConfigError(err_msg)
604
621
 
605
622
 
606
623
  def get_cleaned_username(username: str = '') -> str:
@@ -715,3 +732,26 @@ def hash_file(path: str, hash_alg: str) -> 'hashlib._Hash':
715
732
  break
716
733
  file_hash.update(view[:size])
717
734
  return file_hash
735
+
736
+
737
+ def is_port_available(port: int, reuse_addr: bool = True) -> bool:
738
+ """Check if a TCP port is available for binding on localhost.
739
+
740
+ Args:
741
+ port: The port number to check.
742
+ reuse_addr: If True, sets SO_REUSEADDR socket option to allow reusing
743
+ ports in TIME_WAIT state. Servers like multiprocessing.Manager set
744
+ SO_REUSEADDR by default to accelerate restart. The option should be
745
+ coordinated in check.
746
+
747
+ Returns:
748
+ bool: True if the port is available for binding, False otherwise.
749
+ """
750
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
751
+ if reuse_addr:
752
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
753
+ try:
754
+ s.bind(('localhost', port))
755
+ return True
756
+ except OSError:
757
+ return False
@@ -0,0 +1,204 @@
1
+ """Utilities for nested config."""
2
+ import copy
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ from sky import sky_logging
6
+
7
+ logger = sky_logging.init_logger(__name__)
8
+
9
+
10
+ class Config(Dict[str, Any]):
11
+ """SkyPilot config that supports setting/getting values with nested keys."""
12
+
13
+ def get_nested(
14
+ self,
15
+ keys: Tuple[str, ...],
16
+ default_value: Any,
17
+ override_configs: Optional[Dict[str, Any]] = None,
18
+ allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
19
+ disallowed_override_keys: Optional[List[Tuple[str,
20
+ ...]]] = None) -> Any:
21
+ """Gets a nested key.
22
+
23
+ If any key is not found, or any intermediate key does not point to a
24
+ dict value, returns 'default_value'.
25
+
26
+ Args:
27
+ keys: A tuple of strings representing the nested keys.
28
+ default_value: The default value to return if the key is not found.
29
+ override_configs: A dict of override configs with the same schema as
30
+ the config file, but only containing the keys to override.
31
+ allowed_override_keys: A list of keys that are allowed to be
32
+ overridden.
33
+ disallowed_override_keys: A list of keys that are disallowed to be
34
+ overridden.
35
+
36
+ Returns:
37
+ The value of the nested key, or 'default_value' if not found.
38
+ """
39
+ config = copy.deepcopy(self)
40
+ if override_configs is not None:
41
+ config = _recursive_update(config, override_configs,
42
+ allowed_override_keys,
43
+ disallowed_override_keys)
44
+ return _get_nested(config, keys, default_value, pop=False)
45
+
46
+ def set_nested(self, keys: Tuple[str, ...], value: Any) -> None:
47
+ """In-place sets a nested key to value.
48
+
49
+ Like get_nested(), if any key is not found, this will not raise an
50
+ error.
51
+ """
52
+ override = {}
53
+ for i, key in enumerate(reversed(keys)):
54
+ if i == 0:
55
+ override = {key: value}
56
+ else:
57
+ override = {key: override}
58
+ _recursive_update(self, override)
59
+
60
+ def pop_nested(self, keys: Tuple[str, ...], default_value: Any) -> Any:
61
+ """Pops a nested key."""
62
+ return _get_nested(self, keys, default_value, pop=True)
63
+
64
+ @classmethod
65
+ def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config':
66
+ if config is None:
67
+ return cls()
68
+ return cls(**config)
69
+
70
+
71
+ def _check_allowed_and_disallowed_override_keys(
72
+ key: str,
73
+ allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
74
+ disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None
75
+ ) -> Tuple[Optional[List[Tuple[str, ...]]], Optional[List[Tuple[str, ...]]]]:
76
+ allowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
77
+ disallowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
78
+ if allowed_override_keys is not None:
79
+ for nested_key in allowed_override_keys:
80
+ if key == nested_key[0]:
81
+ if len(nested_key) == 1:
82
+ # Allowed key is fully matched, no need to check further.
83
+ allowed_keys_with_matched_prefix = None
84
+ break
85
+ assert allowed_keys_with_matched_prefix is not None
86
+ allowed_keys_with_matched_prefix.append(nested_key[1:])
87
+ if (allowed_keys_with_matched_prefix is not None and
88
+ not allowed_keys_with_matched_prefix):
89
+ raise ValueError(f'Key {key} is not in allowed override keys: '
90
+ f'{allowed_override_keys}')
91
+ else:
92
+ allowed_keys_with_matched_prefix = None
93
+
94
+ if disallowed_override_keys is not None:
95
+ for nested_key in disallowed_override_keys:
96
+ if key == nested_key[0]:
97
+ if len(nested_key) == 1:
98
+ raise ValueError(
99
+ f'Key {key} is in disallowed override keys: '
100
+ f'{disallowed_override_keys}')
101
+ assert disallowed_keys_with_matched_prefix is not None
102
+ disallowed_keys_with_matched_prefix.append(nested_key[1:])
103
+ else:
104
+ disallowed_keys_with_matched_prefix = None
105
+ return allowed_keys_with_matched_prefix, disallowed_keys_with_matched_prefix
106
+
107
+
108
+ def _recursive_update(
109
+ base_config: Config,
110
+ override_config: Dict[str, Any],
111
+ allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
112
+ disallowed_override_keys: Optional[List[Tuple[str,
113
+ ...]]] = None) -> Config:
114
+ """Recursively updates base configuration with override configuration"""
115
+ for key, value in override_config.items():
116
+ (next_allowed_override_keys, next_disallowed_override_keys
117
+ ) = _check_allowed_and_disallowed_override_keys(
118
+ key, allowed_override_keys, disallowed_override_keys)
119
+ if key == 'kubernetes' and key in base_config:
120
+ merge_k8s_configs(base_config[key], value,
121
+ next_allowed_override_keys,
122
+ next_disallowed_override_keys)
123
+ elif (isinstance(value, dict) and key in base_config and
124
+ isinstance(base_config[key], dict)):
125
+ _recursive_update(base_config[key], value,
126
+ next_allowed_override_keys,
127
+ next_disallowed_override_keys)
128
+ else:
129
+ base_config[key] = value
130
+ return base_config
131
+
132
+
133
+ def _get_nested(configs: Optional[Dict[str, Any]],
134
+ keys: Tuple[str, ...],
135
+ default_value: Any,
136
+ pop: bool = False) -> Any:
137
+ if configs is None:
138
+ return default_value
139
+ curr = configs
140
+ for i, key in enumerate(keys):
141
+ if isinstance(curr, dict) and key in curr:
142
+ value = curr[key]
143
+ if i == len(keys) - 1:
144
+ if pop:
145
+ curr.pop(key, default_value)
146
+ curr = value
147
+ else:
148
+ return default_value
149
+ logger.debug(f'User config: {".".join(keys)} -> {curr}')
150
+ return curr
151
+
152
+
153
+ def merge_k8s_configs(
154
+ base_config: Dict[Any, Any],
155
+ override_config: Dict[Any, Any],
156
+ allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
157
+ disallowed_override_keys: Optional[List[Tuple[str,
158
+ ...]]] = None) -> None:
159
+ """Merge two configs into the base_config.
160
+
161
+ Updates nested dictionaries instead of replacing them.
162
+ If a list is encountered, it will be appended to the base_config list.
163
+
164
+ An exception is when the key is 'containers', in which case the
165
+ first container in the list will be fetched and merge_dict will be
166
+ called on it with the first container in the base_config list.
167
+ """
168
+ for key, value in override_config.items():
169
+ (next_allowed_override_keys, next_disallowed_override_keys
170
+ ) = _check_allowed_and_disallowed_override_keys(
171
+ key, allowed_override_keys, disallowed_override_keys)
172
+ if isinstance(value, dict) and key in base_config:
173
+ merge_k8s_configs(base_config[key], value,
174
+ next_allowed_override_keys,
175
+ next_disallowed_override_keys)
176
+ elif isinstance(value, list) and key in base_config:
177
+ assert isinstance(base_config[key], list), \
178
+ f'Expected {key} to be a list, found {base_config[key]}'
179
+ if key in ['containers', 'imagePullSecrets']:
180
+ # If the key is 'containers' or 'imagePullSecrets, we take the
181
+ # first and only container/secret in the list and merge it, as
182
+ # we only support one container per pod.
183
+ assert len(value) == 1, \
184
+ f'Expected only one container, found {value}'
185
+ merge_k8s_configs(base_config[key][0], value[0],
186
+ next_allowed_override_keys,
187
+ next_disallowed_override_keys)
188
+ elif key in ['volumes', 'volumeMounts']:
189
+ # If the key is 'volumes' or 'volumeMounts', we search for
190
+ # item with the same name and merge it.
191
+ for new_volume in value:
192
+ new_volume_name = new_volume.get('name')
193
+ if new_volume_name is not None:
194
+ destination_volume = next(
195
+ (v for v in base_config[key]
196
+ if v.get('name') == new_volume_name), None)
197
+ if destination_volume is not None:
198
+ merge_k8s_configs(destination_volume, new_volume)
199
+ else:
200
+ base_config[key].append(new_volume)
201
+ else:
202
+ base_config[key].extend(value)
203
+ else:
204
+ base_config[key] = value
@@ -1,8 +1,7 @@
1
1
  """Utils to check if the ssh control master should be disabled."""
2
2
 
3
- import functools
4
-
5
3
  from sky import sky_logging
4
+ from sky.utils import annotations
6
5
  from sky.utils import subprocess_utils
7
6
 
8
7
  logger = sky_logging.init_logger(__name__)
@@ -34,7 +33,7 @@ def is_tmp_9p_filesystem() -> bool:
34
33
  return filesystem_types[1].lower() == '9p'
35
34
 
36
35
 
37
- @functools.lru_cache
36
+ @annotations.lru_cache(scope='global')
38
37
  def should_disable_control_master() -> bool:
39
38
  """Whether disable ssh control master based on file system.
40
39