wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,32 @@
1
+ """Implementation of KubernetesRunner class for wandb launch."""
2
+
1
3
  import base64
2
4
  import json
3
5
  import logging
4
6
  import time
5
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
6
8
 
7
9
  import wandb
8
10
  from wandb.apis.internal import Api
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
9
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
10
13
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
11
14
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
15
+ from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
12
16
  from wandb.sdk.launch.registry.local_registry import LocalRegistry
17
+ from wandb.sdk.launch.runner.abstract import State, Status
13
18
  from wandb.util import get_module
14
19
 
15
20
  from .._project_spec import EntryPoint, LaunchProject
16
21
  from ..builder.build import get_env_vars_dict
22
+ from ..errors import LaunchError
17
23
  from ..utils import (
18
24
  LOG_PREFIX,
19
25
  PROJECT_SYNCHRONOUS,
20
- LaunchError,
21
26
  get_kube_context_and_api_client,
22
27
  make_name_dns_safe,
23
28
  )
24
- from .abstract import AbstractRun, AbstractRunner, Status
29
+ from .abstract import AbstractRun, AbstractRunner
25
30
 
26
31
  get_module(
27
32
  "kubernetes",
@@ -31,8 +36,12 @@ get_module(
31
36
  from kubernetes import client # type: ignore # noqa: E402
32
37
  from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
33
38
  from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
39
+ from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
40
+ CustomObjectsApi,
41
+ )
34
42
  from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
35
43
  from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
44
+ from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
36
45
 
37
46
  TIMEOUT = 5
38
47
  MAX_KUBERNETES_RETRIES = (
@@ -43,7 +52,22 @@ FAIL_MESSAGE_INTERVAL = 60
43
52
  _logger = logging.getLogger(__name__)
44
53
 
45
54
 
55
+ # Dict for mapping possible states of custom objects to the states we want to report
56
+ # to the agent.
57
+ CRD_STATE_DICT: Dict[str, State] = {
58
+ "pending": "starting",
59
+ "running": "running",
60
+ "completed": "finished",
61
+ "failed": "failed",
62
+ "aborted": "failed",
63
+ "terminating": "stopping",
64
+ "terminated": "stopped",
65
+ }
66
+
67
+
46
68
  class KubernetesSubmittedRun(AbstractRun):
69
+ """Wrapper for a launched run on Kubernetes."""
70
+
47
71
  def __init__(
48
72
  self,
49
73
  batch_api: "BatchV1Api",
@@ -53,6 +77,19 @@ class KubernetesSubmittedRun(AbstractRun):
53
77
  namespace: Optional[str] = "default",
54
78
  secret: Optional["V1Secret"] = None,
55
79
  ) -> None:
80
+ """Initialize a KubernetesSubmittedRun.
81
+
82
+ Arguments:
83
+ batch_api: Kubernetes BatchV1Api object.
84
+ core_api: Kubernetes CoreV1Api object.
85
+ name: Name of the job.
86
+ pod_names: List of pod names.
87
+ namespace: Kubernetes namespace.
88
+ secret: Kubernetes secret.
89
+
90
+ Returns:
91
+ None.
92
+ """
56
93
  self.batch_api = batch_api
57
94
  self.core_api = core_api
58
95
  self.name = name
@@ -66,14 +103,37 @@ class KubernetesSubmittedRun(AbstractRun):
66
103
 
67
104
  @property
68
105
  def id(self) -> str:
106
+ """Return the run id."""
69
107
  return self.name
70
108
 
109
+ def get_logs(self) -> Optional[str]:
110
+ try:
111
+ logs = self.core_api.read_namespaced_pod_log(
112
+ name=self.pod_names[0], namespace=self.namespace
113
+ )
114
+ if logs:
115
+ return str(logs)
116
+ else:
117
+ wandb.termwarn(
118
+ f"Retrieved no logs for kubernetes pod(s): {self.pod_names}"
119
+ )
120
+ return None
121
+ except Exception as e:
122
+ wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}")
123
+ return None
124
+
71
125
  def get_job(self) -> "V1Job":
126
+ """Return the job object."""
72
127
  return self.batch_api.read_namespaced_job(
73
128
  name=self.name, namespace=self.namespace
74
129
  )
75
130
 
76
131
  def wait(self) -> bool:
132
+ """Wait for the run to finish.
133
+
134
+ Returns:
135
+ True if the run finished successfully, False otherwise.
136
+ """
77
137
  while True:
78
138
  status = self.get_status()
79
139
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
@@ -85,14 +145,23 @@ class KubernetesSubmittedRun(AbstractRun):
85
145
  ) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
86
146
 
87
147
  def get_status(self) -> Status:
88
- job_response = self.batch_api.read_namespaced_job_status(
89
- name=self.name, namespace=self.namespace
90
- )
91
- status = job_response.status
148
+ """Return the run status."""
149
+ try:
150
+ job_response = self.batch_api.read_namespaced_job_status(
151
+ name=self.name, namespace=self.namespace
152
+ )
153
+ status = job_response.status
154
+
155
+ pod = self.core_api.read_namespaced_pod(
156
+ name=self.pod_names[0], namespace=self.namespace
157
+ )
158
+ except ApiException as e:
159
+ if "(404)" not in str(e):
160
+ raise
161
+ # 404 = Pod/job not reachable
162
+ wandb.termlog(f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}")
163
+ return Status("preempted")
92
164
 
93
- pod = self.core_api.read_namespaced_pod(
94
- name=self.pod_names[0], namespace=self.namespace
95
- )
96
165
  if pod.status.phase in ["Pending", "Unknown"]:
97
166
  now = time.time()
98
167
  if self._fail_count == 0:
@@ -111,7 +180,13 @@ class KubernetesSubmittedRun(AbstractRun):
111
180
  if status.succeeded == 1:
112
181
  return_status = Status("finished")
113
182
  elif status.failed is not None and status.failed >= 1:
114
- return_status = Status("failed")
183
+ if status.conditions[0].reason == "BackoffLimitExceeded":
184
+ wandb.termlog(
185
+ f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}"
186
+ )
187
+ return_status = Status("preempted")
188
+ else:
189
+ return_status = Status("failed")
115
190
  elif status.active == 1:
116
191
  return Status("running")
117
192
  elif status.conditions is not None and status.conditions[0].type == "Suspended":
@@ -133,6 +208,7 @@ class KubernetesSubmittedRun(AbstractRun):
133
208
  return return_status
134
209
 
135
210
  def suspend(self) -> None:
211
+ """Suspend the run."""
136
212
  self.job.spec.suspend = True
137
213
  self.batch_api.patch_namespaced_job(
138
214
  name=self.name, namespace=self.namespace, body=self.job
@@ -156,29 +232,183 @@ class KubernetesSubmittedRun(AbstractRun):
156
232
  )
157
233
 
158
234
  def cancel(self) -> None:
235
+ """Cancel the run."""
159
236
  self.suspend()
160
237
  self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
161
238
 
162
239
 
240
+ class CrdSubmittedRun(AbstractRun):
241
+ """Run submitted to a CRD backend, e.g. Volcano."""
242
+
243
+ def __init__(
244
+ self,
245
+ group: str,
246
+ version: str,
247
+ plural: str,
248
+ name: str,
249
+ namespace: str,
250
+ core_api: CoreV1Api,
251
+ custom_api: CustomObjectsApi,
252
+ pod_names: List[str],
253
+ ) -> None:
254
+ """Create a run object for tracking the progress of a CRD.
255
+
256
+ Arguments:
257
+ group: The API group of the CRD.
258
+ version: The API version of the CRD.
259
+ plural: The plural name of the CRD.
260
+ name: The name of the CRD instance.
261
+ namespace: The namespace of the CRD instance.
262
+ core_api: The Kubernetes core API client.
263
+ custom_api: The Kubernetes custom object API client.
264
+ pod_names: The names of the pods associated with the CRD instance.
265
+
266
+ Raises:
267
+ LaunchError: If the CRD instance does not exist.
268
+ """
269
+ self.group = group
270
+ self.version = version
271
+ self.plural = plural
272
+ self.name = name
273
+ self.namespace = namespace
274
+ self.core_api = core_api
275
+ self.custom_api = custom_api
276
+ self.pod_names = pod_names
277
+ self._fail_count = 0
278
+ try:
279
+ self.job = self.custom_api.get_namespaced_custom_object(
280
+ group=self.group,
281
+ version=self.version,
282
+ namespace=self.namespace,
283
+ plural=self.plural,
284
+ name=self.name,
285
+ )
286
+ except ApiException as e:
287
+ raise LaunchError(
288
+ f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
289
+ ) from e
290
+
291
+ @property
292
+ def id(self) -> str:
293
+ """Get the name of the custom object."""
294
+ return self.name
295
+
296
+ def get_logs(self) -> Optional[str]:
297
+ """Get logs for custom object."""
298
+ # TODO: test more carefully once we release multi-node support
299
+ logs: Dict[str, Optional[str]] = {}
300
+ try:
301
+ for pod_name in self.pod_names:
302
+ logs[pod_name] = self.core_api.read_namespaced_pod_log(
303
+ name=pod_name, namespace=self.namespace
304
+ )
305
+ except ApiException as e:
306
+ wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}")
307
+ return None
308
+ if not logs:
309
+ return None
310
+ logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()]
311
+ return "\n".join(logs_as_array)
312
+
313
+ def get_status(self) -> Status:
314
+ """Get status of custom object."""
315
+ try:
316
+ job_response = self.custom_api.get_namespaced_custom_object_status(
317
+ group=self.group,
318
+ version=self.version,
319
+ namespace=self.namespace,
320
+ plural=self.plural,
321
+ name=self.name,
322
+ )
323
+ except ApiException as e:
324
+ raise LaunchError(
325
+ f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
326
+ ) from e
327
+ # Custom objects can technically define whater states and format the
328
+ # response to the status request however they want. This checks for
329
+ # the most common cases.
330
+ status = job_response["status"]
331
+ state = status.get("state")
332
+ if isinstance(state, dict):
333
+ state = state.get("phase")
334
+ if state is None:
335
+ raise LaunchError(
336
+ f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
337
+ )
338
+ return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
339
+
340
+ def cancel(self) -> None:
341
+ """Cancel the custom object."""
342
+ try:
343
+ self.custom_api.delete_namespaced_custom_object(
344
+ group=self.group,
345
+ version=self.version,
346
+ namespace=self.namespace,
347
+ plural=self.plural,
348
+ name=self.name,
349
+ )
350
+ except ApiException as e:
351
+ raise LaunchError(
352
+ f"Failed to delete CRD {self.name} in namespace {self.namespace}: {str(e)}"
353
+ ) from e
354
+
355
+ def wait(self) -> bool:
356
+ """Wait for this custom object to finish running."""
357
+ while True:
358
+ status = self.get_status()
359
+ wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
360
+ if status.state != "running":
361
+ break
362
+ time.sleep(5)
363
+ return status.state == "finished"
364
+
365
+
163
366
  class KubernetesRunner(AbstractRunner):
367
+ """Launches runs onto kubernetes."""
368
+
164
369
  def __init__(
165
370
  self, api: Api, backend_config: Dict[str, Any], environment: AbstractEnvironment
166
371
  ) -> None:
372
+ """Create a Kubernetes runner.
373
+
374
+ Arguments:
375
+ api: The API client object.
376
+ backend_config: The backend configuration.
377
+ environment: The environment to launch runs into.
378
+
379
+ Raises:
380
+ LaunchError: If the Kubernetes configuration is invalid.
381
+ """
167
382
  super().__init__(api, backend_config)
168
383
  self.environment = environment
169
384
 
170
385
  def wait_job_launch(
171
- self, job_name: str, namespace: str, core_api: "CoreV1Api"
386
+ self,
387
+ job_name: str,
388
+ namespace: str,
389
+ core_api: "CoreV1Api",
390
+ label: str = "job-name",
172
391
  ) -> List[str]:
392
+ """Wait for a job to be launched and return the pod names.
393
+
394
+ Arguments:
395
+ job_name: The name of the job.
396
+ namespace: The namespace of the job.
397
+ core_api: The Kubernetes core API client.
398
+ label: The label key to match against job_name.
399
+
400
+ Returns:
401
+ The names of the pods associated with the job.
402
+ """
173
403
  pods = core_api.list_namespaced_pod(
174
- label_selector=f"job-name={job_name}", namespace=namespace
404
+ label_selector=f"{label}={job_name}", namespace=namespace
175
405
  )
176
406
  timeout = TIMEOUT
177
407
  while len(pods.items) == 0 and timeout > 0:
178
408
  time.sleep(1)
179
409
  timeout -= 1
180
410
  pods = core_api.list_namespaced_pod(
181
- label_selector=f"job-name={job_name}", namespace=namespace
411
+ label_selector=f"{label}={job_name}", namespace=namespace
182
412
  )
183
413
 
184
414
  if timeout == 0:
@@ -197,6 +427,15 @@ class KubernetesRunner(AbstractRunner):
197
427
  def get_namespace(
198
428
  self, resource_args: Dict[str, Any], context: Dict[str, Any]
199
429
  ) -> str:
430
+ """Get the namespace to launch into.
431
+
432
+ Arguments:
433
+ resource_args: The resource args to launch.
434
+ context: The k8s config context.
435
+
436
+ Returns:
437
+ The namespace to launch into.
438
+ """
200
439
  default_namespace = (
201
440
  context["context"].get("namespace", "default") if context else "default"
202
441
  )
@@ -213,8 +452,20 @@ class KubernetesRunner(AbstractRunner):
213
452
  builder: Optional[AbstractBuilder],
214
453
  namespace: str,
215
454
  core_api: "CoreV1Api",
455
+ job_tracker: Optional[JobAndRunStatusTracker],
216
456
  ) -> Tuple[Dict[str, Any], Optional["V1Secret"]]:
217
- """Apply our default values, return job dict and secret."""
457
+ """Apply our default values, return job dict and secret.
458
+
459
+ Arguments:
460
+ resource_args (Dict[str, Any]): The resource args to launch.
461
+ launch_project (LaunchProject): The launch project.
462
+ builder (Optional[AbstractBuilder]): The builder.
463
+ namespace (str): The namespace.
464
+ core_api (CoreV1Api): The core api.
465
+
466
+ Returns:
467
+ Tuple[Dict[str, Any], Optional["V1Secret"]]: The resource args and secret.
468
+ """
218
469
  job: Dict[str, Any] = {
219
470
  "apiVersion": "batch/v1",
220
471
  "kind": "Job",
@@ -253,7 +504,9 @@ class KubernetesRunner(AbstractRunner):
253
504
  "Invalid specification of multiple containers. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
254
505
  )
255
506
  # dont specify run id if user provided image, could have multiple runs
256
- containers[0]["image"] = launch_project.docker_image
507
+ image_uri = launch_project.docker_image
508
+ containers[0]["image"] = image_uri
509
+ launch_project.fill_macros(image_uri)
257
510
  # TODO: handle secret pulling image from registry
258
511
  elif not any(["image" in cont for cont in containers]):
259
512
  if len(containers) > 1:
@@ -262,7 +515,9 @@ class KubernetesRunner(AbstractRunner):
262
515
  )
263
516
  assert entry_point is not None
264
517
  assert builder is not None
265
- image_uri = builder.build_image(launch_project, entry_point)
518
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
519
+ image_uri = image_uri.replace("https://", "")
520
+ launch_project.fill_macros(image_uri)
266
521
  # in the non instance case we need to make an imagePullSecret
267
522
  # so the new job can pull the image
268
523
  if not builder.registry:
@@ -276,8 +531,8 @@ class KubernetesRunner(AbstractRunner):
276
531
  pod_spec["imagePullSecrets"] = [
277
532
  {"name": f"regcred-{launch_project.run_id}"}
278
533
  ]
279
-
280
534
  containers[0]["image"] = image_uri
535
+ launch_project.fill_macros(image_uri)
281
536
 
282
537
  inject_entrypoint_and_args(
283
538
  containers,
@@ -306,8 +561,18 @@ class KubernetesRunner(AbstractRunner):
306
561
  def run(
307
562
  self,
308
563
  launch_project: LaunchProject,
309
- builder: Optional[AbstractBuilder],
564
+ builder: AbstractBuilder,
565
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
310
566
  ) -> Optional[AbstractRun]: # noqa: C901
567
+ """Execute a launch project on Kubernetes.
568
+
569
+ Arguments:
570
+ launch_project: The launch project to execute.
571
+ builder: The builder to use to build the image.
572
+
573
+ Returns:
574
+ The run object if the run was successful, otherwise None.
575
+ """
311
576
  kubernetes = get_module( # noqa: F811
312
577
  "kubernetes",
313
578
  required="Kubernetes runner requires the kubernetes package. Please"
@@ -316,23 +581,86 @@ class KubernetesRunner(AbstractRunner):
316
581
  resource_args = launch_project.resource_args.get("kubernetes", {})
317
582
  if not resource_args:
318
583
  wandb.termlog(
319
- f"{LOG_PREFIX}Note: no resource args specified. Add a Kubernetes yaml spec or other options in a json file with --resource-args <json>."
584
+ f"{LOG_PREFIX}Note: no resource args specified. Add a "
585
+ "Kubernetes yaml spec or other options in a json file "
586
+ "with --resource-args <json>."
320
587
  )
321
588
  _logger.info(f"Running Kubernetes job with resource args: {resource_args}")
322
589
 
323
590
  context, api_client = get_kube_context_and_api_client(kubernetes, resource_args)
324
591
 
592
+ # If the user specified an alternate api, we need will execute this
593
+ # run by creating a custom object.
594
+ api_version = resource_args.get("apiVersion", "batch/v1")
595
+ if api_version not in ["batch/v1", "batch/v1beta1"]:
596
+ entrypoint = launch_project.get_single_entry_point()
597
+ if launch_project.docker_image:
598
+ image_uri = launch_project.docker_image
599
+ else:
600
+ assert entrypoint is not None
601
+ image_uri = builder.build_image(launch_project, entrypoint, job_tracker)
602
+ launch_project.fill_macros(image_uri)
603
+ env_vars = get_env_vars_dict(launch_project, self._api)
604
+ # Crawl the resource args and add our env vars to the containers.
605
+ add_wandb_env(launch_project.resource_args, env_vars)
606
+ # Crawl the resource arsg and add our labels to the pods. This is
607
+ # necessary for the agent to find the pods later on.
608
+ add_label_to_pods(
609
+ launch_project.resource_args, "wandb/run-id", launch_project.run_id
610
+ )
611
+ overrides = {}
612
+ if launch_project.override_args:
613
+ overrides["args"] = launch_project.override_args
614
+ if launch_project.override_entrypoint:
615
+ overrides["command"] = launch_project.override_entrypoint.command
616
+ add_entrypoint_args_overrides(
617
+ launch_project.resource_args,
618
+ overrides,
619
+ )
620
+ api = client.CustomObjectsApi(api_client)
621
+ # Infer the attributes of a custom object from the apiVersion and/or
622
+ # a kind: attribute in the resource args.
623
+ namespace = self.get_namespace(resource_args, context)
624
+ group = resource_args.get("group", api_version.split("/")[0])
625
+ version = api_version.split("/")[1]
626
+ kind = resource_args.get("kind", version)
627
+ plural = f"{kind.lower()}s"
628
+ try:
629
+ response = api.create_namespaced_custom_object(
630
+ group=group,
631
+ version=version,
632
+ namespace=namespace,
633
+ plural=plural,
634
+ body=launch_project.resource_args.get("kubernetes"),
635
+ )
636
+ except ApiException as e:
637
+ raise LaunchError(
638
+ f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
639
+ ) from e
640
+ name = response.get("metadata", {}).get("name")
641
+ _logger.info(f"Created {kind} {response['metadata']['name']}")
642
+ core = client.CoreV1Api(api_client)
643
+ pod_names = self.wait_job_launch(
644
+ launch_project.run_id, namespace, core, label="wandb/run-id"
645
+ )
646
+ return CrdSubmittedRun(
647
+ name=name,
648
+ group=group,
649
+ version=version,
650
+ namespace=namespace,
651
+ plural=plural,
652
+ core_api=client.CoreV1Api(api_client),
653
+ custom_api=api,
654
+ pod_names=pod_names,
655
+ )
656
+
325
657
  batch_api = kubernetes.client.BatchV1Api(api_client)
326
658
  core_api = kubernetes.client.CoreV1Api(api_client)
327
659
 
328
660
  namespace = self.get_namespace(resource_args, context)
329
661
 
330
662
  job, secret = self._inject_defaults(
331
- resource_args,
332
- launch_project,
333
- builder,
334
- namespace,
335
- core_api,
663
+ resource_args, launch_project, builder, namespace, core_api, job_tracker
336
664
  )
337
665
 
338
666
  msg = "Creating Kubernetes job"
@@ -364,6 +692,17 @@ def inject_entrypoint_and_args(
364
692
  override_args: List[str],
365
693
  should_override_entrypoint: bool,
366
694
  ) -> None:
695
+ """Inject the entrypoint and args into the containers.
696
+
697
+ Arguments:
698
+ containers: The containers to inject the entrypoint and args into.
699
+ entry_point: The entrypoint to inject.
700
+ override_args: The args to inject.
701
+ should_override_entrypoint: Whether to override the entrypoint.
702
+
703
+ Returns:
704
+ None
705
+ """
367
706
  for i in range(len(containers)):
368
707
  if override_args:
369
708
  containers[i]["args"] = override_args
@@ -379,8 +718,21 @@ def maybe_create_imagepull_secret(
379
718
  run_id: str,
380
719
  namespace: str,
381
720
  ) -> Optional["V1Secret"]:
721
+ """Create a secret for pulling images from a private registry.
722
+
723
+ Arguments:
724
+ core_api: The Kubernetes CoreV1Api object.
725
+ registry: The registry to pull from.
726
+ run_id: The run id.
727
+ namespace: The namespace to create the secret in.
728
+
729
+ Returns:
730
+ A secret if one was created, otherwise None.
731
+ """
382
732
  secret = None
383
- if isinstance(registry, LocalRegistry):
733
+ if isinstance(registry, LocalRegistry) or isinstance(
734
+ registry, AzureContainerRegistry
735
+ ):
384
736
  # Secret not required
385
737
  return None
386
738
  uname, token = registry.get_username_password()
@@ -406,3 +758,104 @@ def maybe_create_imagepull_secret(
406
758
  return core_api.create_namespaced_secret(namespace, secret)
407
759
  except Exception as e:
408
760
  raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
761
+
762
+
763
+ def add_wandb_env(root: Union[dict, list], env_vars: Dict[str, str]) -> None:
764
+ """Injects wandb environment variables into specs.
765
+
766
+ Recursively walks the spec and injects the environment variables into
767
+ every container spec. Containers are identified by the "containers" key.
768
+
769
+ This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables
770
+ specially. If they are present in the spec, they will be overwritten. If a setting
771
+ for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be
772
+ set in the first container modified by this function.
773
+
774
+ Arguments:
775
+ root: The spec to modify.
776
+ env_vars: The environment variables to inject.
777
+
778
+ Returns: None.
779
+ """
780
+
781
+ def yield_containers(root: Any) -> Iterator[dict]:
782
+ if isinstance(root, dict):
783
+ for k, v in root.items():
784
+ if k == "containers":
785
+ if isinstance(v, list):
786
+ yield from v
787
+ elif isinstance(v, (dict, list)):
788
+ yield from yield_containers(v)
789
+ elif isinstance(root, list):
790
+ for item in root:
791
+ yield from yield_containers(item)
792
+
793
+ for cont in yield_containers(root):
794
+ env = cont.setdefault("env", [])
795
+ env.extend([{"name": key, "value": value} for key, value in env_vars.items()])
796
+ cont["env"] = env
797
+ # After we have set WANDB_RUN_ID once, we don't want to set it again
798
+ if "WANDB_RUN_ID" in env_vars:
799
+ env_vars.pop("WANDB_RUN_ID")
800
+
801
+
802
+ def add_label_to_pods(
803
+ manifest: Union[dict, list], label_key: str, label_value: str
804
+ ) -> None:
805
+ """Add a label to all pod specs in a manifest.
806
+
807
+ Recursively traverses the manifest and adds the label to all pod specs.
808
+ Pod specs are identified by the presence of a "spec" key with a "containers"
809
+ key in the value.
810
+
811
+ Arguments:
812
+ manifest: The manifest to modify.
813
+ label_key: The label key to add.
814
+ label_value: The label value to add.
815
+
816
+ Returns: None.
817
+ """
818
+
819
+ def yield_pods(manifest: Any) -> Iterator[dict]:
820
+ if isinstance(manifest, list):
821
+ for item in manifest:
822
+ yield from yield_pods(item)
823
+ elif isinstance(manifest, dict):
824
+ if "spec" in manifest and "containers" in manifest["spec"]:
825
+ yield manifest
826
+ for value in manifest.values():
827
+ if isinstance(value, (dict, list)):
828
+ yield from yield_pods(value)
829
+
830
+ for pod in yield_pods(manifest):
831
+ metadata = pod.setdefault("metadata", {})
832
+ labels = metadata.setdefault("labels", {})
833
+ labels[label_key] = label_value
834
+
835
+
836
+ def add_entrypoint_args_overrides(manifest: Union[dict, list], overrides: dict) -> None:
837
+ """Add entrypoint and args overrides to all containers in a manifest.
838
+
839
+ Recursively traverses the manifest and adds the entrypoint and args overrides
840
+ to all containers. Containers are identified by the presence of a "spec" key
841
+ with a "containers" key in the value.
842
+
843
+ Arguments:
844
+ manifest: The manifest to modify.
845
+ overrides: Dictionary with args and entrypoint keys.
846
+
847
+ Returns: None.
848
+ """
849
+ if isinstance(manifest, list):
850
+ for item in manifest:
851
+ add_entrypoint_args_overrides(item, overrides)
852
+ elif isinstance(manifest, dict):
853
+ if "spec" in manifest and "containers" in manifest["spec"]:
854
+ containers = manifest["spec"]["containers"]
855
+ for container in containers:
856
+ if "command" in overrides:
857
+ container["command"] = overrides["command"]
858
+ if "args" in overrides:
859
+ container["args"] = overrides["args"]
860
+ for value in manifest.values():
861
+ add_entrypoint_args_overrides(value, overrides)