skypilot-nightly 1.0.0.dev20250215__py3-none-any.whl → 1.0.0.dev20250217__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sky/__init__.py +48 -22
- sky/adaptors/aws.py +2 -1
- sky/adaptors/azure.py +4 -4
- sky/adaptors/cloudflare.py +4 -4
- sky/adaptors/kubernetes.py +8 -8
- sky/authentication.py +42 -45
- sky/backends/backend.py +2 -2
- sky/backends/backend_utils.py +108 -221
- sky/backends/cloud_vm_ray_backend.py +283 -282
- sky/benchmark/benchmark_utils.py +6 -2
- sky/check.py +40 -28
- sky/cli.py +1213 -1116
- sky/client/__init__.py +1 -0
- sky/client/cli.py +5644 -0
- sky/client/common.py +345 -0
- sky/client/sdk.py +1757 -0
- sky/cloud_stores.py +12 -6
- sky/clouds/__init__.py +0 -2
- sky/clouds/aws.py +20 -13
- sky/clouds/azure.py +5 -3
- sky/clouds/cloud.py +1 -1
- sky/clouds/cudo.py +2 -1
- sky/clouds/do.py +2 -1
- sky/clouds/fluidstack.py +3 -2
- sky/clouds/gcp.py +10 -8
- sky/clouds/ibm.py +8 -7
- sky/clouds/kubernetes.py +7 -6
- sky/clouds/lambda_cloud.py +8 -7
- sky/clouds/oci.py +4 -3
- sky/clouds/paperspace.py +2 -1
- sky/clouds/runpod.py +2 -1
- sky/clouds/scp.py +8 -7
- sky/clouds/service_catalog/__init__.py +3 -3
- sky/clouds/service_catalog/aws_catalog.py +7 -1
- sky/clouds/service_catalog/common.py +4 -2
- sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +2 -2
- sky/clouds/utils/oci_utils.py +1 -1
- sky/clouds/vast.py +2 -1
- sky/clouds/vsphere.py +2 -1
- sky/core.py +263 -99
- sky/dag.py +4 -0
- sky/data/mounting_utils.py +2 -1
- sky/data/storage.py +97 -35
- sky/data/storage_utils.py +69 -9
- sky/exceptions.py +138 -5
- sky/execution.py +47 -50
- sky/global_user_state.py +105 -22
- sky/jobs/__init__.py +12 -14
- sky/jobs/client/__init__.py +0 -0
- sky/jobs/client/sdk.py +296 -0
- sky/jobs/constants.py +30 -1
- sky/jobs/controller.py +12 -6
- sky/jobs/dashboard/dashboard.py +2 -6
- sky/jobs/recovery_strategy.py +22 -29
- sky/jobs/server/__init__.py +1 -0
- sky/jobs/{core.py → server/core.py} +101 -34
- sky/jobs/server/dashboard_utils.py +64 -0
- sky/jobs/server/server.py +182 -0
- sky/jobs/utils.py +32 -23
- sky/models.py +27 -0
- sky/optimizer.py +9 -11
- sky/provision/__init__.py +6 -3
- sky/provision/aws/config.py +2 -2
- sky/provision/aws/instance.py +1 -1
- sky/provision/azure/instance.py +1 -1
- sky/provision/cudo/instance.py +1 -1
- sky/provision/do/instance.py +1 -1
- sky/provision/do/utils.py +0 -5
- sky/provision/fluidstack/fluidstack_utils.py +4 -3
- sky/provision/fluidstack/instance.py +4 -2
- sky/provision/gcp/instance.py +1 -1
- sky/provision/instance_setup.py +2 -2
- sky/provision/kubernetes/constants.py +8 -0
- sky/provision/kubernetes/instance.py +1 -1
- sky/provision/kubernetes/utils.py +67 -76
- sky/provision/lambda_cloud/instance.py +3 -15
- sky/provision/logging.py +1 -1
- sky/provision/oci/instance.py +7 -4
- sky/provision/paperspace/instance.py +1 -1
- sky/provision/provisioner.py +3 -2
- sky/provision/runpod/instance.py +1 -1
- sky/provision/vast/instance.py +1 -1
- sky/provision/vast/utils.py +2 -1
- sky/provision/vsphere/instance.py +2 -11
- sky/resources.py +55 -40
- sky/serve/__init__.py +6 -10
- sky/serve/client/__init__.py +0 -0
- sky/serve/client/sdk.py +366 -0
- sky/serve/constants.py +3 -0
- sky/serve/replica_managers.py +10 -10
- sky/serve/serve_utils.py +56 -36
- sky/serve/server/__init__.py +0 -0
- sky/serve/{core.py → server/core.py} +37 -17
- sky/serve/server/server.py +117 -0
- sky/serve/service.py +8 -1
- sky/server/__init__.py +1 -0
- sky/server/common.py +441 -0
- sky/server/constants.py +21 -0
- sky/server/html/log.html +174 -0
- sky/server/requests/__init__.py +0 -0
- sky/server/requests/executor.py +462 -0
- sky/server/requests/payloads.py +481 -0
- sky/server/requests/queues/__init__.py +0 -0
- sky/server/requests/queues/mp_queue.py +76 -0
- sky/server/requests/requests.py +567 -0
- sky/server/requests/serializers/__init__.py +0 -0
- sky/server/requests/serializers/decoders.py +192 -0
- sky/server/requests/serializers/encoders.py +166 -0
- sky/server/server.py +1095 -0
- sky/server/stream_utils.py +144 -0
- sky/setup_files/MANIFEST.in +1 -0
- sky/setup_files/dependencies.py +12 -4
- sky/setup_files/setup.py +1 -1
- sky/sky_logging.py +9 -13
- sky/skylet/autostop_lib.py +2 -2
- sky/skylet/constants.py +46 -12
- sky/skylet/events.py +5 -6
- sky/skylet/job_lib.py +78 -66
- sky/skylet/log_lib.py +17 -11
- sky/skypilot_config.py +79 -94
- sky/task.py +119 -73
- sky/templates/aws-ray.yml.j2 +4 -4
- sky/templates/azure-ray.yml.j2 +3 -2
- sky/templates/cudo-ray.yml.j2 +3 -2
- sky/templates/fluidstack-ray.yml.j2 +3 -2
- sky/templates/gcp-ray.yml.j2 +3 -2
- sky/templates/ibm-ray.yml.j2 +3 -2
- sky/templates/jobs-controller.yaml.j2 +1 -12
- sky/templates/kubernetes-ray.yml.j2 +3 -2
- sky/templates/lambda-ray.yml.j2 +3 -2
- sky/templates/oci-ray.yml.j2 +3 -2
- sky/templates/paperspace-ray.yml.j2 +3 -2
- sky/templates/runpod-ray.yml.j2 +3 -2
- sky/templates/scp-ray.yml.j2 +3 -2
- sky/templates/skypilot-server-kubernetes-proxy.sh +36 -0
- sky/templates/vsphere-ray.yml.j2 +4 -2
- sky/templates/websocket_proxy.py +64 -0
- sky/usage/constants.py +8 -0
- sky/usage/usage_lib.py +45 -11
- sky/utils/accelerator_registry.py +33 -53
- sky/utils/admin_policy_utils.py +2 -1
- sky/utils/annotations.py +51 -0
- sky/utils/cli_utils/status_utils.py +33 -3
- sky/utils/cluster_utils.py +356 -0
- sky/utils/command_runner.py +69 -14
- sky/utils/common.py +74 -0
- sky/utils/common_utils.py +133 -93
- sky/utils/config_utils.py +204 -0
- sky/utils/control_master_utils.py +2 -3
- sky/utils/controller_utils.py +133 -147
- sky/utils/dag_utils.py +72 -24
- sky/utils/kubernetes/deploy_remote_cluster.sh +2 -2
- sky/utils/kubernetes/exec_kubeconfig_converter.py +73 -0
- sky/utils/kubernetes/kubernetes_deploy_utils.py +228 -0
- sky/utils/log_utils.py +83 -23
- sky/utils/message_utils.py +81 -0
- sky/utils/registry.py +127 -0
- sky/utils/resources_utils.py +2 -2
- sky/utils/rich_utils.py +213 -34
- sky/utils/schemas.py +19 -2
- sky/{status_lib.py → utils/status_lib.py} +12 -7
- sky/utils/subprocess_utils.py +51 -35
- sky/utils/timeline.py +7 -2
- sky/utils/ux_utils.py +95 -25
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/METADATA +8 -3
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/RECORD +170 -132
- sky/clouds/cloud_registry.py +0 -76
- sky/utils/cluster_yaml_utils.py +0 -24
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/LICENSE +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/WHEEL +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/entry_points.txt +0 -0
- {skypilot_nightly-1.0.0.dev20250215.dist-info → skypilot_nightly-1.0.0.dev20250217.dist-info}/top_level.txt +0 -0
sky/utils/common_utils.py
CHANGED
@@ -5,7 +5,7 @@ import functools
|
|
5
5
|
import getpass
|
6
6
|
import hashlib
|
7
7
|
import inspect
|
8
|
-
import
|
8
|
+
import io
|
9
9
|
import os
|
10
10
|
import platform
|
11
11
|
import random
|
@@ -23,6 +23,8 @@ import yaml
|
|
23
23
|
from sky import exceptions
|
24
24
|
from sky import sky_logging
|
25
25
|
from sky.skylet import constants
|
26
|
+
from sky.usage import constants as usage_constants
|
27
|
+
from sky.utils import annotations
|
26
28
|
from sky.utils import ux_utils
|
27
29
|
from sky.utils import validator
|
28
30
|
|
@@ -36,16 +38,12 @@ CLUSTER_NAME_HASH_LENGTH = 2
|
|
36
38
|
|
37
39
|
_COLOR_PATTERN = re.compile(r'\x1b[^m]*m')
|
38
40
|
|
39
|
-
_PAYLOAD_PATTERN = re.compile(r'<sky-payload>(.*)</sky-payload>')
|
40
|
-
_PAYLOAD_STR = '<sky-payload>{}</sky-payload>'
|
41
|
-
|
42
41
|
_VALID_ENV_VAR_REGEX = '[a-zA-Z_][a-zA-Z0-9_]*'
|
43
42
|
|
44
43
|
logger = sky_logging.init_logger(__name__)
|
45
44
|
|
46
|
-
_usage_run_id = None
|
47
|
-
|
48
45
|
|
46
|
+
@annotations.lru_cache(scope='request')
|
49
47
|
def get_usage_run_id() -> str:
|
50
48
|
"""Returns a unique run id for each 'run'.
|
51
49
|
|
@@ -53,42 +51,44 @@ def get_usage_run_id() -> str:
|
|
53
51
|
and has called its CLI or programmatic APIs. For example, two successive
|
54
52
|
`sky launch` are two runs.
|
55
53
|
"""
|
56
|
-
|
57
|
-
if
|
58
|
-
|
59
|
-
return
|
54
|
+
usage_run_id = os.getenv(usage_constants.USAGE_RUN_ID_ENV_VAR)
|
55
|
+
if usage_run_id is not None:
|
56
|
+
return usage_run_id
|
57
|
+
return str(uuid.uuid4())
|
58
|
+
|
59
|
+
|
60
|
+
def _is_valid_user_hash(user_hash: Optional[str]) -> bool:
|
61
|
+
if user_hash is None:
|
62
|
+
return False
|
63
|
+
try:
|
64
|
+
int(user_hash, 16)
|
65
|
+
except (TypeError, ValueError):
|
66
|
+
return False
|
67
|
+
return len(user_hash) == USER_HASH_LENGTH
|
68
|
+
|
69
|
+
|
70
|
+
def generate_user_hash() -> str:
|
71
|
+
"""Generates a unique user-machine specific hash."""
|
72
|
+
hash_str = user_and_hostname_hash()
|
73
|
+
user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:USER_HASH_LENGTH]
|
74
|
+
if not _is_valid_user_hash(user_hash):
|
75
|
+
# A fallback in case the hash is invalid.
|
76
|
+
user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH]
|
77
|
+
return user_hash
|
60
78
|
|
61
79
|
|
62
|
-
def get_user_hash(
|
80
|
+
def get_user_hash() -> str:
|
63
81
|
"""Returns a unique user-machine specific hash as a user id.
|
64
82
|
|
65
83
|
We cache the user hash in a file to avoid potential user_name or
|
66
84
|
hostname changes causing a new user hash to be generated.
|
67
|
-
|
68
|
-
Args:
|
69
|
-
force_fresh_hash: Bypasses the cached hash in USER_HASH_FILE and the
|
70
|
-
hash in the USER_ID_ENV_VAR and forces a fresh user-machine hash
|
71
|
-
to be generated. Used by `kubernetes.ssh_key_secret_field_name` to
|
72
|
-
avoid controllers sharing the same ssh key field name as the
|
73
|
-
local client.
|
74
85
|
"""
|
86
|
+
user_hash = os.getenv(constants.USER_ID_ENV_VAR)
|
87
|
+
if _is_valid_user_hash(user_hash):
|
88
|
+
assert user_hash is not None
|
89
|
+
return user_hash
|
75
90
|
|
76
|
-
|
77
|
-
if user_hash is None:
|
78
|
-
return False
|
79
|
-
try:
|
80
|
-
int(user_hash, 16)
|
81
|
-
except (TypeError, ValueError):
|
82
|
-
return False
|
83
|
-
return len(user_hash) == USER_HASH_LENGTH
|
84
|
-
|
85
|
-
if not force_fresh_hash:
|
86
|
-
user_hash = os.getenv(constants.USER_ID_ENV_VAR)
|
87
|
-
if _is_valid_user_hash(user_hash):
|
88
|
-
assert user_hash is not None
|
89
|
-
return user_hash
|
90
|
-
|
91
|
-
if not force_fresh_hash and os.path.exists(_USER_HASH_FILE):
|
91
|
+
if os.path.exists(_USER_HASH_FILE):
|
92
92
|
# Read from cached user hash file.
|
93
93
|
with open(_USER_HASH_FILE, 'r', encoding='utf-8') as f:
|
94
94
|
# Remove invalid characters.
|
@@ -96,19 +96,10 @@ def get_user_hash(force_fresh_hash: bool = False) -> str:
|
|
96
96
|
if _is_valid_user_hash(user_hash):
|
97
97
|
return user_hash
|
98
98
|
|
99
|
-
|
100
|
-
user_hash = hashlib.md5(hash_str.encode()).hexdigest()[:USER_HASH_LENGTH]
|
101
|
-
if not _is_valid_user_hash(user_hash):
|
102
|
-
# A fallback in case the hash is invalid.
|
103
|
-
user_hash = uuid.uuid4().hex[:USER_HASH_LENGTH]
|
99
|
+
user_hash = generate_user_hash()
|
104
100
|
os.makedirs(os.path.dirname(_USER_HASH_FILE), exist_ok=True)
|
105
|
-
|
106
|
-
|
107
|
-
# be intentionally using a different hash, e.g. we want to keep the
|
108
|
-
# user_hash for usage collection the same on the jobs/serve controller
|
109
|
-
# as users' local client.
|
110
|
-
with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f:
|
111
|
-
f.write(user_hash)
|
101
|
+
with open(_USER_HASH_FILE, 'w', encoding='utf-8') as f:
|
102
|
+
f.write(user_hash)
|
112
103
|
return user_hash
|
113
104
|
|
114
105
|
|
@@ -253,7 +244,46 @@ class Backoff:
|
|
253
244
|
return self._backoff
|
254
245
|
|
255
246
|
|
256
|
-
|
247
|
+
_current_command: Optional[str] = None
|
248
|
+
_current_client_entrypoint: Optional[str] = None
|
249
|
+
|
250
|
+
|
251
|
+
def set_client_entrypoint_and_command(client_entrypoint: Optional[str],
|
252
|
+
client_command: Optional[str]):
|
253
|
+
"""Override the current client entrypoint and command.
|
254
|
+
|
255
|
+
This is useful when we are on the SkyPilot API server side and we have a
|
256
|
+
client entrypoint and command from the client.
|
257
|
+
"""
|
258
|
+
global _current_command, _current_client_entrypoint
|
259
|
+
_current_command = client_command
|
260
|
+
_current_client_entrypoint = client_entrypoint
|
261
|
+
|
262
|
+
|
263
|
+
def get_current_command() -> str:
|
264
|
+
"""Returns the command related to this operation.
|
265
|
+
|
266
|
+
Normally uses get_pretty_entry_point(), but will use the client command on
|
267
|
+
the server side.
|
268
|
+
"""
|
269
|
+
if _current_command is not None:
|
270
|
+
return _current_command
|
271
|
+
|
272
|
+
return get_pretty_entrypoint_cmd()
|
273
|
+
|
274
|
+
|
275
|
+
def get_current_client_entrypoint(server_entrypoint: str) -> str:
|
276
|
+
"""Returns the current client entrypoint.
|
277
|
+
|
278
|
+
Gets the client entrypoint from the context, if it is not set, returns the
|
279
|
+
server entrypoint.
|
280
|
+
"""
|
281
|
+
if _current_client_entrypoint is not None:
|
282
|
+
return _current_client_entrypoint
|
283
|
+
return server_entrypoint
|
284
|
+
|
285
|
+
|
286
|
+
def get_pretty_entrypoint_cmd() -> str:
|
257
287
|
"""Returns the prettified entry point of this process (sys.argv).
|
258
288
|
|
259
289
|
Example return values:
|
@@ -298,29 +328,51 @@ def user_and_hostname_hash() -> str:
|
|
298
328
|
return f'{getpass.getuser()}-{hostname_hash}'
|
299
329
|
|
300
330
|
|
301
|
-
def read_yaml(path: str) -> Dict[str, Any]:
|
331
|
+
def read_yaml(path: Optional[str]) -> Dict[str, Any]:
|
332
|
+
if path is None:
|
333
|
+
raise ValueError('Attempted to read a None YAML.')
|
302
334
|
with open(path, 'r', encoding='utf-8') as f:
|
303
335
|
config = yaml.safe_load(f)
|
304
336
|
return config
|
305
337
|
|
306
338
|
|
339
|
+
def read_yaml_all_str(yaml_str: str) -> List[Dict[str, Any]]:
|
340
|
+
stream = io.StringIO(yaml_str)
|
341
|
+
config = yaml.safe_load_all(stream)
|
342
|
+
configs = list(config)
|
343
|
+
if not configs:
|
344
|
+
# Empty YAML file.
|
345
|
+
return [{}]
|
346
|
+
return configs
|
347
|
+
|
348
|
+
|
307
349
|
def read_yaml_all(path: str) -> List[Dict[str, Any]]:
|
308
350
|
with open(path, 'r', encoding='utf-8') as f:
|
309
|
-
|
310
|
-
configs = list(config)
|
311
|
-
if not configs:
|
312
|
-
# Empty YAML file.
|
313
|
-
return [{}]
|
314
|
-
return configs
|
351
|
+
return read_yaml_all_str(f.read())
|
315
352
|
|
316
353
|
|
317
354
|
def dump_yaml(path: str, config: Union[List[Dict[str, Any]],
|
318
355
|
Dict[str, Any]]) -> None:
|
356
|
+
"""Dumps a YAML file.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
path: the path to the YAML file.
|
360
|
+
config: the configuration to dump.
|
361
|
+
"""
|
319
362
|
with open(path, 'w', encoding='utf-8') as f:
|
320
363
|
f.write(dump_yaml_str(config))
|
321
364
|
|
322
365
|
|
323
366
|
def dump_yaml_str(config: Union[List[Dict[str, Any]], Dict[str, Any]]) -> str:
|
367
|
+
"""Dumps a YAML string.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
config: the configuration to dump.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
The YAML string.
|
374
|
+
"""
|
375
|
+
|
324
376
|
# https://github.com/yaml/pyyaml/issues/127
|
325
377
|
class LineBreakDumper(yaml.SafeDumper):
|
326
378
|
|
@@ -408,43 +460,6 @@ def retry(method, max_retries=3, initial_backoff=1):
|
|
408
460
|
return method_with_retries
|
409
461
|
|
410
462
|
|
411
|
-
def encode_payload(payload: Any) -> str:
|
412
|
-
"""Encode a payload to make it more robust for parsing.
|
413
|
-
|
414
|
-
This makes message transfer more robust to any additional strings added to
|
415
|
-
the message during transfer.
|
416
|
-
|
417
|
-
An example message that is polluted by the system warning:
|
418
|
-
"LC_ALL: cannot change locale (en_US.UTF-8)\n<sky-payload>hello, world</sky-payload>" # pylint: disable=line-too-long
|
419
|
-
|
420
|
-
Args:
|
421
|
-
payload: A str, dict or list to be encoded.
|
422
|
-
|
423
|
-
Returns:
|
424
|
-
A string that is encoded from the payload.
|
425
|
-
"""
|
426
|
-
payload_str = json.dumps(payload)
|
427
|
-
payload_str = _PAYLOAD_STR.format(payload_str)
|
428
|
-
return payload_str
|
429
|
-
|
430
|
-
|
431
|
-
def decode_payload(payload_str: str) -> Any:
|
432
|
-
"""Decode a payload string.
|
433
|
-
|
434
|
-
Args:
|
435
|
-
payload_str: A string that is encoded from a payload.
|
436
|
-
|
437
|
-
Returns:
|
438
|
-
A str, dict or list that is decoded from the payload string.
|
439
|
-
"""
|
440
|
-
matched = _PAYLOAD_PATTERN.findall(payload_str)
|
441
|
-
if not matched:
|
442
|
-
raise ValueError(f'Invalid payload string: \n{payload_str}')
|
443
|
-
payload_str = matched[0]
|
444
|
-
payload = json.loads(payload_str)
|
445
|
-
return payload
|
446
|
-
|
447
|
-
|
448
463
|
def class_fullname(cls, skip_builtins: bool = True):
|
449
464
|
"""Get the full name of a class.
|
450
465
|
|
@@ -492,12 +507,14 @@ def remove_color(s: str):
|
|
492
507
|
return _COLOR_PATTERN.sub('', s)
|
493
508
|
|
494
509
|
|
495
|
-
def remove_file_if_exists(path: str):
|
510
|
+
def remove_file_if_exists(path: Optional[str]):
|
496
511
|
"""Delete a file if it exists.
|
497
512
|
|
498
513
|
Args:
|
499
514
|
path: The path to the file.
|
500
515
|
"""
|
516
|
+
if path is None:
|
517
|
+
return
|
501
518
|
try:
|
502
519
|
os.remove(path)
|
503
520
|
except FileNotFoundError:
|
@@ -600,7 +617,7 @@ def validate_schema(obj, schema, err_msg_prefix='', skip_none=True):
|
|
600
617
|
|
601
618
|
if err_msg:
|
602
619
|
with ux_utils.print_exception_no_traceback():
|
603
|
-
raise
|
620
|
+
raise exceptions.InvalidSkyPilotConfigError(err_msg)
|
604
621
|
|
605
622
|
|
606
623
|
def get_cleaned_username(username: str = '') -> str:
|
@@ -715,3 +732,26 @@ def hash_file(path: str, hash_alg: str) -> 'hashlib._Hash':
|
|
715
732
|
break
|
716
733
|
file_hash.update(view[:size])
|
717
734
|
return file_hash
|
735
|
+
|
736
|
+
|
737
|
+
def is_port_available(port: int, reuse_addr: bool = True) -> bool:
|
738
|
+
"""Check if a TCP port is available for binding on localhost.
|
739
|
+
|
740
|
+
Args:
|
741
|
+
port: The port number to check.
|
742
|
+
reuse_addr: If True, sets SO_REUSEADDR socket option to allow reusing
|
743
|
+
ports in TIME_WAIT state. Servers like multiprocessing.Manager set
|
744
|
+
SO_REUSEADDR by default to accelerate restart. The option should be
|
745
|
+
coordinated in check.
|
746
|
+
|
747
|
+
Returns:
|
748
|
+
bool: True if the port is available for binding, False otherwise.
|
749
|
+
"""
|
750
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
751
|
+
if reuse_addr:
|
752
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
753
|
+
try:
|
754
|
+
s.bind(('localhost', port))
|
755
|
+
return True
|
756
|
+
except OSError:
|
757
|
+
return False
|
@@ -0,0 +1,204 @@
|
|
1
|
+
"""Utilities for nested config."""
|
2
|
+
import copy
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
from sky import sky_logging
|
6
|
+
|
7
|
+
logger = sky_logging.init_logger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class Config(Dict[str, Any]):
|
11
|
+
"""SkyPilot config that supports setting/getting values with nested keys."""
|
12
|
+
|
13
|
+
def get_nested(
|
14
|
+
self,
|
15
|
+
keys: Tuple[str, ...],
|
16
|
+
default_value: Any,
|
17
|
+
override_configs: Optional[Dict[str, Any]] = None,
|
18
|
+
allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
|
19
|
+
disallowed_override_keys: Optional[List[Tuple[str,
|
20
|
+
...]]] = None) -> Any:
|
21
|
+
"""Gets a nested key.
|
22
|
+
|
23
|
+
If any key is not found, or any intermediate key does not point to a
|
24
|
+
dict value, returns 'default_value'.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
keys: A tuple of strings representing the nested keys.
|
28
|
+
default_value: The default value to return if the key is not found.
|
29
|
+
override_configs: A dict of override configs with the same schema as
|
30
|
+
the config file, but only containing the keys to override.
|
31
|
+
allowed_override_keys: A list of keys that are allowed to be
|
32
|
+
overridden.
|
33
|
+
disallowed_override_keys: A list of keys that are disallowed to be
|
34
|
+
overridden.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The value of the nested key, or 'default_value' if not found.
|
38
|
+
"""
|
39
|
+
config = copy.deepcopy(self)
|
40
|
+
if override_configs is not None:
|
41
|
+
config = _recursive_update(config, override_configs,
|
42
|
+
allowed_override_keys,
|
43
|
+
disallowed_override_keys)
|
44
|
+
return _get_nested(config, keys, default_value, pop=False)
|
45
|
+
|
46
|
+
def set_nested(self, keys: Tuple[str, ...], value: Any) -> None:
|
47
|
+
"""In-place sets a nested key to value.
|
48
|
+
|
49
|
+
Like get_nested(), if any key is not found, this will not raise an
|
50
|
+
error.
|
51
|
+
"""
|
52
|
+
override = {}
|
53
|
+
for i, key in enumerate(reversed(keys)):
|
54
|
+
if i == 0:
|
55
|
+
override = {key: value}
|
56
|
+
else:
|
57
|
+
override = {key: override}
|
58
|
+
_recursive_update(self, override)
|
59
|
+
|
60
|
+
def pop_nested(self, keys: Tuple[str, ...], default_value: Any) -> Any:
|
61
|
+
"""Pops a nested key."""
|
62
|
+
return _get_nested(self, keys, default_value, pop=True)
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def from_dict(cls, config: Optional[Dict[str, Any]]) -> 'Config':
|
66
|
+
if config is None:
|
67
|
+
return cls()
|
68
|
+
return cls(**config)
|
69
|
+
|
70
|
+
|
71
|
+
def _check_allowed_and_disallowed_override_keys(
|
72
|
+
key: str,
|
73
|
+
allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
|
74
|
+
disallowed_override_keys: Optional[List[Tuple[str, ...]]] = None
|
75
|
+
) -> Tuple[Optional[List[Tuple[str, ...]]], Optional[List[Tuple[str, ...]]]]:
|
76
|
+
allowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
|
77
|
+
disallowed_keys_with_matched_prefix: Optional[List[Tuple[str, ...]]] = []
|
78
|
+
if allowed_override_keys is not None:
|
79
|
+
for nested_key in allowed_override_keys:
|
80
|
+
if key == nested_key[0]:
|
81
|
+
if len(nested_key) == 1:
|
82
|
+
# Allowed key is fully matched, no need to check further.
|
83
|
+
allowed_keys_with_matched_prefix = None
|
84
|
+
break
|
85
|
+
assert allowed_keys_with_matched_prefix is not None
|
86
|
+
allowed_keys_with_matched_prefix.append(nested_key[1:])
|
87
|
+
if (allowed_keys_with_matched_prefix is not None and
|
88
|
+
not allowed_keys_with_matched_prefix):
|
89
|
+
raise ValueError(f'Key {key} is not in allowed override keys: '
|
90
|
+
f'{allowed_override_keys}')
|
91
|
+
else:
|
92
|
+
allowed_keys_with_matched_prefix = None
|
93
|
+
|
94
|
+
if disallowed_override_keys is not None:
|
95
|
+
for nested_key in disallowed_override_keys:
|
96
|
+
if key == nested_key[0]:
|
97
|
+
if len(nested_key) == 1:
|
98
|
+
raise ValueError(
|
99
|
+
f'Key {key} is in disallowed override keys: '
|
100
|
+
f'{disallowed_override_keys}')
|
101
|
+
assert disallowed_keys_with_matched_prefix is not None
|
102
|
+
disallowed_keys_with_matched_prefix.append(nested_key[1:])
|
103
|
+
else:
|
104
|
+
disallowed_keys_with_matched_prefix = None
|
105
|
+
return allowed_keys_with_matched_prefix, disallowed_keys_with_matched_prefix
|
106
|
+
|
107
|
+
|
108
|
+
def _recursive_update(
|
109
|
+
base_config: Config,
|
110
|
+
override_config: Dict[str, Any],
|
111
|
+
allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
|
112
|
+
disallowed_override_keys: Optional[List[Tuple[str,
|
113
|
+
...]]] = None) -> Config:
|
114
|
+
"""Recursively updates base configuration with override configuration"""
|
115
|
+
for key, value in override_config.items():
|
116
|
+
(next_allowed_override_keys, next_disallowed_override_keys
|
117
|
+
) = _check_allowed_and_disallowed_override_keys(
|
118
|
+
key, allowed_override_keys, disallowed_override_keys)
|
119
|
+
if key == 'kubernetes' and key in base_config:
|
120
|
+
merge_k8s_configs(base_config[key], value,
|
121
|
+
next_allowed_override_keys,
|
122
|
+
next_disallowed_override_keys)
|
123
|
+
elif (isinstance(value, dict) and key in base_config and
|
124
|
+
isinstance(base_config[key], dict)):
|
125
|
+
_recursive_update(base_config[key], value,
|
126
|
+
next_allowed_override_keys,
|
127
|
+
next_disallowed_override_keys)
|
128
|
+
else:
|
129
|
+
base_config[key] = value
|
130
|
+
return base_config
|
131
|
+
|
132
|
+
|
133
|
+
def _get_nested(configs: Optional[Dict[str, Any]],
|
134
|
+
keys: Tuple[str, ...],
|
135
|
+
default_value: Any,
|
136
|
+
pop: bool = False) -> Any:
|
137
|
+
if configs is None:
|
138
|
+
return default_value
|
139
|
+
curr = configs
|
140
|
+
for i, key in enumerate(keys):
|
141
|
+
if isinstance(curr, dict) and key in curr:
|
142
|
+
value = curr[key]
|
143
|
+
if i == len(keys) - 1:
|
144
|
+
if pop:
|
145
|
+
curr.pop(key, default_value)
|
146
|
+
curr = value
|
147
|
+
else:
|
148
|
+
return default_value
|
149
|
+
logger.debug(f'User config: {".".join(keys)} -> {curr}')
|
150
|
+
return curr
|
151
|
+
|
152
|
+
|
153
|
+
def merge_k8s_configs(
|
154
|
+
base_config: Dict[Any, Any],
|
155
|
+
override_config: Dict[Any, Any],
|
156
|
+
allowed_override_keys: Optional[List[Tuple[str, ...]]] = None,
|
157
|
+
disallowed_override_keys: Optional[List[Tuple[str,
|
158
|
+
...]]] = None) -> None:
|
159
|
+
"""Merge two configs into the base_config.
|
160
|
+
|
161
|
+
Updates nested dictionaries instead of replacing them.
|
162
|
+
If a list is encountered, it will be appended to the base_config list.
|
163
|
+
|
164
|
+
An exception is when the key is 'containers', in which case the
|
165
|
+
first container in the list will be fetched and merge_dict will be
|
166
|
+
called on it with the first container in the base_config list.
|
167
|
+
"""
|
168
|
+
for key, value in override_config.items():
|
169
|
+
(next_allowed_override_keys, next_disallowed_override_keys
|
170
|
+
) = _check_allowed_and_disallowed_override_keys(
|
171
|
+
key, allowed_override_keys, disallowed_override_keys)
|
172
|
+
if isinstance(value, dict) and key in base_config:
|
173
|
+
merge_k8s_configs(base_config[key], value,
|
174
|
+
next_allowed_override_keys,
|
175
|
+
next_disallowed_override_keys)
|
176
|
+
elif isinstance(value, list) and key in base_config:
|
177
|
+
assert isinstance(base_config[key], list), \
|
178
|
+
f'Expected {key} to be a list, found {base_config[key]}'
|
179
|
+
if key in ['containers', 'imagePullSecrets']:
|
180
|
+
# If the key is 'containers' or 'imagePullSecrets, we take the
|
181
|
+
# first and only container/secret in the list and merge it, as
|
182
|
+
# we only support one container per pod.
|
183
|
+
assert len(value) == 1, \
|
184
|
+
f'Expected only one container, found {value}'
|
185
|
+
merge_k8s_configs(base_config[key][0], value[0],
|
186
|
+
next_allowed_override_keys,
|
187
|
+
next_disallowed_override_keys)
|
188
|
+
elif key in ['volumes', 'volumeMounts']:
|
189
|
+
# If the key is 'volumes' or 'volumeMounts', we search for
|
190
|
+
# item with the same name and merge it.
|
191
|
+
for new_volume in value:
|
192
|
+
new_volume_name = new_volume.get('name')
|
193
|
+
if new_volume_name is not None:
|
194
|
+
destination_volume = next(
|
195
|
+
(v for v in base_config[key]
|
196
|
+
if v.get('name') == new_volume_name), None)
|
197
|
+
if destination_volume is not None:
|
198
|
+
merge_k8s_configs(destination_volume, new_volume)
|
199
|
+
else:
|
200
|
+
base_config[key].append(new_volume)
|
201
|
+
else:
|
202
|
+
base_config[key].extend(value)
|
203
|
+
else:
|
204
|
+
base_config[key] = value
|
@@ -1,8 +1,7 @@
|
|
1
1
|
"""Utils to check if the ssh control master should be disabled."""
|
2
2
|
|
3
|
-
import functools
|
4
|
-
|
5
3
|
from sky import sky_logging
|
4
|
+
from sky.utils import annotations
|
6
5
|
from sky.utils import subprocess_utils
|
7
6
|
|
8
7
|
logger = sky_logging.init_logger(__name__)
|
@@ -34,7 +33,7 @@ def is_tmp_9p_filesystem() -> bool:
|
|
34
33
|
return filesystem_types[1].lower() == '9p'
|
35
34
|
|
36
35
|
|
37
|
-
@
|
36
|
+
@annotations.lru_cache(scope='global')
|
38
37
|
def should_disable_control_master() -> bool:
|
39
38
|
"""Whether disable ssh control master based on file system.
|
40
39
|
|