konduktor-nightly 0.1.0.dev20250209104336__py3-none-any.whl → 0.1.0.dev20250313070642__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. konduktor/__init__.py +16 -6
  2. konduktor/adaptors/__init__.py +0 -0
  3. konduktor/adaptors/common.py +88 -0
  4. konduktor/adaptors/gcp.py +112 -0
  5. konduktor/backends/__init__.py +8 -0
  6. konduktor/backends/backend.py +86 -0
  7. konduktor/backends/jobset.py +218 -0
  8. konduktor/backends/jobset_utils.py +447 -0
  9. konduktor/check.py +192 -0
  10. konduktor/cli.py +790 -0
  11. konduktor/cloud_stores.py +158 -0
  12. konduktor/config.py +420 -0
  13. konduktor/constants.py +36 -0
  14. konduktor/controller/constants.py +6 -6
  15. konduktor/controller/launch.py +3 -3
  16. konduktor/controller/node.py +5 -5
  17. konduktor/controller/parse.py +23 -23
  18. konduktor/dashboard/backend/main.py +57 -57
  19. konduktor/dashboard/backend/sockets.py +19 -19
  20. konduktor/data/__init__.py +9 -0
  21. konduktor/data/constants.py +12 -0
  22. konduktor/data/data_utils.py +223 -0
  23. konduktor/data/gcp/__init__.py +19 -0
  24. konduktor/data/gcp/constants.py +42 -0
  25. konduktor/data/gcp/gcs.py +906 -0
  26. konduktor/data/gcp/utils.py +9 -0
  27. konduktor/data/storage.py +799 -0
  28. konduktor/data/storage_utils.py +500 -0
  29. konduktor/execution.py +444 -0
  30. konduktor/kube_client.py +153 -48
  31. konduktor/logging.py +49 -5
  32. konduktor/manifests/dmesg_daemonset.yaml +8 -0
  33. konduktor/manifests/pod_cleanup_controller.yaml +129 -0
  34. konduktor/resource.py +478 -0
  35. konduktor/task.py +867 -0
  36. konduktor/templates/jobset.yaml.j2 +31 -0
  37. konduktor/templates/pod.yaml.j2 +185 -0
  38. konduktor/usage/__init__.py +0 -0
  39. konduktor/usage/constants.py +21 -0
  40. konduktor/utils/__init__.py +0 -0
  41. konduktor/utils/accelerator_registry.py +21 -0
  42. konduktor/utils/annotations.py +62 -0
  43. konduktor/utils/base64_utils.py +93 -0
  44. konduktor/utils/common_utils.py +393 -0
  45. konduktor/utils/constants.py +5 -0
  46. konduktor/utils/env_options.py +55 -0
  47. konduktor/utils/exceptions.py +226 -0
  48. konduktor/utils/kubernetes_enums.py +8 -0
  49. konduktor/utils/kubernetes_utils.py +652 -0
  50. konduktor/utils/log_utils.py +251 -0
  51. konduktor/utils/loki_utils.py +85 -0
  52. konduktor/utils/rich_utils.py +123 -0
  53. konduktor/utils/schemas.py +581 -0
  54. konduktor/utils/subprocess_utils.py +273 -0
  55. konduktor/utils/ux_utils.py +216 -0
  56. konduktor/utils/validator.py +20 -0
  57. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/LICENSE +0 -1
  58. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/METADATA +13 -2
  59. konduktor_nightly-0.1.0.dev20250313070642.dist-info/RECORD +94 -0
  60. konduktor_nightly-0.1.0.dev20250209104336.dist-info/RECORD +0 -48
  61. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/WHEEL +0 -0
  62. {konduktor_nightly-0.1.0.dev20250209104336.dist-info → konduktor_nightly-0.1.0.dev20250313070642.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,447 @@
1
+ """Jobset utils: wraps CRUD operations for jobsets"""
2
+
3
+ import enum
4
+ import json
5
+ import os
6
+ import tempfile
7
+ import typing
8
+ from datetime import datetime, timezone
9
+ from typing import Any, Dict, Optional
10
+ from urllib.parse import urlparse
11
+
12
+ import colorama
13
+
14
+ import konduktor
15
+ from konduktor import cloud_stores, constants, kube_client, logging
16
+ from konduktor.utils import common_utils, kubernetes_utils, log_utils
17
+
18
+ if typing.TYPE_CHECKING:
19
+ pass
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ JOBSET_API_GROUP = 'jobset.x-k8s.io'
24
+ JOBSET_API_VERSION = 'v1alpha2'
25
+ JOBSET_PLURAL = 'jobsets'
26
+
27
+ JOBSET_NAME_LABEL = 'trainy.ai/job-name'
28
+ JOBSET_USERID_LABEL = 'trainy.ai/user-id'
29
+ JOBSET_USER_LABEL = 'trainy.ai/username'
30
+ JOBSET_ACCELERATOR_LABEL = 'trainy.ai/accelerator'
31
+ JOBSET_NUM_ACCELERATORS_LABEL = 'trainy.ai/num-accelerators'
32
+
33
+ _JOBSET_METADATA_LABELS = {
34
+ 'jobset_name_label': JOBSET_NAME_LABEL,
35
+ 'jobset_userid_label': JOBSET_USERID_LABEL,
36
+ 'jobset_user_label': JOBSET_USER_LABEL,
37
+ 'jobset_accelerator_label': JOBSET_ACCELERATOR_LABEL,
38
+ 'jobset_num_accelerators_label': JOBSET_NUM_ACCELERATORS_LABEL,
39
+ }
40
+
41
+
42
+ class JobNotFoundError(Exception):
43
+ pass
44
+
45
+
46
+ class JobStatus(enum.Enum):
47
+ SUSPENDED = 'SUSPENDED'
48
+ ACTIVE = 'ACTIVE'
49
+ COMPLETED = 'COMPLETED'
50
+ FAILED = 'FAILED'
51
+ PENDING = 'PENDING'
52
+
53
+
54
+ if typing.TYPE_CHECKING:
55
+ import konduktor
56
+
57
+
58
+ def create_pod_spec(task: 'konduktor.Task') -> Dict[str, Any]:
59
+ """Merges the task defintion with config
60
+ to create a final pod spec dict for the job
61
+
62
+ Returns:
63
+ Dict[str, Any]: k8s pod spec
64
+ """
65
+
66
+ # fill out the templating variables
67
+ assert task.resources is not None, 'Task resources are required'
68
+ if task.resources.accelerators:
69
+ num_gpus = list(task.resources.accelerators.values())[0]
70
+ else:
71
+ num_gpus = 0
72
+ task.name = f'{task.name}-{common_utils.get_usage_run_id()[:4]}'
73
+ node_hostnames = ','.join(
74
+ [f'{task.name}-workers-0-{idx}.{task.name}' for idx in range(task.num_nodes)]
75
+ )
76
+ master_addr = f'{task.name}-workers-0-0.{task.name}'
77
+
78
+ if task.resources.accelerators:
79
+ accelerator_type = list(task.resources.accelerators.keys())[0]
80
+ else:
81
+ accelerator_type = None
82
+
83
+ # template the commands to run on the container for syncing files. At this point
84
+ # task.stores is Dict[str, storage_utils.Storage] which is (dst, storage_obj_src)
85
+ # first we iterate through storage_mounts and then file_mounts.
86
+ sync_commands = []
87
+ mkdir_commands = []
88
+ storage_secrets = {}
89
+
90
+ # first do storage_mount sync
91
+ for dst, store in task.storage_mounts.items():
92
+ # TODO(asaiacai) idk why but theres an extra storage mount for the
93
+ # file mounts. Should be cleaned up eventually in
94
+ # maybe_translate_local_file_mounts_and_sync_up
95
+ assert store.source is not None and isinstance(
96
+ store.source, str
97
+ ), 'Store source is required'
98
+ store_scheme = urlparse(store.source).scheme
99
+ if '/tmp/konduktor-job-filemounts-files' in dst:
100
+ continue
101
+ # should impelement a method here instead of raw dog dict access
102
+ cloud_store = cloud_stores._REGISTRY[store_scheme]
103
+ storage_secrets[store_scheme] = cloud_store._STORE.get_k8s_credential_name()
104
+ mkdir_commands.append(
105
+ f'cd {constants.KONDUKTOR_REMOTE_WORKDIR};' f'mkdir -p {dst}'
106
+ )
107
+ assert store._bucket_sub_path is not None
108
+ sync_commands.append(
109
+ cloud_store.make_sync_dir_command(
110
+ os.path.join(store.source, store._bucket_sub_path), dst
111
+ )
112
+ )
113
+
114
+ # then do file_mount sync.
115
+ assert task.file_mounts is not None
116
+ for dst, src in task.file_mounts.items():
117
+ store_scheme = str(urlparse(store.source).scheme)
118
+ cloud_store = cloud_stores._REGISTRY[store_scheme]
119
+ mkdir_commands.append(
120
+ f'cd {constants.KONDUKTOR_REMOTE_WORKDIR};'
121
+ f'mkdir -p {os.path.dirname(dst)}'
122
+ )
123
+ storage_secrets[store_scheme] = cloud_store._STORE.get_k8s_credential_name()
124
+ sync_commands.append(cloud_store.make_sync_file_command(src, dst))
125
+
126
+ assert task.resources is not None, 'Task resources are required'
127
+ assert task.resources.cpus is not None, 'Task resources cpus are required'
128
+ assert task.resources.memory is not None, 'Task resources memory are required'
129
+ assert task.resources.image_id is not None, 'Task resources image_id are required'
130
+ with tempfile.NamedTemporaryFile() as temp:
131
+ common_utils.fill_template(
132
+ 'pod.yaml.j2',
133
+ {
134
+ # TODO(asaiacai) need to parse/round these numbers and sanity check
135
+ 'cpu': kubernetes_utils.parse_cpu_or_gpu_resource(task.resources.cpus),
136
+ 'memory': kubernetes_utils.parse_memory_resource(task.resources.memory),
137
+ 'image_id': task.resources.image_id,
138
+ 'num_gpus': num_gpus,
139
+ 'master_addr': master_addr,
140
+ 'num_nodes': task.num_nodes,
141
+ 'job_name': task.name, # append timestamp and user id here?
142
+ 'run_cmd': task.run,
143
+ 'node_hostnames': node_hostnames,
144
+ 'accelerator_type': accelerator_type,
145
+ 'sync_commands': sync_commands,
146
+ 'mkdir_commands': mkdir_commands,
147
+ 'mount_secrets': storage_secrets,
148
+ 'remote_workdir': constants.KONDUKTOR_REMOTE_WORKDIR,
149
+ },
150
+ temp.name,
151
+ )
152
+ pod_config = common_utils.read_yaml(temp.name)
153
+ # merge with `~/.konduktor/config.yaml``
154
+ kubernetes_utils.combine_pod_config_fields(temp.name, pod_config)
155
+ pod_config = common_utils.read_yaml(temp.name)
156
+ for k, v in task.envs.items():
157
+ pod_config['kubernetes']['pod_config']['spec']['containers'][0][
158
+ 'env'
159
+ ].append({'name': k, 'value': v})
160
+
161
+ # TODO(asaiacai): have some schema validations. see
162
+ # https://github.com/skypilot-org/skypilot/pull/4466
163
+ # TODO(asaiacai): where can we include policies for the pod spec.
164
+
165
+ return pod_config
166
+
167
+
168
+ def create_jobset(
169
+ namespace: str,
170
+ task: 'konduktor.Task',
171
+ pod_spec: Dict[str, Any],
172
+ dryrun: bool = False,
173
+ ) -> Optional[Dict[str, Any]]:
174
+ """Creates a jobset based on the task definition and pod spec
175
+ and returns the created jobset spec
176
+ """
177
+ assert task.resources is not None, 'Task resources are undefined'
178
+ if task.resources.accelerators:
179
+ accelerator_type = list(task.resources.accelerators.keys())[0]
180
+ num_accelerators = list(task.resources.accelerators.values())[0]
181
+ else:
182
+ accelerator_type = 'None'
183
+ num_accelerators = 0
184
+ with tempfile.NamedTemporaryFile() as temp:
185
+ common_utils.fill_template(
186
+ 'jobset.yaml.j2',
187
+ {
188
+ 'job_name': task.name,
189
+ 'user_id': common_utils.user_and_hostname_hash(),
190
+ 'num_nodes': task.num_nodes,
191
+ 'user': common_utils.get_cleaned_username(),
192
+ 'accelerator_type': accelerator_type,
193
+ 'num_accelerators': num_accelerators,
194
+ **_JOBSET_METADATA_LABELS,
195
+ },
196
+ temp.name,
197
+ )
198
+ jobset_spec = common_utils.read_yaml(temp.name)
199
+ jobset_spec['jobset']['spec']['replicatedJobs'][0]['template']['spec'][
200
+ 'template'
201
+ ] = pod_spec # noqa: E501
202
+ try:
203
+ jobset = kube_client.crd_api().create_namespaced_custom_object(
204
+ group=JOBSET_API_GROUP,
205
+ version=JOBSET_API_VERSION,
206
+ namespace=namespace,
207
+ plural=JOBSET_PLURAL,
208
+ body=jobset_spec['jobset'],
209
+ dry_run='All' if dryrun else None,
210
+ )
211
+ logger.info(
212
+ f'task {colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
213
+ f'{task.name}{colorama.Style.RESET_ALL} created'
214
+ )
215
+ return jobset
216
+ except kube_client.api_exception() as err:
217
+ try:
218
+ error_body = json.loads(err.body)
219
+ error_message = error_body.get('message', '')
220
+ logger.error(f'error creating jobset: {error_message}')
221
+ except json.JSONDecodeError:
222
+ error_message = str(err.body)
223
+ logger.error(f'error creating jobset: {error_message}')
224
+ else:
225
+ # Re-raise the exception if it's a different error
226
+ raise err
227
+ return None
228
+
229
+
230
+ def list_jobset(namespace: str) -> Optional[Dict[str, Any]]:
231
+ """Lists all jobsets in this namespace"""
232
+ try:
233
+ response = kube_client.crd_api().list_namespaced_custom_object(
234
+ group=JOBSET_API_GROUP,
235
+ version=JOBSET_API_VERSION,
236
+ namespace=namespace,
237
+ plural=JOBSET_PLURAL,
238
+ )
239
+ return response
240
+ except kube_client.api_exception() as err:
241
+ try:
242
+ error_body = json.loads(err.body)
243
+ error_message = error_body.get('message', '')
244
+ logger.error(f'error listing jobset: {error_message}')
245
+ except json.JSONDecodeError:
246
+ error_message = str(err.body)
247
+ logger.error(f'error creating jobset: {error_message}')
248
+ else:
249
+ # Re-raise the exception if it's a different error
250
+ raise err
251
+ return None
252
+
253
+
254
+ def get_jobset(namespace: str, job_name: str) -> Optional[Dict[str, Any]]:
255
+ """Retrieves jobset in this namespace"""
256
+ try:
257
+ response = kube_client.crd_api().get_namespaced_custom_object(
258
+ group=JOBSET_API_GROUP,
259
+ version=JOBSET_API_VERSION,
260
+ namespace=namespace,
261
+ plural=JOBSET_PLURAL,
262
+ name=job_name,
263
+ )
264
+ return response
265
+ except kube_client.api_exception() as err:
266
+ if err.status == 404:
267
+ raise JobNotFoundError(
268
+ f"Jobset '{job_name}' " f"not found in namespace '{namespace}'."
269
+ )
270
+ try:
271
+ error_body = json.loads(err.body)
272
+ error_message = error_body.get('message', '')
273
+ logger.error(f'error getting jobset: {error_message}')
274
+ except json.JSONDecodeError:
275
+ error_message = str(err.body)
276
+ logger.error(f'error creating jobset: {error_message}')
277
+ else:
278
+ # Re-raise the exception if it's a different error
279
+ raise err
280
+ return None
281
+
282
+
283
+ def delete_jobset(namespace: str, job_name: str) -> Optional[Dict[str, Any]]:
284
+ """Deletes jobset in this namespace
285
+
286
+ Args:
287
+ namespace: Namespace where jobset exists
288
+ job_name: Name of jobset to delete
289
+
290
+ Returns:
291
+ Response from delete operation
292
+ """
293
+ try:
294
+ response = kube_client.crd_api().delete_namespaced_custom_object(
295
+ group=JOBSET_API_GROUP,
296
+ version=JOBSET_API_VERSION,
297
+ namespace=namespace,
298
+ plural=JOBSET_PLURAL,
299
+ name=job_name,
300
+ )
301
+ return response
302
+ except kube_client.api_exception() as err:
303
+ try:
304
+ error_body = json.loads(err.body)
305
+ error_message = error_body.get('message', '')
306
+ logger.error(f'error deleting jobset: {error_message}')
307
+ except json.JSONDecodeError:
308
+ error_message = str(err.body)
309
+ logger.error(f'error deleting jobset: {error_message}')
310
+ else:
311
+ # Re-raise the exception if it's a different error
312
+ raise err
313
+ return None
314
+
315
+
316
+ def get_job(namespace: str, job_name: str) -> Optional[Dict[str, Any]]:
317
+ """Gets a specific job from a jobset by name and worker index
318
+
319
+ Args:
320
+ namespace: Namespace where job exists
321
+ job_name: Name of jobset containing the job
322
+ worker_id: Index of the worker job to get (defaults to 0)
323
+
324
+ Returns:
325
+ Job object if found
326
+ """
327
+ try:
328
+ # Get the job object using the job name
329
+ # pattern {jobset-name}-workers-0-{worker_id}
330
+ job_name = f'{job_name}-workers-0'
331
+ response = kube_client.batch_api().read_namespaced_job(
332
+ name=job_name, namespace=namespace
333
+ )
334
+ return response
335
+ except kube_client.api_exception() as err:
336
+ try:
337
+ error_body = json.loads(err.body)
338
+ error_message = error_body.get('message', '')
339
+ logger.error(f'error getting job: {error_message}')
340
+ except json.JSONDecodeError:
341
+ error_message = str(err.body)
342
+ logger.error(f'error getting job: {error_message}')
343
+ else:
344
+ # Re-raise the exception if it's a different error
345
+ raise err
346
+ return None
347
+
348
+
349
+ def show_status_table(namespace: str, all_users: bool):
350
+ """Compute cluster table values and display.
351
+
352
+ Returns:
353
+ Number of pending auto{stop,down} clusters that are not already
354
+ STOPPED.
355
+ """
356
+ # TODO(zhwu): Update the information for autostop clusters.
357
+
358
+ def _get_status_string_colorized(status: Dict[str, Any]) -> str:
359
+ terminalState = status.get('terminalState', None)
360
+ if terminalState and terminalState.upper() == JobStatus.COMPLETED.name.upper():
361
+ return (
362
+ f'{colorama.Fore.GREEN}'
363
+ f'{JobStatus.COMPLETED.name}{colorama.Style.RESET_ALL}'
364
+ )
365
+ elif terminalState and terminalState.upper() == JobStatus.FAILED.name.upper():
366
+ return (
367
+ f'{colorama.Fore.RED}'
368
+ f'{JobStatus.FAILED.name}{colorama.Style.RESET_ALL}'
369
+ )
370
+ elif status['replicatedJobsStatus'][0]['active']:
371
+ return (
372
+ f'{colorama.Fore.CYAN}'
373
+ f'{JobStatus.ACTIVE.name}{colorama.Style.RESET_ALL}'
374
+ )
375
+ elif status['replicatedJobsStatus'][0]['suspended']:
376
+ return (
377
+ f'{colorama.Fore.GREEN}'
378
+ f'{JobStatus.SUSPENDED.name}{colorama.Style.RESET_ALL}'
379
+ )
380
+ else:
381
+ return (
382
+ f'{colorama.Fore.BLUE}'
383
+ f'{JobStatus.PENDING.name}{colorama.Style.RESET_ALL}'
384
+ )
385
+
386
+ def _get_time_delta(timestamp: str):
387
+ delta = datetime.now(timezone.utc) - datetime.strptime(
388
+ timestamp, '%Y-%m-%dT%H:%M:%SZ'
389
+ ).replace(tzinfo=timezone.utc)
390
+ total_seconds = int(delta.total_seconds())
391
+
392
+ days, remainder = divmod(total_seconds, 86400) # 86400 seconds in a day
393
+ hours, remainder = divmod(remainder, 3600) # 3600 seconds in an hour
394
+ minutes, _ = divmod(remainder, 60) # 60 seconds in a minute
395
+
396
+ days_str = f'{days} days, ' if days > 0 else ''
397
+ hours_str = f'{hours} hours, ' if hours > 0 else ''
398
+ minutes_str = f'{minutes} minutes' if minutes > 0 else ''
399
+
400
+ return f'{days_str}{hours_str}{minutes_str}'
401
+
402
+ def _get_resources(job: Dict[str, Any]) -> str:
403
+ num_pods = int(
404
+ job['spec']['replicatedJobs'][0]['template']['spec']['parallelism']
405
+ ) # noqa: E501
406
+ resources = job['spec']['replicatedJobs'][0]['template']['spec']['template'][
407
+ 'spec'
408
+ ]['containers'][0]['resources']['limits'] # noqa: E501
409
+ cpu, memory = resources['cpu'], resources['memory']
410
+ accelerator = job['metadata']['labels'].get(JOBSET_ACCELERATOR_LABEL, None)
411
+ if accelerator:
412
+ return f'{num_pods}x({cpu}CPU, memory {memory}, {accelerator})'
413
+ else:
414
+ return f'{num_pods}x({cpu}CPU, memory {memory}GB)'
415
+
416
+ if all_users:
417
+ columns = ['NAME', 'USER', 'STATUS', 'RESOURCES', 'SUBMITTED']
418
+ else:
419
+ columns = ['NAME', 'STATUS', 'RESOURCES', 'SUBMITTED']
420
+ job_table = log_utils.create_table(columns)
421
+ job_specs = list_jobset(namespace)
422
+ assert job_specs is not None, 'Retrieving jobs failed'
423
+ for job in job_specs['items']:
424
+ if all_users:
425
+ job_table.add_row(
426
+ [
427
+ job['metadata']['name'],
428
+ job['metadata']['labels'][JOBSET_USERID_LABEL],
429
+ _get_status_string_colorized(job['status']),
430
+ _get_resources(job),
431
+ _get_time_delta(job['metadata']['creationTimestamp']),
432
+ ]
433
+ )
434
+ elif (
435
+ not all_users
436
+ and job['metadata']['labels'][JOBSET_USER_LABEL]
437
+ == common_utils.get_cleaned_username()
438
+ ):
439
+ job_table.add_row(
440
+ [
441
+ job['metadata']['name'],
442
+ _get_status_string_colorized(job['status']),
443
+ _get_resources(job),
444
+ _get_time_delta(job['metadata']['creationTimestamp']),
445
+ ]
446
+ )
447
+ print(job_table)
konduktor/check.py ADDED
@@ -0,0 +1,192 @@
1
+ # Proprietary Changes made for Trainy under the Trainy Software License
2
+ # Original source: skypilot: https://github.com/skypilot-org/skypilot
3
+ # which is Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ """Credential checks: check cloud credentials and enable clouds.
14
+
15
+ Our architecture is client-server and requires that credentials are stored
16
+ as a secret in the cluster. This makes it so that cluster admins can just
17
+ deploy credentials (s3, gcs, r2) once to the namespace. Users then during job
18
+ use the secret stored for mounting credentials to pods. Users running must also
19
+ have the credentials present on their local machine, otherwise they won't be able to
20
+ upload files to object storage.
21
+
22
+ We have to check that the credentials are valid on the client side.
23
+ If the check fails, then we will attempt to check the credentials present on the client.
24
+ If these credentials are valid, we update the secret on the cluster, and
25
+ run the job as usual.
26
+ If these credentials are not valid, we fail the job and alert the user.
27
+
28
+ """
29
+
30
+ import traceback
31
+ import typing
32
+ from typing import Iterable, List, Optional, Tuple
33
+
34
+ import click
35
+ import colorama
36
+
37
+ from konduktor import cloud_stores, logging
38
+ from konduktor import config as konduktor_config
39
+ from konduktor.utils import rich_utils
40
+
41
+ if typing.TYPE_CHECKING:
42
+ pass
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ def check(
48
+ quiet: bool = False,
49
+ clouds: Optional[Iterable[str]] = None,
50
+ ) -> List[str]:
51
+ echo = (
52
+ (lambda *_args, **_kwargs: None)
53
+ if quiet
54
+ else lambda *args, **kwargs: click.echo(*args, **kwargs, color=True)
55
+ )
56
+ echo('Checking credentials to enable clouds storage for Konduktor.')
57
+ enabled_clouds = []
58
+ disabled_clouds = []
59
+
60
+ def check_one_cloud(
61
+ cloud_tuple: Tuple[str, 'cloud_stores.CloudStorage'],
62
+ ) -> None:
63
+ cloud_repr, cloud = cloud_tuple
64
+ with rich_utils.safe_status(f'Checking {cloud_repr}...'):
65
+ try:
66
+ logger.info(f'Checking {cloud_repr} local client credentials...')
67
+ ok, reason = cloud.check_credentials()
68
+ except Exception: # pylint: disable=broad-except
69
+ # Catch all exceptions to prevent a single cloud from blocking
70
+ # the check for other clouds.
71
+ ok, reason = False, traceback.format_exc()
72
+ status_msg = 'enabled' if ok else 'disabled'
73
+ styles = {'fg': 'green', 'bold': False} if ok else {'dim': True}
74
+ echo(' ' + click.style(f'{cloud_repr}: {status_msg}', **styles) + ' ' * 30)
75
+ if ok:
76
+ enabled_clouds.append(cloud_repr)
77
+ if reason is not None:
78
+ echo(f' Hint: {reason}')
79
+ else:
80
+ disabled_clouds.append(cloud_repr)
81
+ echo(f' Reason: {reason}')
82
+
83
+ def get_cloud_tuple(cloud_name: str) -> Tuple[str, 'cloud_stores.CloudStorage']:
84
+ # Validates cloud_name and returns a tuple of the cloud's name and
85
+ # the cloud object. Includes special handling for Cloudflare.
86
+ cloud_obj = cloud_stores._REGISTRY.get(cloud_name, None)
87
+ assert cloud_obj is not None, f'Cloud {cloud_name!r} not found'
88
+ return cloud_name, cloud_obj
89
+
90
+ def get_all_clouds():
91
+ return tuple([c for c in cloud_stores._REGISTRY.keys()])
92
+
93
+ if clouds is not None:
94
+ cloud_list = clouds
95
+ else:
96
+ cloud_list = get_all_clouds()
97
+ clouds_to_check = [get_cloud_tuple(c) for c in cloud_list]
98
+
99
+ # Use allowed_clouds from config if it exists, otherwise check all clouds.
100
+ # Also validate names with get_cloud_tuple.
101
+ config_allowed_cloud_names = [
102
+ c for c in konduktor_config.get_nested(('allowed_clouds',), get_all_clouds())
103
+ ]
104
+ # Use disallowed_cloud_names for logging the clouds that will be disabled
105
+ # because they are not included in allowed_clouds in config.yaml.
106
+ disallowed_cloud_names = [
107
+ c for c in get_all_clouds() if c not in config_allowed_cloud_names
108
+ ]
109
+ # Check only the clouds which are allowed in the config.
110
+ clouds_to_check = [c for c in clouds_to_check if c[0] in config_allowed_cloud_names]
111
+
112
+ for cloud_tuple in sorted(clouds_to_check):
113
+ check_one_cloud(cloud_tuple)
114
+
115
+ # Cloudflare is not a real cloud in registry.CLOUD_REGISTRY, and should
116
+ # not be inserted into the DB (otherwise `sky launch` and other code would
117
+ # error out when it's trying to look it up in the registry).
118
+ enabled_clouds_set = {
119
+ cloud for cloud in enabled_clouds if not cloud.startswith('Cloudflare')
120
+ }
121
+ disabled_clouds_set = {
122
+ cloud for cloud in disabled_clouds if not cloud.startswith('Cloudflare')
123
+ }
124
+
125
+ # Determine the set of enabled clouds: (previously enabled clouds + newly
126
+ # enabled clouds - newly disabled clouds) intersected with
127
+ # config_allowed_clouds, if specified in config.yaml.
128
+ # This means that if a cloud is already enabled and is not included in
129
+ # allowed_clouds in config.yaml, it will be disabled.
130
+ all_enabled_clouds = enabled_clouds_set - disabled_clouds_set
131
+
132
+ disallowed_clouds_hint = None
133
+ if disallowed_cloud_names:
134
+ disallowed_clouds_hint = (
135
+ '\nNote: The following clouds were disabled because they were not '
136
+ 'included in allowed_clouds in ~/.konduktor/config.yaml: '
137
+ f'{", ".join([c for c in disallowed_cloud_names])}'
138
+ )
139
+ if not all_enabled_clouds:
140
+ echo(
141
+ click.style(
142
+ 'No cloud is enabled. Konduktor will not be able to run any '
143
+ 'task. Run `konduktor check` for more info.',
144
+ fg='red',
145
+ bold=True,
146
+ )
147
+ )
148
+ if disallowed_clouds_hint:
149
+ echo(click.style(disallowed_clouds_hint, dim=True))
150
+ raise SystemExit()
151
+ else:
152
+ clouds_arg = ' ' + ' '.join(disabled_clouds) if clouds is not None else ''
153
+ echo(
154
+ click.style(
155
+ '\nTo enable a cloud, follow the hints above and rerun: ', dim=True
156
+ )
157
+ + click.style(f'konduktor check {clouds_arg}', bold=True)
158
+ + '\n'
159
+ + click.style(
160
+ 'If any problems remain, refer to detailed docs at: '
161
+ 'https://konduktor.readthedocs.io/en/latest/admin/installation.html', # pylint: disable=line-too-long
162
+ dim=True,
163
+ )
164
+ )
165
+
166
+ if disallowed_clouds_hint:
167
+ echo(click.style(disallowed_clouds_hint, dim=True))
168
+
169
+ # Pretty print for UX.
170
+ if not quiet:
171
+ enabled_clouds_str = '\n ' + '\n '.join(
172
+ [_format_enabled_storage(cloud) for cloud in sorted(all_enabled_clouds)]
173
+ )
174
+ echo(
175
+ f'\n{colorama.Fore.GREEN}{logging.PARTY_POPPER_EMOJI} '
176
+ f'Enabled clouds {logging.PARTY_POPPER_EMOJI}'
177
+ f'{colorama.Style.RESET_ALL}{enabled_clouds_str}'
178
+ )
179
+ return enabled_clouds
180
+
181
+
182
+ # === Helper functions ===
183
+ def storage_in_iterable(
184
+ cloud: 'cloud_stores.GcsCloudStorage',
185
+ cloud_list: Iterable['cloud_stores.GcsCloudStorage'],
186
+ ) -> bool:
187
+ """Returns whether the cloud is in the given cloud list."""
188
+ return any(cloud == c for c in cloud_list)
189
+
190
+
191
+ def _format_enabled_storage(cloud_name: str) -> str:
192
+ return f'{colorama.Fore.GREEN}{cloud_name}{colorama.Style.RESET_ALL}'