konduktor-nightly 0.1.0.dev20250209104336__py3-none-any.whl → 0.1.0.dev20250313070642__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. konduktor/__init__.py +16 -6
  2. konduktor/adaptors/__init__.py +0 -0
  3. konduktor/adaptors/common.py +88 -0
  4. konduktor/adaptors/gcp.py +112 -0
  5. konduktor/backends/__init__.py +8 -0
  6. konduktor/backends/backend.py +86 -0
  7. konduktor/backends/jobset.py +218 -0
  8. konduktor/backends/jobset_utils.py +447 -0
  9. konduktor/check.py +192 -0
  10. konduktor/cli.py +790 -0
  11. konduktor/cloud_stores.py +158 -0
  12. konduktor/config.py +420 -0
  13. konduktor/constants.py +36 -0
  14. konduktor/controller/constants.py +6 -6
  15. konduktor/controller/launch.py +3 -3
  16. konduktor/controller/node.py +5 -5
  17. konduktor/controller/parse.py +23 -23
  18. konduktor/dashboard/backend/main.py +57 -57
  19. konduktor/dashboard/backend/sockets.py +19 -19
  20. konduktor/data/__init__.py +9 -0
  21. konduktor/data/constants.py +12 -0
  22. konduktor/data/data_utils.py +223 -0
  23. konduktor/data/gcp/__init__.py +19 -0
  24. konduktor/data/gcp/constants.py +42 -0
  25. konduktor/data/gcp/gcs.py +906 -0
  26. konduktor/data/gcp/utils.py +9 -0
  27. konduktor/data/storage.py +799 -0
  28. konduktor/data/storage_utils.py +500 -0
  29. konduktor/execution.py +444 -0
  30. konduktor/kube_client.py +153 -48
  31. konduktor/logging.py +49 -5
  32. konduktor/manifests/dmesg_daemonset.yaml +8 -0
  33. konduktor/manifests/pod_cleanup_controller.yaml +129 -0
  34. konduktor/resource.py +478 -0
  35. konduktor/task.py +867 -0
  36. konduktor/templates/jobset.yaml.j2 +31 -0
  37. konduktor/templates/pod.yaml.j2 +185 -0
  38. konduktor/usage/__init__.py +0 -0
  39. konduktor/usage/constants.py +21 -0
  40. konduktor/utils/__init__.py +0 -0
  41. konduktor/utils/accelerator_registry.py +21 -0
  42. konduktor/utils/annotations.py +62 -0
  43. konduktor/utils/base64_utils.py +93 -0
  44. konduktor/utils/common_utils.py +393 -0
  45. konduktor/utils/constants.py +5 -0
  46. konduktor/utils/env_options.py +55 -0
  47. konduktor/utils/exceptions.py +226 -0
  48. konduktor/utils/kubernetes_enums.py +8 -0
  49. konduktor/utils/kubernetes_utils.py +652 -0
  50. konduktor/utils/log_utils.py +251 -0
  51. konduktor/utils/loki_utils.py +85 -0
  52. konduktor/utils/rich_utils.py +123 -0
  53. konduktor/utils/schemas.py +581 -0
  54. konduktor/utils/subprocess_utils.py +273 -0
  55. konduktor/utils/ux_utils.py +216 -0
  56. konduktor/utils/validator.py +20 -0
  57. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/LICENSE +0 -1
  58. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/METADATA +13 -2
  59. konduktor_nightly-0.1.0.dev20250313070642.dist-info/RECORD +94 -0
  60. konduktor_nightly-0.1.0.dev20250209104336.dist-info/RECORD +0 -48
  61. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/WHEEL +0 -0
  62. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/entry_points.txt +0 -0
konduktor/task.py ADDED
@@ -0,0 +1,867 @@
1
+ # Proprietary Changes made for Trainy under the Trainy Software License
2
+ # Original source: konduktor: https://github.com/konduktor-org/konduktor
3
+ # which is Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Task: a coarse-grained stage in an application."""
14
+
15
+ import inspect
16
+ import json
17
+ import os
18
+ import re
19
+ import typing
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ import yaml
23
+
24
+ if typing.TYPE_CHECKING:
25
+ import konduktor.resource as resources_lib
26
+
27
+ import konduktor
28
+ from konduktor import constants, logging
29
+ from konduktor.data import data_utils
30
+ from konduktor.data import storage as storage_lib
31
+ from konduktor.utils import common_utils, exceptions, schemas, ux_utils
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _VALID_NAME_REGEX = '[a-zA-Z0-9]+(?:[._-]{1,2}[a-zA-Z0-9]+)*'
36
+ _VALID_NAME_DESCR = (
37
+ 'ASCII characters and may contain lowercase and'
38
+ ' uppercase letters, digits, underscores, periods,'
39
+ ' and dashes. Must start and end with alphanumeric'
40
+ ' characters. No triple dashes or underscores.'
41
+ )
42
+
43
+ _RUN_FN_CHECK_FAIL_MSG = (
44
+ 'run command generator must take exactly 2 arguments: node_rank (int) and'
45
+ 'a list of node ip addresses (List[str]). Got {run_sig}'
46
+ )
47
+
48
+
49
+ def _is_valid_name(name: Optional[str]) -> bool:
50
+ """Checks if the task name is valid.
51
+
52
+ Valid is defined as either NoneType or str with ASCII characters which may
53
+ contain lowercase and uppercase letters, digits, underscores, periods,
54
+ and dashes. Must start and end with alphanumeric characters.
55
+ No triple dashes or underscores.
56
+
57
+ Examples:
58
+ some_name_here
59
+ some-name-here
60
+ some__name__here
61
+ some--name--here
62
+ some__name--here
63
+ some.name.here
64
+ some-name_he.re
65
+ this---shouldnt--work
66
+ this___shouldnt_work
67
+ _thisshouldntwork
68
+ thisshouldntwork_
69
+ """
70
+ if name is None:
71
+ return False
72
+ return bool(re.fullmatch(_VALID_NAME_REGEX, name))
73
+
74
+
75
+ def _fill_in_env_vars(
76
+ yaml_field: Dict[str, Any],
77
+ task_envs: Dict[str, str],
78
+ ) -> Dict[str, Any]:
79
+ """Detects env vars in yaml field and fills them with task_envs.
80
+
81
+ Use cases of env vars in file_mounts:
82
+ - dst/src paths; e.g.,
83
+ /model_path/llama-${SIZE}b: s3://llama-weights/llama-${SIZE}b
84
+ - storage's name (bucket name)
85
+ - storage's source (local path)
86
+
87
+ Use cases of env vars in service:
88
+ - model type; e.g.,
89
+ service:
90
+ readiness_probe:
91
+ path: /v1/chat/completions
92
+ post_data:
93
+ model: $MODEL_NAME
94
+ messages:
95
+ - role: user
96
+ content: How to print hello world?
97
+ max_tokens: 1
98
+
99
+ We simply dump yaml_field into a json string, and replace env vars using
100
+ regex. This should be safe as yaml config has been schema-validated.
101
+
102
+ Env vars of the following forms are detected:
103
+ - ${ENV}
104
+ - $ENV
105
+ where <ENV> must appear in task.envs.
106
+ """
107
+ yaml_field_str = json.dumps(yaml_field)
108
+
109
+ def replace_var(match):
110
+ var_name = match.group(1)
111
+ # If the variable isn't in the dictionary, return it unchanged
112
+ return task_envs.get(var_name, match.group(0))
113
+
114
+ # Pattern for valid env var names in bash.
115
+ pattern = r'\$\{?\b([a-zA-Z_][a-zA-Z0-9_]*)\b\}?'
116
+ yaml_field_str = re.sub(pattern, replace_var, yaml_field_str)
117
+ return json.loads(yaml_field_str)
118
+
119
+
120
+ class Task:
121
+ """Task: a computation to be run on the cloud."""
122
+
123
+ def __init__(
124
+ self,
125
+ name: str,
126
+ *,
127
+ setup: Optional[str] = None,
128
+ run: Optional[str] = None,
129
+ envs: Optional[Dict[str, str]] = None,
130
+ workdir: Optional[str] = None,
131
+ num_nodes: Optional[int] = None,
132
+ ):
133
+ """Initializes a Task.
134
+
135
+ All fields are optional. ``Task.run`` is the actual program: either a
136
+ shell command to run (str) or a command generator for different nodes
137
+ (lambda; see below).
138
+
139
+ Optionally, call ``Task.set_resources()`` to set the resource
140
+ requirements for this task. If not set, a default CPU-only requirement
141
+ is assumed (the same as ``konduktor launch``).
142
+
143
+ All setters of this class, ``Task.set_*()``, return ``self``, i.e.,
144
+ they are fluent APIs and can be chained together.
145
+
146
+ Example:
147
+ .. code-block:: python
148
+
149
+ # A Task that will sync up local workdir '.', containing
150
+ # requirements.txt and train.py.
151
+ konduktor.Task(setup='pip install requirements.txt',
152
+ run='python train.py',
153
+ workdir='.')
154
+
155
+ # An empty Task for provisioning a cluster.
156
+ task = konduktor.Task(num_nodes=n).set_resources(...)
157
+
158
+ # Chaining setters.
159
+ konduktor.Task().set_resources(...).set_file_mounts(...)
160
+
161
+ Args:
162
+ name: A string name for the Task for display purposes.
163
+ run: The actual command for the task. If not None, either a shell
164
+ command (str) or a command generator (callable). If latter, it
165
+ must take a node rank and a list of node addresses as input and
166
+ return a shell command (str) (valid to return None for some nodes,
167
+ in which case no commands are run on them). Run commands will be
168
+ run under ``workdir``. Note the command generator should be a
169
+ self-contained lambda.
170
+ envs: A dictionary of environment variables to set before running the
171
+ setup and run commands.
172
+ workdir: The local working directory. This directory will be synced
173
+ to a location on the remote VM(s), and ``setup`` and ``run``
174
+ commands will be run under that location (thus, they can rely on
175
+ relative paths when invoking binaries).
176
+ num_nodes: The number of nodes to provision for this Task. If None,
177
+ treated as 1 node. If > 1, each node will execute its own
178
+ setup/run command, where ``run`` can either be a str, meaning all
179
+ nodes get the same command, or a lambda, with the semantics
180
+ documented above.
181
+ """
182
+ assert name is not None, 'Task name is required'
183
+ self.name = name
184
+ if setup is not None:
185
+ raise ValueError('`setup` is being deprecated and not supported')
186
+ self.run = run
187
+ self.storage_mounts: Dict[str, storage_lib.Storage] = {}
188
+ self.storage_plans: Dict[storage_lib.Storage, storage_lib.StoreType] = {}
189
+ self._envs = envs or {}
190
+ self.workdir = workdir
191
+ # Ignore type error due to a mypy bug.
192
+ # https://github.com/python/mypy/issues/3004
193
+ self._num_nodes = 1
194
+ self.num_nodes = num_nodes # type: ignore
195
+ # Default to CPU VM
196
+ self.resources: Optional[konduktor.Resources] = None
197
+
198
+ self.file_mounts: Optional[Dict[str, str]] = None
199
+ self.best_resources = None # (asaiacai): this is unused consider removing
200
+
201
+ # Check if the task is legal.
202
+ self._validate()
203
+
204
+ def _validate(self):
205
+ """Checks if the Task fields are valid."""
206
+
207
+ # TODO(asaiacai): add validations here to check that valid
208
+ # kueue resources are specified as labels
209
+ if not _is_valid_name(self.name):
210
+ with ux_utils.print_exception_no_traceback():
211
+ raise ValueError(
212
+ f'Invalid task name {self.name}. Valid name: '
213
+ f'{_VALID_NAME_DESCR}'
214
+ )
215
+
216
+ # Check self.run
217
+ if callable(self.run):
218
+ run_sig = inspect.signature(self.run)
219
+ # Check that run is a function with 2 arguments.
220
+ if len(run_sig.parameters) != 2:
221
+ with ux_utils.print_exception_no_traceback():
222
+ raise ValueError(_RUN_FN_CHECK_FAIL_MSG.format(run_sig))
223
+
224
+ type_list = [int, List[str]]
225
+ # Check annotations, if exists
226
+ for i, param in enumerate(run_sig.parameters.values()):
227
+ if param.annotation != inspect.Parameter.empty:
228
+ if param.annotation != type_list[i]:
229
+ with ux_utils.print_exception_no_traceback():
230
+ raise ValueError(_RUN_FN_CHECK_FAIL_MSG.format(run_sig))
231
+
232
+ # Check self containedness.
233
+ run_closure = inspect.getclosurevars(self.run)
234
+ if run_closure.nonlocals:
235
+ with ux_utils.print_exception_no_traceback():
236
+ raise ValueError(
237
+ 'run command generator must be self contained. '
238
+ f'Found nonlocals: {run_closure.nonlocals}'
239
+ )
240
+ if run_closure.globals:
241
+ with ux_utils.print_exception_no_traceback():
242
+ raise ValueError(
243
+ 'run command generator must be self contained. '
244
+ f'Found globals: {run_closure.globals}'
245
+ )
246
+ if run_closure.unbound:
247
+ # Do not raise an error here. Import statements, which are
248
+ # allowed, will be considered as unbounded.
249
+ pass
250
+ elif self.run is not None and not isinstance(self.run, str):
251
+ with ux_utils.print_exception_no_traceback():
252
+ raise ValueError(
253
+ 'run must be a shell script (str), ' f'Got {type(self.run)}'
254
+ )
255
+ elif self.run is None:
256
+ with ux_utils.print_exception_no_traceback():
257
+ raise ValueError('run commands are empty')
258
+
259
+ # Workdir.
260
+ if self.workdir is not None:
261
+ full_workdir = os.path.abspath(os.path.expanduser(self.workdir))
262
+ if not os.path.isdir(full_workdir):
263
+ # Symlink to a dir is legal (isdir() follows symlinks).
264
+ with ux_utils.print_exception_no_traceback():
265
+ raise ValueError(
266
+ 'Workdir must exist and must be a directory (or '
267
+ f'a symlink to a directory). {self.workdir} not found.'
268
+ )
269
+
270
+ @staticmethod
271
+ def from_yaml_config(
272
+ config: Dict[str, Any],
273
+ env_overrides: Optional[List[Tuple[str, str]]] = None,
274
+ ) -> 'Task':
275
+ # More robust handling for 'envs': explicitly convert keys and values to
276
+ # str, since users may pass '123' as keys/values which will get parsed
277
+ # as int causing validate_schema() to fail.
278
+ envs = config.get('envs')
279
+ if envs is not None and isinstance(envs, dict):
280
+ new_envs: Dict[str, Optional[str]] = {}
281
+ for k, v in envs.items():
282
+ if v is not None:
283
+ new_envs[str(k)] = str(v)
284
+ else:
285
+ new_envs[str(k)] = None
286
+ config['envs'] = new_envs
287
+ common_utils.validate_schema(
288
+ config, schemas.get_task_schema(), 'Invalid task YAML: '
289
+ )
290
+ if env_overrides is not None:
291
+ # We must override env vars before constructing the Task, because
292
+ # the Storage object creation is eager and it (its name/source
293
+ # fields) may depend on env vars.
294
+ new_envs = config.get('envs', {})
295
+ new_envs.update(env_overrides)
296
+ config['envs'] = new_envs
297
+
298
+ for k, v in config.get('envs', {}).items():
299
+ if v is None:
300
+ with ux_utils.print_exception_no_traceback():
301
+ raise ValueError(
302
+ f'Environment variable {k!r} is None. Please set a '
303
+ 'value for it in task YAML or with --env flag. '
304
+ f'To set it to be empty, use an empty string ({k}: "" '
305
+ f'in task YAML or --env {k}="" in CLI).'
306
+ )
307
+
308
+ # Fill in any Task.envs into file_mounts (src/dst paths, storage
309
+ # name/source).
310
+ if config.get('file_mounts') is not None:
311
+ config['file_mounts'] = _fill_in_env_vars(
312
+ config['file_mounts'], config.get('envs', {})
313
+ )
314
+
315
+ # Fill in any Task.envs into workdir
316
+ if config.get('workdir') is not None:
317
+ config['workdir'] = _fill_in_env_vars(
318
+ config['workdir'], config.get('envs', {})
319
+ )
320
+
321
+ task = Task(
322
+ config.pop('name', None),
323
+ run=config.pop('run', None),
324
+ workdir=config.pop('workdir', None),
325
+ num_nodes=config.pop('num_nodes', None),
326
+ envs=config.pop('envs', None),
327
+ )
328
+
329
+ # Create lists to store storage objects inlined in file_mounts.
330
+ # These are retained in dicts in the YAML schema and later parsed to
331
+ # storage objects with the storage/storage_mount objects.
332
+ fm_storages = []
333
+ file_mounts = config.pop('file_mounts', None)
334
+ if file_mounts is not None:
335
+ copy_mounts = {}
336
+ for dst_path, src in file_mounts.items():
337
+ # Check if it is str path
338
+ if isinstance(src, str):
339
+ copy_mounts[dst_path] = src
340
+ # If the src is not a str path, it is likely a dict. Try to
341
+ # parse storage object.
342
+ elif isinstance(src, dict):
343
+ fm_storages.append((dst_path, src))
344
+ else:
345
+ with ux_utils.print_exception_no_traceback():
346
+ raise ValueError(
347
+ f'Unable to parse file_mount ' f'{dst_path}:{src}'
348
+ )
349
+ task.set_file_mounts(copy_mounts)
350
+
351
+ task_storage_mounts: Dict[str, storage_lib.Storage] = {}
352
+ all_storages = fm_storages
353
+ for storage in all_storages:
354
+ mount_path = storage[0]
355
+ assert mount_path, 'Storage mount path cannot be empty.'
356
+ try:
357
+ storage_obj = storage_lib.Storage.from_yaml_config(storage[1])
358
+ except exceptions.StorageSourceError as e:
359
+ # Patch the error message to include the mount path, if included
360
+ e.args = (
361
+ e.args[0].replace('<destination_path>', mount_path),
362
+ ) + e.args[1:]
363
+ raise e
364
+ task_storage_mounts[mount_path] = storage_obj
365
+ task.set_storage_mounts(task_storage_mounts)
366
+
367
+ # Experimental configs.
368
+ experimnetal_configs = config.pop('experimental', None)
369
+ cluster_config_override = None
370
+ if experimnetal_configs is not None:
371
+ cluster_config_override = experimnetal_configs.pop('config_overrides', None)
372
+ logger.debug(
373
+ 'Overriding konduktor config with task-level config: '
374
+ f'{cluster_config_override}'
375
+ )
376
+ assert not experimnetal_configs, (
377
+ 'Invalid task args: ' f'{experimnetal_configs.keys()}'
378
+ )
379
+
380
+ # Parse resources field.
381
+ resources_config = config.pop('resources', {})
382
+ if cluster_config_override is not None:
383
+ assert resources_config.get('_cluster_config_overrides') is None, (
384
+ 'Cannot set _cluster_config_overrides in both resources and '
385
+ 'experimental.config_overrides'
386
+ )
387
+ resources_config['_cluster_config_overrides'] = cluster_config_override
388
+ task.set_resources(konduktor.Resources.from_yaml_config(resources_config))
389
+ assert not config, f'Invalid task args: {config.keys()}'
390
+ return task
391
+
392
+ @staticmethod
393
+ def from_yaml(yaml_path: str) -> 'Task':
394
+ """Initializes a task from a task YAML.
395
+
396
+ Example:
397
+ .. code-block:: python
398
+
399
+ task = konduktor.Task.from_yaml('/path/to/task.yaml')
400
+
401
+ Args:
402
+ yaml_path: file path to a valid task yaml file.
403
+
404
+ Raises:
405
+ ValueError: if the path gets loaded into a str instead of a dict; or
406
+ if there are any other parsing errors.
407
+ """
408
+ with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f:
409
+ # https://github.com/yaml/pyyaml/issues/165#issuecomment-430074049
410
+ # to raise errors on duplicate keys.
411
+ config = yaml.safe_load(f)
412
+
413
+ if isinstance(config, str):
414
+ with ux_utils.print_exception_no_traceback():
415
+ raise ValueError(
416
+ 'YAML loaded as str, not as dict. '
417
+ f'Is it correct? Path: {yaml_path}'
418
+ )
419
+
420
+ if config is None:
421
+ config = {}
422
+ return Task.from_yaml_config(config)
423
+
424
+ @property
425
+ def num_nodes(self) -> int:
426
+ return self._num_nodes
427
+
428
+ @num_nodes.setter
429
+ def num_nodes(self, num_nodes: Optional[int]) -> None:
430
+ if num_nodes is None:
431
+ num_nodes = 1
432
+ if not isinstance(num_nodes, int) or num_nodes <= 0:
433
+ with ux_utils.print_exception_no_traceback():
434
+ raise ValueError(
435
+ f'num_nodes should be a positive int. Got: {num_nodes}'
436
+ )
437
+ self._num_nodes = num_nodes
438
+
439
+ @property
440
+ def envs(self) -> Dict[str, str]:
441
+ return self._envs
442
+
443
+ def update_envs(
444
+ self, envs: Union[None, List[Tuple[str, str]], Dict[str, str]]
445
+ ) -> 'Task':
446
+ """Updates environment variables for use inside the setup/run commands.
447
+
448
+ Args:
449
+ envs: (optional) either a list of ``(env_name, value)`` or a dict
450
+ ``{env_name: value}``.
451
+
452
+ Returns:
453
+ self: The current task, with envs updated.
454
+
455
+ Raises:
456
+ ValueError: if various invalid inputs errors are detected.
457
+ """
458
+ if envs is None:
459
+ envs = {}
460
+ if isinstance(envs, (list, tuple)):
461
+ keys = set(env[0] for env in envs)
462
+ if len(keys) != len(envs):
463
+ with ux_utils.print_exception_no_traceback():
464
+ raise ValueError('Duplicate env keys provided.')
465
+ envs = dict(envs)
466
+ if isinstance(envs, dict):
467
+ for key in envs:
468
+ if not isinstance(key, str):
469
+ with ux_utils.print_exception_no_traceback():
470
+ raise ValueError('Env keys must be strings.')
471
+ if not common_utils.is_valid_env_var(key):
472
+ with ux_utils.print_exception_no_traceback():
473
+ raise ValueError(f'Invalid env key: {key}')
474
+ else:
475
+ with ux_utils.print_exception_no_traceback():
476
+ raise ValueError(
477
+ 'envs must be List[Tuple[str, str]] or Dict[str, str]: ' f'{envs}'
478
+ )
479
+ self._envs.update(envs)
480
+ return self
481
+
482
+ def set_resources(
483
+ self,
484
+ resources: 'resources_lib.Resources',
485
+ ) -> 'Task':
486
+ """Sets the required resources to execute this task.
487
+
488
+ If this function is not called for a Task, default resource
489
+ requirements will be used (8 vCPUs).
490
+
491
+ Args:
492
+ resources: either a konduktor.Resources, a set of them, or a list of them.
493
+ A set or a list of resources asks the optimizer to "pick the
494
+ best of these resources" to run this task.
495
+ Returns:
496
+ self: The current task, with resources set.
497
+ """
498
+ if isinstance(resources, konduktor.Resources):
499
+ resources = resources
500
+ self.resources = resources
501
+
502
+ # TODO(asaiacai): we're only going to support COPY for now.
503
+ # MOUNT is not supported. Evaluate if the task requires
504
+ # FUSE and set the requires_fuse flag
505
+ for _, storage_obj in self.storage_mounts.items():
506
+ if storage_obj.mode not in storage_lib.StorageMode:
507
+ with ux_utils.print_exception_no_traceback():
508
+ raise ValueError(
509
+ f'Storage Type {storage_obj.mode} '
510
+ 'not supported. '
511
+ f'Only {", ".join([mode.name for mode in storage_lib.StorageMode])} ' # noqa: E501
512
+ 'is supported!'
513
+ )
514
+ # if storage_obj.mode == storage_lib.StorageMode.MOUNT:
515
+ # for r in self.resources:
516
+ # r.requires_fuse = True
517
+ # break
518
+
519
+ return self
520
+
521
+ def set_resources_override(self, override_params: Dict[str, Any]) -> 'Task':
522
+ """Sets the override parameters for the resources."""
523
+ assert self.resources is not None, 'Resources are required'
524
+ new_resources = self.resources.copy(**override_params)
525
+
526
+ self.set_resources(new_resources)
527
+ return self
528
+
529
+ def set_file_mounts(self, file_mounts: Optional[Dict[str, str]]) -> 'Task':
530
+ """Sets the file mounts for this task.
531
+
532
+ Useful for syncing datasets, dotfiles, etc.
533
+
534
+ File mounts are a dictionary: ``{remote_path: local_path/cloud URI}``.
535
+ Local (or cloud) files/directories will be synced to the specified
536
+ paths on the remote VM(s) where this Task will run.
537
+
538
+ Neither source or destimation paths can end with a slash.
539
+
540
+ Example:
541
+ .. code-block:: python
542
+
543
+ task.set_file_mounts({
544
+ '~/.dotfile': '/local/.dotfile',
545
+ # /remote/dir/ will contain the contents of /local/dir/.
546
+ '/remote/dir': '/local/dir',
547
+ })
548
+
549
+ Args:
550
+ file_mounts: an optional dict of ``{remote_path: local_path/cloud
551
+ URI}``, where remote means the VM(s) on which this Task will
552
+ eventually run on, and local means the node from which the task is
553
+ launched.
554
+
555
+ Returns:
556
+ self: the current task, with file mounts set.
557
+
558
+ Raises:
559
+ ValueError: if input paths are invalid.
560
+ """
561
+ if file_mounts is None:
562
+ self.file_mounts = None
563
+ return self
564
+ for target, source in file_mounts.items():
565
+ if target.endswith('/') or source.endswith('/'):
566
+ with ux_utils.print_exception_no_traceback():
567
+ raise ValueError(
568
+ 'File mount paths cannot end with a slash '
569
+ '(try "/mydir: /mydir" or "/myfile: /myfile"). '
570
+ f'Found: target={target} source={source}'
571
+ )
572
+ if data_utils.is_cloud_store_url(target):
573
+ with ux_utils.print_exception_no_traceback():
574
+ raise ValueError(
575
+ 'File mount destination paths cannot be cloud storage'
576
+ )
577
+ if not data_utils.is_cloud_store_url(source):
578
+ if not os.path.exists(
579
+ os.path.abspath(os.path.expanduser(source))
580
+ ) and not source.startswith('konduktor:'):
581
+ with ux_utils.print_exception_no_traceback():
582
+ raise ValueError(
583
+ f'File mount source {source!r} does not exist '
584
+ 'locally. To fix: check if it exists, and correct '
585
+ 'the path.'
586
+ )
587
+ if (
588
+ target == constants.KONDUKTOR_REMOTE_WORKDIR
589
+ and self.workdir is not None
590
+ ):
591
+ with ux_utils.print_exception_no_traceback():
592
+ raise ValueError(
593
+ f'Cannot use {constants.KONDUKTOR_REMOTE_WORKDIR!r} as a '
594
+ 'destination path of a file mount, as it will be used '
595
+ 'by the workdir. If uploading a file/folder to the '
596
+ 'workdir is needed, please specify the full path to '
597
+ 'the file/folder.'
598
+ )
599
+
600
+ self.file_mounts = file_mounts
601
+ return self
602
+
603
+ def _get_preferred_store(self) -> Tuple[storage_lib.StoreType, Optional[str]]:
604
+ """Returns the preferred store type and region for this task."""
605
+ storage_cloud = None
606
+
607
+ enabled_storage_clouds = (
608
+ storage_lib.get_cached_enabled_storage_clouds_or_refresh(
609
+ raise_if_no_cloud_access=True
610
+ )
611
+ )
612
+ resources = self.resources
613
+ if resources is not None:
614
+ storage_cloud = resources.cloud
615
+ else:
616
+ storage_cloud = None
617
+
618
+ if storage_cloud is not None:
619
+ if str(storage_cloud) not in enabled_storage_clouds:
620
+ storage_cloud = None
621
+
622
+ storage_cloud_str = None
623
+ if storage_cloud is None:
624
+ storage_cloud_str = enabled_storage_clouds[0]
625
+ assert storage_cloud_str is not None, enabled_storage_clouds[0]
626
+ storage_region = None # Use default region in the Store class
627
+ else:
628
+ storage_cloud_str = str(storage_cloud)
629
+
630
+ store_type = storage_lib.StoreType.from_cloud(storage_cloud_str)
631
+ return store_type, storage_region
632
+
633
+ def sync_storage_mounts(self) -> None:
634
+ """(INTERNAL) Eagerly syncs storage mounts to cloud storage.
635
+
636
+ After syncing up, COPY-mode storage mounts are translated into regular
637
+ file_mounts of the form ``{ /remote/path: {s3,gs,..}://<bucket path>
638
+ }``. For local file mounts, we first sync all local paths from
639
+ `workdir` and `file_mounts` to the cloud storage.
640
+ """
641
+ for storage in self.storage_mounts.values():
642
+ if len(storage.stores) == 0:
643
+ store_type, store_region = self._get_preferred_store()
644
+ self.storage_plans[storage] = store_type
645
+ storage.add_store(store_type, store_region)
646
+ else:
647
+ # We will download the first store that is added to remote.
648
+ self.storage_plans[storage] = list(storage.stores.keys())[0]
649
+
650
+ storage_mounts = self.storage_mounts
651
+ storage_plans = self.storage_plans
652
+ for mnt_path, storage in storage_mounts.items():
653
+ if storage.mode == storage_lib.StorageMode.COPY:
654
+ store_type = storage_plans[storage]
655
+ # TODO(asaiacai): add back other stores here
656
+ elif store_type is storage_lib.StoreType.GCS:
657
+ if isinstance(storage.source, str) and storage.source.startswith(
658
+ 'gs://'
659
+ ):
660
+ blob_path = storage.source
661
+ else:
662
+ assert storage.name is not None, storage
663
+ blob_path = 'gs://' + storage.name
664
+ self.update_file_mounts(
665
+ {
666
+ mnt_path: blob_path,
667
+ }
668
+ )
669
+ else:
670
+ with ux_utils.print_exception_no_traceback():
671
+ raise ValueError(f'Storage Type {store_type} ' 'does not exist!')
672
+
673
+ def update_storage_mounts(
674
+ self, storage_mounts: Dict[str, storage_lib.Storage]
675
+ ) -> 'Task':
676
+ """Updates the storage mounts for this task.
677
+
678
+ Different from set_storage_mounts(), this function updates into the
679
+ existing storage_mounts (calls ``dict.update()``), rather than
680
+ overwriting it.
681
+
682
+ This should be called before provisioning in order to take effect.
683
+
684
+ Args:
685
+ storage_mounts: an optional dict of ``{mount_path: sky.Storage
686
+ object}``, where mount_path is the path inside the remote VM(s)
687
+ where the Storage object will be mounted on.
688
+
689
+ Returns:
690
+ self: The current task, with storage mounts updated.
691
+
692
+ Raises:
693
+ ValueError: if input paths are invalid.
694
+ """
695
+ if not storage_mounts:
696
+ return self
697
+ task_storage_mounts = self.storage_mounts if self.storage_mounts else {}
698
+ task_storage_mounts.update(storage_mounts)
699
+ return self.set_storage_mounts(task_storage_mounts)
700
+
701
+ def update_file_mounts(self, file_mounts: Dict[str, str]) -> 'Task':
702
+ """Updates the file mounts for this task.
703
+
704
+ Different from set_file_mounts(), this function updates into the
705
+ existing file_mounts (calls ``dict.update()``), rather than
706
+ overwritting it.
707
+
708
+ This should be called before provisioning in order to take effect.
709
+
710
+ Example:
711
+ .. code-block:: python
712
+
713
+ task.update_file_mounts({
714
+ '~/.config': '~/Documents/config',
715
+ '/tmp/workdir': '/local/workdir/cnn-cifar10',
716
+ })
717
+
718
+ Args:
719
+ file_mounts: a dict of ``{remote_path: local_path/cloud URI}``, where
720
+ remote means the VM(s) on which this Task will eventually run on,
721
+ and local means the node from which the task is launched.
722
+
723
+ Returns:
724
+ self: the current task, with file mounts updated.
725
+
726
+ Raises:
727
+ ValueError: if input paths are invalid.
728
+ """
729
+ if self.file_mounts is None:
730
+ self.file_mounts = {}
731
+ assert self.file_mounts is not None
732
+ self.file_mounts.update(file_mounts)
733
+ # For validation logic:
734
+ return self.set_file_mounts(self.file_mounts)
735
+
736
+ def set_storage_mounts(
737
+ self,
738
+ storage_mounts: Optional[Dict[str, storage_lib.Storage]],
739
+ ) -> 'Task':
740
+ """Sets the storage mounts for this task.
741
+
742
+ Storage mounts are a dictionary: ``{mount_path: sky.Storage object}``,
743
+ each of which mounts a sky.Storage object (a cloud object store bucket)
744
+ to a path inside the remote cluster.
745
+
746
+ A sky.Storage object can be created by uploading from a local directory
747
+ (setting ``source``), or backed by an existing cloud bucket (setting
748
+ ``name`` to the bucket name; or setting ``source`` to the bucket URI).
749
+
750
+ Example:
751
+ .. code-block:: python
752
+
753
+ task.set_storage_mounts({
754
+ '/remote/imagenet/': sky.Storage(name='my-bucket',
755
+ source='/local/imagenet'),
756
+ })
757
+
758
+ Args:
759
+ storage_mounts: an optional dict of ``{mount_path: sky.Storage
760
+ object}``, where mount_path is the path inside the remote VM(s)
761
+ where the Storage object will be mounted on.
762
+
763
+ Returns:
764
+ self: The current task, with storage mounts set.
765
+
766
+ Raises:
767
+ ValueError: if input paths are invalid.
768
+ """
769
+ if storage_mounts is None:
770
+ self.storage_mounts = {}
771
+ # Clear the requires_fuse flag if no storage mounts are set.
772
+ assert self.resources is not None, 'Task resources are required'
773
+ return self
774
+ for target, storage_obj in storage_mounts.items():
775
+ # TODO(zhwu): /home/username/sky_workdir as the target path need
776
+ # to be filtered out as well.
777
+ if (
778
+ target == constants.KONDUKTOR_REMOTE_WORKDIR
779
+ and self.workdir is not None
780
+ ):
781
+ with ux_utils.print_exception_no_traceback():
782
+ raise ValueError(
783
+ f'Cannot use {constants.KONDUKTOR_REMOTE_WORKDIR!r} as a '
784
+ 'destination path of a file mount, as it will be used '
785
+ 'by the workdir. If uploading a file/folder to the '
786
+ 'workdir is needed, please specify the full path to '
787
+ 'the file/folder.'
788
+ )
789
+
790
+ if data_utils.is_cloud_store_url(target):
791
+ with ux_utils.print_exception_no_traceback():
792
+ raise ValueError(
793
+ 'Storage mount destination path cannot be cloud storage'
794
+ )
795
+
796
+ assert (
797
+ storage_obj.mode == storage_lib.StorageMode.COPY
798
+ ), 'Only COPY mode is supported for storage mounts'
799
+ # TODO(asaiacai): can decide if we want to just delete this
800
+ # if storage_obj.mode == storage_lib.StorageMode.MOUNT:
801
+ # # If any storage is using MOUNT mode, we need to enable FUSE in
802
+ # # the resources.
803
+ # for r in self.resources:
804
+ # r.requires_fuse = True
805
+
806
+ # Storage source validation is done in Storage object
807
+ self.storage_mounts = storage_mounts
808
+ return self
809
+
810
+ def get_local_to_remote_file_mounts(self) -> Optional[Dict[str, str]]:
811
+ """Returns file mounts of the form (dst=VM path, src=local path).
812
+
813
+ Any cloud object store URIs (gs://, s3://, etc.), either as source or
814
+ destination, are not included.
815
+
816
+ INTERNAL: this method is internal-facing.
817
+ """
818
+ if self.file_mounts is None:
819
+ return None
820
+ d = {}
821
+ for k, v in self.file_mounts.items():
822
+ if not data_utils.is_cloud_store_url(
823
+ k
824
+ ) and not data_utils.is_cloud_store_url(v):
825
+ d[k] = v
826
+ return d
827
+
828
+ def to_yaml_config(self) -> Dict[str, Any]:
829
+ """Returns a yaml-style dict representation of the task.
830
+
831
+ INTERNAL: this method is internal-facing.
832
+ """
833
+ config = {}
834
+
835
+ def add_if_not_none(key, value, no_empty: bool = False):
836
+ if no_empty and not value:
837
+ return
838
+ if value is not None:
839
+ config[key] = value
840
+
841
+ add_if_not_none('name', self.name)
842
+
843
+ tmp_resource_config: Dict[str, Any] = {}
844
+ assert self.resources is not None, 'Resources are not defined'
845
+ tmp_resource_config = self.resources.to_yaml_config()
846
+
847
+ add_if_not_none('resources', tmp_resource_config)
848
+
849
+ add_if_not_none('num_nodes', self.num_nodes)
850
+
851
+ add_if_not_none('run', self.run)
852
+ add_if_not_none('workdir', self.workdir)
853
+ add_if_not_none('envs', self.envs, no_empty=True)
854
+
855
+ add_if_not_none('file_mounts', {})
856
+
857
+ if self.file_mounts is not None:
858
+ config['file_mounts'].update(self.file_mounts)
859
+
860
+ if self.storage_mounts is not None:
861
+ config['file_mounts'].update(
862
+ {
863
+ mount_path: storage.to_yaml_config()
864
+ for mount_path, storage in self.storage_mounts.items()
865
+ }
866
+ )
867
+ return config