skypilot-nightly 1.0.0.dev20250520__py3-none-any.whl → 1.0.0.dev20250522__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sky/__init__.py +2 -2
  2. sky/backends/backend_utils.py +4 -1
  3. sky/backends/cloud_vm_ray_backend.py +56 -37
  4. sky/check.py +3 -3
  5. sky/cli.py +89 -16
  6. sky/client/cli.py +89 -16
  7. sky/client/sdk.py +92 -4
  8. sky/clouds/__init__.py +2 -0
  9. sky/clouds/cloud.py +6 -0
  10. sky/clouds/gcp.py +156 -21
  11. sky/clouds/service_catalog/__init__.py +3 -0
  12. sky/clouds/service_catalog/common.py +9 -2
  13. sky/clouds/service_catalog/constants.py +1 -0
  14. sky/core.py +6 -8
  15. sky/dashboard/out/404.html +1 -1
  16. sky/dashboard/out/_next/static/CzOVV6JpRQBRt5GhZuhyK/_buildManifest.js +1 -0
  17. sky/dashboard/out/_next/static/chunks/236-1a3a9440417720eb.js +6 -0
  18. sky/dashboard/out/_next/static/chunks/37-d584022b0da4ac3b.js +6 -0
  19. sky/dashboard/out/_next/static/chunks/393-e1eaa440481337ec.js +1 -0
  20. sky/dashboard/out/_next/static/chunks/480-f28cd152a98997de.js +1 -0
  21. sky/dashboard/out/_next/static/chunks/{678-206dddca808e6d16.js → 582-683f4f27b81996dc.js} +2 -2
  22. sky/dashboard/out/_next/static/chunks/pages/_app-8cfab319f9fb3ae8.js +1 -0
  23. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-33bc2bec322249b1.js +1 -0
  24. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-e2fc2dd1955e6c36.js +1 -0
  25. sky/dashboard/out/_next/static/chunks/pages/clusters-3a748bd76e5c2984.js +1 -0
  26. sky/dashboard/out/_next/static/chunks/pages/infra-9180cd91cee64b96.js +1 -0
  27. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-70756c2dad850a7e.js +1 -0
  28. sky/dashboard/out/_next/static/chunks/pages/jobs-ecd804b9272f4a7c.js +1 -0
  29. sky/dashboard/out/_next/static/css/7e7ce4ff31d3977b.css +3 -0
  30. sky/dashboard/out/clusters/[cluster]/[job].html +1 -1
  31. sky/dashboard/out/clusters/[cluster].html +1 -1
  32. sky/dashboard/out/clusters.html +1 -1
  33. sky/dashboard/out/index.html +1 -1
  34. sky/dashboard/out/infra.html +1 -0
  35. sky/dashboard/out/jobs/[job].html +1 -1
  36. sky/dashboard/out/jobs.html +1 -1
  37. sky/data/storage.py +1 -0
  38. sky/execution.py +57 -8
  39. sky/jobs/server/core.py +5 -3
  40. sky/jobs/utils.py +38 -7
  41. sky/optimizer.py +41 -39
  42. sky/provision/gcp/constants.py +147 -4
  43. sky/provision/gcp/instance_utils.py +10 -0
  44. sky/provision/gcp/volume_utils.py +247 -0
  45. sky/provision/provisioner.py +16 -7
  46. sky/resources.py +233 -18
  47. sky/serve/serve_utils.py +5 -13
  48. sky/serve/server/core.py +2 -4
  49. sky/server/common.py +60 -14
  50. sky/server/constants.py +2 -0
  51. sky/server/html/token_page.html +154 -0
  52. sky/server/requests/executor.py +3 -6
  53. sky/server/requests/payloads.py +3 -3
  54. sky/server/server.py +40 -8
  55. sky/skypilot_config.py +117 -31
  56. sky/task.py +24 -1
  57. sky/templates/gcp-ray.yml.j2 +44 -1
  58. sky/templates/nebius-ray.yml.j2 +0 -2
  59. sky/utils/admin_policy_utils.py +26 -22
  60. sky/utils/cli_utils/status_utils.py +95 -56
  61. sky/utils/common_utils.py +35 -2
  62. sky/utils/context.py +36 -6
  63. sky/utils/context_utils.py +15 -0
  64. sky/utils/infra_utils.py +175 -0
  65. sky/utils/resources_utils.py +55 -21
  66. sky/utils/schemas.py +111 -5
  67. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/METADATA +1 -1
  68. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/RECORD +73 -68
  69. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/WHEEL +1 -1
  70. sky/dashboard/out/_next/static/8hlc2dkbIDDBOkxtEW7X6/_buildManifest.js +0 -1
  71. sky/dashboard/out/_next/static/chunks/236-f49500b82ad5392d.js +0 -6
  72. sky/dashboard/out/_next/static/chunks/37-0a572fe0dbb89c4d.js +0 -6
  73. sky/dashboard/out/_next/static/chunks/845-0ca6f2c1ba667c3b.js +0 -1
  74. sky/dashboard/out/_next/static/chunks/979-7bf73a4c7cea0f5c.js +0 -1
  75. sky/dashboard/out/_next/static/chunks/pages/_app-e6b013bc3f77ad60.js +0 -1
  76. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]/[job]-e15db85d0ea1fbe1.js +0 -1
  77. sky/dashboard/out/_next/static/chunks/pages/clusters/[cluster]-f383db7389368ea7.js +0 -1
  78. sky/dashboard/out/_next/static/chunks/pages/clusters-a93b93e10b8b074e.js +0 -1
  79. sky/dashboard/out/_next/static/chunks/pages/jobs/[job]-03f279c6741fb48b.js +0 -1
  80. sky/dashboard/out/_next/static/chunks/pages/jobs-a75029b67aab6a2e.js +0 -1
  81. sky/dashboard/out/_next/static/css/c6933bbb2ce7f4dd.css +0 -3
  82. /sky/dashboard/out/_next/static/{8hlc2dkbIDDBOkxtEW7X6 → CzOVV6JpRQBRt5GhZuhyK}/_ssgManifest.js +0 -0
  83. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/entry_points.txt +0 -0
  84. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/licenses/LICENSE +0 -0
  85. {skypilot_nightly-1.0.0.dev20250520.dist-info → skypilot_nightly-1.0.0.dev20250522.dist-info}/top_level.txt +0 -0
sky/skypilot_config.py CHANGED
@@ -52,6 +52,7 @@ import contextlib
52
52
  import copy
53
53
  import json
54
54
  import os
55
+ import tempfile
55
56
  import threading
56
57
  import typing
57
58
  from typing import Any, Dict, Iterator, List, Optional, Tuple
@@ -62,6 +63,7 @@ from sky.adaptors import common as adaptors_common
62
63
  from sky.skylet import constants
63
64
  from sky.utils import common_utils
64
65
  from sky.utils import config_utils
66
+ from sky.utils import context
65
67
  from sky.utils import schemas
66
68
  from sky.utils import ux_utils
67
69
 
@@ -105,13 +107,66 @@ ENV_VAR_PROJECT_CONFIG = f'{constants.SKYPILOT_ENV_VAR_PREFIX}PROJECT_CONFIG'
105
107
  _GLOBAL_CONFIG_PATH = '~/.sky/config.yaml'
106
108
  _PROJECT_CONFIG_PATH = '.sky.yaml'
107
109
 
108
- # The loaded config.
109
- _dict = config_utils.Config()
110
- _loaded_config_path: Optional[str] = None
111
- _config_overridden: bool = False
110
+
111
+ class ConfigContext:
112
+
113
+ def __init__(self,
114
+ config: config_utils.Config = config_utils.Config(),
115
+ config_path: Optional[str] = None,
116
+ config_overridden: bool = False):
117
+ self.config = config
118
+ self.config_path = config_path
119
+ self.config_overridden = config_overridden
120
+
121
+
122
+ # The global loaded config.
123
+ _global_config_context = ConfigContext()
112
124
  _reload_config_lock = threading.Lock()
113
125
 
114
126
 
127
+ def _get_config_context() -> ConfigContext:
128
+ """Get config context for current context.
129
+
130
+ If no context is available, the global config context is returned.
131
+ """
132
+ ctx = context.get()
133
+ if not ctx:
134
+ return _global_config_context
135
+ if ctx.config_context is None:
136
+ # Config context for current context is not initialized, inherit from
137
+ # the global one.
138
+ ctx.config_context = ConfigContext(
139
+ config=copy.deepcopy(_global_config_context.config),
140
+ config_path=_global_config_context.config_path,
141
+ config_overridden=_global_config_context.config_overridden,
142
+ )
143
+ return ctx.config_context
144
+
145
+
146
+ def _get_loaded_config() -> config_utils.Config:
147
+ return _get_config_context().config
148
+
149
+
150
+ def _set_loaded_config(config: config_utils.Config) -> None:
151
+ _get_config_context().config = config
152
+
153
+
154
+ def _get_loaded_config_path() -> Optional[str]:
155
+ return _get_config_context().config_path
156
+
157
+
158
+ def _set_loaded_config_path(path: Optional[str]) -> None:
159
+ _get_config_context().config_path = path
160
+
161
+
162
+ def _is_config_overridden() -> bool:
163
+ return _get_config_context().config_overridden
164
+
165
+
166
+ def _set_config_overridden(config_overridden: bool) -> None:
167
+ _get_config_context().config_overridden = config_overridden
168
+
169
+
115
170
  def get_user_config_path() -> str:
116
171
  """Returns the path to the user config file."""
117
172
  return _GLOBAL_CONFIG_PATH
@@ -224,7 +279,7 @@ def get_nested(keys: Tuple[str, ...],
224
279
  Returns:
225
280
  The value of the nested key, or 'default_value' if not found.
226
281
  """
227
- return _dict.get_nested(
282
+ return _get_loaded_config().get_nested(
228
283
  keys,
229
284
  default_value,
230
285
  override_configs,
@@ -237,14 +292,14 @@ def set_nested(keys: Tuple[str, ...], value: Any) -> Dict[str, Any]:
237
292
 
238
293
  Like get_nested(), if any key is not found, this will not raise an error.
239
294
  """
240
- copied_dict = copy.deepcopy(_dict)
295
+ copied_dict = copy.deepcopy(_get_loaded_config())
241
296
  copied_dict.set_nested(keys, value)
242
297
  return dict(**copied_dict)
243
298
 
244
299
 
245
300
  def to_dict() -> config_utils.Config:
246
301
  """Returns a deep-copied version of the current config."""
247
- return copy.deepcopy(_dict)
302
+ return copy.deepcopy(_get_loaded_config())
248
303
 
249
304
 
250
305
  def _get_config_file_path(envvar: str) -> Optional[str]:
@@ -345,10 +400,9 @@ def _parse_dotlist(dotlist: List[str]) -> config_utils.Config:
345
400
 
346
401
 
347
402
  def _reload_config_from_internal_file(internal_config_path: str) -> None:
348
- global _dict, _loaded_config_path
349
403
  # Reset the global variables, to avoid using stale values.
350
- _dict = config_utils.Config()
351
- _loaded_config_path = None
404
+ _set_loaded_config(config_utils.Config())
405
+ _set_loaded_config_path(None)
352
406
 
353
407
  config_path = os.path.expanduser(internal_config_path)
354
408
  if not os.path.exists(config_path):
@@ -359,14 +413,13 @@ def _reload_config_from_internal_file(internal_config_path: str) -> None:
359
413
  'exist. Please double check the path or unset the env var: '
360
414
  f'unset {ENV_VAR_SKYPILOT_CONFIG}')
361
415
  logger.debug(f'Using config path: {config_path}')
362
- _dict = parse_config_file(config_path)
363
- _loaded_config_path = config_path
416
+ _set_loaded_config(parse_config_file(config_path))
417
+ _set_loaded_config_path(config_path)
364
418
 
365
419
 
366
420
  def _reload_config_as_server() -> None:
367
- global _dict
368
421
  # Reset the global variables, to avoid using stale values.
369
- _dict = config_utils.Config()
422
+ _set_loaded_config(config_utils.Config())
370
423
 
371
424
  overrides: List[config_utils.Config] = []
372
425
  server_config = get_server_config()
@@ -382,13 +435,12 @@ def _reload_config_as_server() -> None:
382
435
  logger.debug(
383
436
  f'server config: \n'
384
437
  f'{common_utils.dump_yaml_str(dict(overlaid_server_config))}')
385
- _dict = overlaid_server_config
438
+ _set_loaded_config(overlaid_server_config)
386
439
 
387
440
 
388
441
  def _reload_config_as_client() -> None:
389
- global _dict
390
442
  # Reset the global variables, to avoid using stale values.
391
- _dict = config_utils.Config()
443
+ _set_loaded_config(config_utils.Config())
392
444
 
393
445
  overrides: List[config_utils.Config] = []
394
446
  user_config = get_user_config()
@@ -407,15 +459,15 @@ def _reload_config_as_client() -> None:
407
459
  logger.debug(
408
460
  f'client config (before task and CLI overrides): \n'
409
461
  f'{common_utils.dump_yaml_str(dict(overlaid_client_config))}')
410
- _dict = overlaid_client_config
462
+ _set_loaded_config(overlaid_client_config)
411
463
 
412
464
 
413
465
  def loaded_config_path() -> Optional[str]:
414
466
  """Returns the path to the loaded config file, or
415
467
  '<overridden>' if the config is overridden."""
416
- if _config_overridden:
468
+ if _is_config_overridden():
417
469
  return '<overridden>'
418
- return _loaded_config_path
470
+ return _get_loaded_config_path()
419
471
 
420
472
 
421
473
  # Load on import, synchronization is guaranteed by python interpreter.
@@ -424,21 +476,20 @@ _reload_config()
424
476
 
425
477
  def loaded() -> bool:
426
478
  """Returns if the user configurations are loaded."""
427
- return bool(_dict)
479
+ return bool(_get_loaded_config())
428
480
 
429
481
 
430
482
  @contextlib.contextmanager
431
483
  def override_skypilot_config(
432
484
  override_configs: Optional[Dict[str, Any]]) -> Iterator[None]:
433
485
  """Overrides the user configurations."""
434
- global _dict, _config_overridden
435
486
  # TODO(SKY-1215): allow admin user to extend the disallowed keys or specify
436
487
  # allowed keys.
437
488
  if not override_configs:
438
489
  # If no override configs (None or empty dict), do nothing.
439
490
  yield
440
491
  return
441
- original_config = _dict
492
+ original_config = _get_loaded_config()
442
493
  override_configs = config_utils.Config(override_configs)
443
494
  disallowed_diff_keys = []
444
495
  for key in constants.SKIPPED_CLIENT_OVERRIDE_KEYS:
@@ -455,7 +506,7 @@ def override_skypilot_config(
455
506
  'and will be ignored. Remove these keys to disable this warning. '
456
507
  'If you want to specify it, please modify it on server side or '
457
508
  'contact your administrator.')
458
- config = _dict.get_nested(
509
+ config = original_config.get_nested(
459
510
  keys=tuple(),
460
511
  default_value=None,
461
512
  override_configs=dict(override_configs),
@@ -469,8 +520,8 @@ def override_skypilot_config(
469
520
  'https://docs.skypilot.co/en/latest/reference/config.html. ' # pylint: disable=line-too-long
470
521
  'Error: ',
471
522
  skip_none=False)
472
- _config_overridden = True
473
- _dict = config
523
+ _set_config_overridden(True)
524
+ _set_loaded_config(config)
474
525
  yield
475
526
  except exceptions.InvalidSkyPilotConfigError as e:
476
527
  with ux_utils.print_exception_no_traceback():
@@ -483,8 +534,43 @@ def override_skypilot_config(
483
534
  f'{common_utils.dump_yaml_str(dict(override_configs))}\n'
484
535
  f'Details: {e}') from e
485
536
  finally:
486
- _dict = original_config
487
- _config_overridden = False
537
+ _set_loaded_config(original_config)
538
+ _set_config_overridden(False)
539
+
540
+
541
+ @contextlib.contextmanager
542
+ def replace_skypilot_config(new_configs: config_utils.Config) -> Iterator[None]:
543
+ """Replaces the global config with the new configs.
544
+
545
+ This function is concurrent safe when it is:
546
+ 1. called in different processes;
547
+ 2. or called in a same process but with different context, refer to
548
+ sky_utils.context for more details.
549
+ """
550
+ original_config = _get_loaded_config()
551
+ original_env_var = os.environ.get(ENV_VAR_SKYPILOT_CONFIG)
552
+ if new_configs != original_config:
553
+ # Modify the global config of current process or context
554
+ _set_loaded_config(new_configs)
555
+ with tempfile.NamedTemporaryFile(delete=False,
556
+ mode='w',
557
+ prefix='mutated-skypilot-config-',
558
+ suffix='.yaml') as temp_file:
559
+ common_utils.dump_yaml(temp_file.name, dict(**new_configs))
560
+ # Modify the env var of current process or context so that the
561
+ # new config will be used by spawned sub-processes.
562
+ # Note that this code modifies os.environ directly because it
563
+ # will be hijacked to be context-aware if a context is active.
564
+ os.environ[ENV_VAR_SKYPILOT_CONFIG] = temp_file.name
565
+ yield
566
+ # Restore the original config and env var.
567
+ _set_loaded_config(original_config)
568
+ if original_env_var:
569
+ os.environ[ENV_VAR_SKYPILOT_CONFIG] = original_env_var
570
+ else:
571
+ os.environ.pop(ENV_VAR_SKYPILOT_CONFIG, None)
572
+ else:
573
+ yield
488
574
 
489
575
 
490
576
  def _compose_cli_config(cli_config: Optional[List[str]]) -> config_utils.Config:
@@ -529,11 +615,11 @@ def apply_cli_config(cli_config: Optional[List[str]]) -> Dict[str, Any]:
529
615
  cli_config: A path to a config file or a comma-separated
530
616
  list of key-value pairs.
531
617
  """
532
- global _dict
533
618
  parsed_config = _compose_cli_config(cli_config)
534
619
  if sky_logging.logging_enabled(logger, sky_logging.DEBUG):
535
620
  logger.debug(f'applying following CLI overrides: \n'
536
621
  f'{common_utils.dump_yaml_str(dict(parsed_config))}')
537
- _dict = overlay_skypilot_config(original_config=_dict,
538
- override_configs=parsed_config)
622
+ _set_loaded_config(
623
+ overlay_skypilot_config(original_config=_get_loaded_config(),
624
+ override_configs=parsed_config))
539
625
  return parsed_config
sky/task.py CHANGED
@@ -512,6 +512,7 @@ class Task:
512
512
  # storage objects with the storage/storage_mount objects.
513
513
  fm_storages = []
514
514
  file_mounts = config.pop('file_mounts', None)
515
+ volumes = []
515
516
  if file_mounts is not None:
516
517
  copy_mounts = {}
517
518
  for dst_path, src in file_mounts.items():
@@ -521,7 +522,27 @@ class Task:
521
522
  # If the src is not a str path, it is likely a dict. Try to
522
523
  # parse storage object.
523
524
  elif isinstance(src, dict):
524
- fm_storages.append((dst_path, src))
525
+ if (src.get('store') ==
526
+ storage_lib.StoreType.VOLUME.value.lower()):
527
+ # Build the volumes config for resources.
528
+ volume_config = {
529
+ 'path': dst_path,
530
+ }
531
+ if src.get('name'):
532
+ volume_config['name'] = src.get('name')
533
+ persistent = src.get('persistent', False)
534
+ volume_config['auto_delete'] = not persistent
535
+ volume_config_detail = src.get('config', {})
536
+ volume_config.update(volume_config_detail)
537
+ volumes.append(volume_config)
538
+ source_path = src.get('source')
539
+ if source_path:
540
+ # For volume, copy the source path to the
541
+ # data directory of the volume mount point.
542
+ copy_mounts[
543
+ f'{dst_path.rstrip("/")}/data'] = source_path
544
+ else:
545
+ fm_storages.append((dst_path, src))
525
546
  else:
526
547
  with ux_utils.print_exception_no_traceback():
527
548
  raise ValueError(f'Unable to parse file_mount '
@@ -599,6 +620,8 @@ class Task:
599
620
  'experimental.config_overrides')
600
621
  resources_config[
601
622
  '_cluster_config_overrides'] = cluster_config_override
623
+ if volumes:
624
+ resources_config['volumes'] = volumes
602
625
  task.set_resources(sky.Resources.from_yaml_config(resources_config))
603
626
 
604
627
  service = config.pop('service', None)
@@ -109,12 +109,27 @@ available_node_types:
109
109
  {%- if tpu_vm %}
110
110
  acceleratorType: {{tpu_type}}
111
111
  runtimeVersion: {{runtime_version}}
112
+ {%- if volumes %}
113
+ dataDisks:
114
+ {%- for volume in volumes %}
115
+ {%- if volume.source %}
116
+ - sourceDisk: {{volume.source}}
117
+ {%- endif %}
118
+ {%- if volume.attach_mode %}
119
+ mode: {{volume.attach_mode}}
120
+ {%- endif %}
121
+ {%- endfor %}
122
+ {%- endif %}
112
123
  metadata:
113
124
  # TPU VM's metadata has different format than normal VMs.
114
125
  # After replacing the variables, this will become username:ssh_public_key_content.
115
126
  # This is a specific syntax required by GCP https://cloud.google.com/compute/docs/connect/add-ssh-keys
116
127
  ssh-keys: |-
117
128
  skypilot:ssh_user:skypilot:ssh_public_key_content
129
+ {%- if user_data is not none %}
130
+ startup-script: |-
131
+ {{ user_data | indent(10) }}
132
+ {%- endif %}
118
133
  {%- if use_spot %}
119
134
  schedulingConfig:
120
135
  preemptible: true
@@ -138,6 +153,34 @@ available_node_types:
138
153
  {%- if disk_iops %}
139
154
  provisionedIops: {{disk_iops}}
140
155
  {%- endif %}
156
+ {%- for volume in volumes %}
157
+ - boot: false
158
+ autoDelete: {{volume.auto_delete}}
159
+ type: {{volume.storage_type}}
160
+ deviceName: {{volume.device_name}}
161
+ {%- if volume.source %}
162
+ source: {{volume.source}}
163
+ {%- endif %}
164
+ {%- if volume.attach_mode %}
165
+ mode: {{volume.attach_mode}}
166
+ {%- endif %}
167
+ {%- if volume.interface_type %}
168
+ interface: {{volume.interface_type}}
169
+ {%- endif %}
170
+ {%- if volume.disk_tier %}
171
+ initializeParams:
172
+ diskType: zones/{{zones}}/diskTypes/{{volume.disk_tier}}
173
+ {%- endif %}
174
+ {%- if volume.disk_name %}
175
+ diskName: {{volume.disk_name}}
176
+ {%- endif %}
177
+ {%- if volume.disk_size %}
178
+ diskSizeGb: {{volume.disk_size}}
179
+ {%- endif %}
180
+ {%- if volume.iops %}
181
+ provisionedIops: {{volume.iops}}
182
+ {%- endif %}
183
+ {%- endfor %}
141
184
  {%- if gpu is not none %}
142
185
  guestAccelerators:
143
186
  - acceleratorType: projects/{{gcp_project_id}}/zones/{{zones}}/acceleratorTypes/{{gpu}}
@@ -157,7 +200,7 @@ available_node_types:
157
200
  {%- if user_data is not none %}
158
201
  - key: user-data
159
202
  value: |-
160
- {{ user_data | indent(10) }}
203
+ {{ user_data | indent(14) }}
161
204
  {%- endif %}
162
205
  {%- if use_spot or gpu is not none %}
163
206
  scheduling:
@@ -47,11 +47,9 @@ available_node_types:
47
47
  ImageId: {{image_id}}
48
48
  DiskSize: {{disk_size}}
49
49
  UserData: |
50
- {%- if docker_image is not none %}
51
50
  runcmd:
52
51
  - sudo sed -i 's/^#\?AllowTcpForwarding.*/AllowTcpForwarding yes/' /etc/ssh/sshd_config
53
52
  - systemctl restart sshd
54
- {%- endif %}
55
53
 
56
54
  {# Two available OS images:
57
55
  1. ubuntu22.04-driverless - requires Docker installation
@@ -1,9 +1,8 @@
1
1
  """Admin policy utils."""
2
+ import contextlib
2
3
  import copy
3
4
  import importlib
4
- import os
5
- import tempfile
6
- from typing import Optional, Tuple, Union
5
+ from typing import Iterator, Optional, Tuple, Union
7
6
 
8
7
  import colorama
9
8
 
@@ -52,9 +51,31 @@ def _get_policy_cls(
52
51
  return policy_cls
53
52
 
54
53
 
54
+ @contextlib.contextmanager
55
+ def apply_and_use_config_in_current_request(
56
+ entrypoint: Union['dag_lib.Dag', 'task_lib.Task'],
57
+ request_options: Optional[admin_policy.RequestOptions] = None,
58
+ ) -> Iterator['dag_lib.Dag']:
59
+ """Applies an admin policy and override SkyPilot config for current request
60
+
61
+ This is a helper function of `apply()` that applies an admin policy and
62
+ overrides the SkyPilot config for the current request as a context manager.
63
+ The original SkyPilot config will be restored when the context manager is
64
+ exited.
65
+
66
+ Refer to `apply()` for more details.
67
+ """
68
+ original_config = skypilot_config.to_dict()
69
+ dag, mutated_config = apply(entrypoint, request_options)
70
+ if mutated_config != original_config:
71
+ with skypilot_config.replace_skypilot_config(mutated_config):
72
+ yield dag
73
+ else:
74
+ yield dag
75
+
76
+
55
77
  def apply(
56
78
  entrypoint: Union['dag_lib.Dag', 'task_lib.Task'],
57
- use_mutated_config_in_current_request: bool = True,
58
79
  request_options: Optional[admin_policy.RequestOptions] = None,
59
80
  ) -> Tuple['dag_lib.Dag', config_utils.Config]:
60
81
  """Applies an admin policy (if registered) to a DAG or a task.
@@ -85,8 +106,7 @@ def apply(
85
106
  return dag, skypilot_config.to_dict()
86
107
 
87
108
  logger.info(f'Applying policy: {policy}')
88
- original_config = skypilot_config.to_dict()
89
- config = copy.deepcopy(original_config)
109
+ config = copy.deepcopy(skypilot_config.to_dict())
90
110
  mutated_dag = dag_lib.Dag()
91
111
  mutated_dag.name = dag.name
92
112
 
@@ -126,22 +146,6 @@ def apply(
126
146
  mutated_dag.graph.add_edge(mutated_dag.tasks[u_idx],
127
147
  mutated_dag.tasks[v_idx])
128
148
 
129
- if (use_mutated_config_in_current_request and
130
- original_config != mutated_config):
131
- with tempfile.NamedTemporaryFile(
132
- delete=False,
133
- mode='w',
134
- prefix='policy-mutated-skypilot-config-',
135
- suffix='.yaml') as temp_file:
136
-
137
- common_utils.dump_yaml(temp_file.name, dict(**mutated_config))
138
- os.environ[skypilot_config.ENV_VAR_SKYPILOT_CONFIG] = temp_file.name
139
- logger.debug(f'Updated SkyPilot config: {temp_file.name}')
140
- # TODO(zhwu): This is not a clean way to update the SkyPilot config,
141
- # because we are resetting the global context for a single DAG,
142
- # which is conceptually weird.
143
- importlib.reload(skypilot_config)
144
-
145
149
  logger.debug(f'Mutated user request: {mutated_user_request}')
146
150
  mutated_dag.policy_applied = True
147
151
  return mutated_dag, mutated_config