wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public.py +18 -20
  5. wandb/beta/workflows.py +5 -6
  6. wandb/cli/cli.py +27 -27
  7. wandb/data_types.py +2 -0
  8. wandb/integration/langchain/wandb_tracer.py +16 -179
  9. wandb/integration/sagemaker/config.py +2 -2
  10. wandb/integration/tensorboard/log.py +4 -4
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  13. wandb/proto/wandb_deprecated.py +3 -1
  14. wandb/sdk/__init__.py +1 -4
  15. wandb/sdk/artifacts/__init__.py +0 -14
  16. wandb/sdk/artifacts/artifact.py +1757 -277
  17. wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
  18. wandb/sdk/artifacts/artifact_state.py +10 -0
  19. wandb/sdk/artifacts/artifacts_cache.py +7 -8
  20. wandb/sdk/artifacts/exceptions.py +4 -4
  21. wandb/sdk/artifacts/storage_handler.py +2 -2
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
  23. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
  24. wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
  25. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
  26. wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
  27. wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
  28. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
  29. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
  30. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
  31. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
  33. wandb/sdk/artifacts/storage_policy.py +3 -3
  34. wandb/sdk/data_types/_dtypes.py +7 -12
  35. wandb/sdk/data_types/base_types/json_metadata.py +2 -2
  36. wandb/sdk/data_types/base_types/media.py +5 -6
  37. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  38. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
  39. wandb/sdk/data_types/helper_types/classes.py +5 -8
  40. wandb/sdk/data_types/helper_types/image_mask.py +4 -5
  41. wandb/sdk/data_types/histogram.py +3 -3
  42. wandb/sdk/data_types/html.py +3 -4
  43. wandb/sdk/data_types/image.py +4 -5
  44. wandb/sdk/data_types/molecule.py +2 -2
  45. wandb/sdk/data_types/object_3d.py +3 -3
  46. wandb/sdk/data_types/plotly.py +2 -2
  47. wandb/sdk/data_types/saved_model.py +7 -8
  48. wandb/sdk/data_types/trace_tree.py +4 -4
  49. wandb/sdk/data_types/video.py +4 -4
  50. wandb/sdk/interface/interface.py +8 -10
  51. wandb/sdk/internal/file_stream.py +2 -3
  52. wandb/sdk/internal/internal_api.py +99 -4
  53. wandb/sdk/internal/job_builder.py +15 -7
  54. wandb/sdk/internal/sender.py +4 -0
  55. wandb/sdk/internal/settings_static.py +1 -0
  56. wandb/sdk/launch/_project_spec.py +9 -7
  57. wandb/sdk/launch/agent/agent.py +115 -58
  58. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  59. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  60. wandb/sdk/launch/builder/abstract.py +5 -1
  61. wandb/sdk/launch/builder/build.py +16 -10
  62. wandb/sdk/launch/builder/docker_builder.py +9 -2
  63. wandb/sdk/launch/builder/kaniko_builder.py +108 -22
  64. wandb/sdk/launch/builder/noop.py +3 -1
  65. wandb/sdk/launch/environment/aws_environment.py +2 -1
  66. wandb/sdk/launch/environment/azure_environment.py +124 -0
  67. wandb/sdk/launch/github_reference.py +30 -18
  68. wandb/sdk/launch/launch.py +1 -1
  69. wandb/sdk/launch/loader.py +15 -0
  70. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  71. wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
  72. wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
  73. wandb/sdk/launch/runner/abstract.py +19 -3
  74. wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
  75. wandb/sdk/launch/runner/local_container.py +101 -48
  76. wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
  77. wandb/sdk/launch/runner/vertex_runner.py +8 -4
  78. wandb/sdk/launch/sweeps/scheduler.py +102 -27
  79. wandb/sdk/launch/sweeps/utils.py +21 -0
  80. wandb/sdk/launch/utils.py +19 -7
  81. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  82. wandb/sdk/service/server.py +22 -9
  83. wandb/sdk/service/service.py +27 -8
  84. wandb/sdk/verify/verify.py +6 -9
  85. wandb/sdk/wandb_config.py +2 -4
  86. wandb/sdk/wandb_init.py +2 -0
  87. wandb/sdk/wandb_require.py +7 -0
  88. wandb/sdk/wandb_run.py +32 -35
  89. wandb/sdk/wandb_settings.py +10 -3
  90. wandb/testing/relay.py +15 -2
  91. wandb/util.py +55 -23
  92. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
  93. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
  94. wandb/integration/langchain/util.py +0 -191
  95. wandb/sdk/artifacts/invalid_artifact.py +0 -23
  96. wandb/sdk/artifacts/lazy_artifact.py +0 -162
  97. wandb/sdk/artifacts/local_artifact.py +0 -719
  98. wandb/sdk/artifacts/public_artifact.py +0 -1188
  99. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  100. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  101. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
  102. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ """Implementation of AzureContainerRegistry class."""
2
+ import re
3
+ from typing import TYPE_CHECKING, Tuple
4
+
5
+ from wandb.util import get_module
6
+
7
+ from ..environment.abstract import AbstractEnvironment
8
+ from ..environment.azure_environment import AzureEnvironment
9
+ from ..errors import LaunchError
10
+ from .abstract import AbstractRegistry
11
+
12
+ if TYPE_CHECKING:
13
+ from azure.containerregistry import ContainerRegistryClient # type: ignore
14
+ from azure.core.exceptions import ResourceNotFoundError # type: ignore
15
+
16
+
17
+ ContainerRegistryClient = get_module( # noqa: F811
18
+ "azure.containerregistry",
19
+ required="The azure-containerregistry package is required to use launch with Azure. Please install it with `pip install azure-containerregistry`.",
20
+ ).ContainerRegistryClient
21
+
22
+ ResourceNotFoundError = get_module( # noqa: F811
23
+ "azure.core.exceptions",
24
+ required="The azure-core package is required to use launch with Azure. Please install it with `pip install azure-core`.",
25
+ ).ResourceNotFoundError
26
+
27
+
28
+ class AzureContainerRegistry(AbstractRegistry):
29
+ """Helper for accessing Azure Container Registry resources."""
30
+
31
+ def __init__(
32
+ self,
33
+ environment: AzureEnvironment,
34
+ uri: str,
35
+ verify: bool = True,
36
+ ):
37
+ """Initialize an AzureContainerRegistry."""
38
+ self.environment = environment
39
+ self.uri = uri
40
+ if verify:
41
+ self.verify()
42
+
43
+ @classmethod
44
+ def from_config(
45
+ cls, config: dict, environment: AbstractEnvironment, verify: bool = True
46
+ ) -> "AzureContainerRegistry":
47
+ """Create an AzureContainerRegistry from a config dict.
48
+
49
+ Args:
50
+ config (dict): The config dict.
51
+ environment (AbstractEnvironment): The environment to use.
52
+ verify (bool, optional): Whether to verify the registry. Defaults to True.
53
+
54
+ Returns:
55
+ AzureContainerRegistry: The registry.
56
+
57
+ Raises:
58
+ LaunchError: If the config is invalid.
59
+ """
60
+ if not isinstance(environment, AzureEnvironment):
61
+ raise LaunchError(
62
+ "AzureContainerRegistry requires an AzureEnvironment to be passed in."
63
+ )
64
+ uri = config.get("uri")
65
+ if uri is None:
66
+ raise LaunchError(
67
+ "Please specify a registry name to use under the registry.uri."
68
+ )
69
+ return cls(
70
+ uri=uri,
71
+ environment=environment,
72
+ verify=verify,
73
+ )
74
+
75
+ def get_username_password(self) -> Tuple[str, str]:
76
+ """Get username and password for container registry."""
77
+ raise NotImplementedError
78
+
79
+ def check_image_exists(self, image_uri: str) -> bool:
80
+ """Check if image exists in container registry.
81
+
82
+ Args:
83
+ image_uri (str): Image URI to check.
84
+
85
+ Returns:
86
+ bool: True if image exists, False otherwise.
87
+ """
88
+ credential = self.environment.get_credentials()
89
+ registry, repository, tag = self.parse_azurecr_uri(image_uri)
90
+ client = ContainerRegistryClient(f"https://{registry}.azurecr.io", credential)
91
+ try:
92
+ client.get_manifest_properties(repository, tag)
93
+ return True
94
+ except ResourceNotFoundError:
95
+ return False
96
+ except Exception as e:
97
+ raise LaunchError(
98
+ f"Unable to check if image exists in Azure Container Registry: {e}"
99
+ ) from e
100
+
101
+ def get_repo_uri(self) -> str:
102
+ return self.uri
103
+
104
+ def verify(self) -> None:
105
+ try:
106
+ _ = self.registry_name
107
+ except Exception as e:
108
+ raise LaunchError(f"Unable to verify Azure Container Registry: {e}") from e
109
+
110
+ @property
111
+ def registry_name(self) -> str:
112
+ """Get registry name."""
113
+ return self.parse_azurecr_uri(self.uri)[0]
114
+
115
+ @staticmethod
116
+ def parse_azurecr_uri(uri: str) -> Tuple[str, str, str]:
117
+ """Parse an Azure Container Registry URI.
118
+
119
+ Args:
120
+ uri (str): URI to parse.
121
+
122
+ Returns:
123
+ Tuple[str, str, str]: Tuple of registry name, repository name, and tag.
124
+
125
+ Raises:
126
+ LaunchError: If unable to parse URI.
127
+ """
128
+ regex = r"(?:https://)([\w]+)\.azurecr\.io/([\w\-]+):?(.*)"
129
+ match = re.match(regex, uri)
130
+ if match is None:
131
+ raise LaunchError(f"Unable to parse Azure Container Registry URI: {uri}")
132
+ return match.group(1), match.group(2), match.group(3)
@@ -1,8 +1,11 @@
1
1
  """Implementation of Elastic Container Registry class for wandb launch."""
2
2
  import base64
3
3
  import logging
4
+ import re
4
5
  from typing import Dict, Tuple
5
6
 
7
+ import yaml
8
+
6
9
  from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
7
10
  from wandb.sdk.launch.errors import LaunchError
8
11
  from wandb.util import get_module
@@ -52,7 +55,7 @@ class ElasticContainerRegistry(AbstractRegistry):
52
55
  @classmethod
53
56
  def from_config( # type: ignore[override]
54
57
  cls,
55
- config: Dict,
58
+ config: Dict[str, str],
56
59
  environment: AwsEnvironment,
57
60
  verify: bool = True,
58
61
  ) -> "ElasticContainerRegistry":
@@ -70,10 +73,41 @@ class ElasticContainerRegistry(AbstractRegistry):
70
73
  f"Could not create ElasticContainerRegistry from config. Expected type 'ecr' "
71
74
  f"but got '{config.get('type')}'."
72
75
  )
73
- repository = config.get("repository")
74
- if not repository:
76
+ if ("uri" in config) == ("repository" in config):
77
+ raise LaunchError(
78
+ "Could not create ElasticContainerRegistry from config. Either 'uri' or "
79
+ f"'repository' is required. The config received was:\n{yaml.dump(config)}."
80
+ )
81
+ if "repository" in config:
82
+ repository = config.get("repository")
83
+ else:
84
+ match = re.match(
85
+ r"^(?P<account>.*)\.dkr\.ecr\.(?P<region>.*)\.amazonaws\.com/(?P<repository>.*)/?$",
86
+ config["uri"],
87
+ )
88
+ if not match:
89
+ raise LaunchError(
90
+ f"Could not create ElasticContainerRegistry from config. The uri "
91
+ f"{config.get('uri')} is invalid."
92
+ )
93
+ repository = match.group("repository")
94
+ if match.group("region") != environment.region:
95
+ raise LaunchError(
96
+ f"Could not create ElasticContainerRegistry from config. The uri "
97
+ f"{config.get('uri')} is in region {match.group('region')} but the "
98
+ f"environment is in region {environment.region}."
99
+ )
100
+ if match.group("account") != environment._account:
101
+ raise LaunchError(
102
+ f"Could not create ElasticContainerRegistry from config. The uri "
103
+ f"{config.get('uri')} is in account {match.group('account')} but the "
104
+ f"account being used is {environment._account}."
105
+ )
106
+ if not isinstance(repository, str):
107
+ # This is for mypy. We should never get here.
75
108
  raise LaunchError(
76
- "Could not create ElasticContainerRegistry from config. 'repository' is required."
109
+ f"Could not create ElasticContainerRegistry from config. The repository "
110
+ f"{repository} is invalid: repository should be a string."
77
111
  )
78
112
  return cls(repository, environment)
79
113
 
@@ -3,6 +3,8 @@ import logging
3
3
  import re
4
4
  from typing import Tuple
5
5
 
6
+ import yaml
7
+
6
8
  from wandb.sdk.launch.environment.gcp_environment import GcpEnvironment
7
9
  from wandb.sdk.launch.errors import LaunchError
8
10
  from wandb.util import get_module
@@ -105,14 +107,51 @@ class GoogleArtifactRegistry(AbstractRegistry):
105
107
  Returns:
106
108
  A GoogleArtifactRegistry.
107
109
  """
108
- repository = config.get("repository")
109
- if not repository:
110
- raise LaunchError(
111
- "The Google Artifact Registry repository must be specified."
110
+ if "uri" in config:
111
+ if "repository" in config or "image-name" in config:
112
+ raise LaunchError(
113
+ "The Google Artifact Registry must be specified with either "
114
+ "the uri key or the repository and image-name keys, but not both. "
115
+ f"The provided config is:\n{yaml.dump(config)}."
116
+ )
117
+ match = re.match(
118
+ r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/(?P<image_name>[\w-]+)$",
119
+ config["uri"],
112
120
  )
113
- image_name = config.get("image-name")
114
- if not image_name:
115
- raise LaunchError("The image name must be specified.")
121
+ if not match:
122
+ raise LaunchError(
123
+ f"The Google Artifact Registry uri {config['uri']} is invalid. "
124
+ "Please provide a uri of the form "
125
+ "REGION-docker.pkg.dev/PROJECT/REPOSITORY/IMAGE_NAME."
126
+ )
127
+ else:
128
+ repository = match.group("repository")
129
+ image_name = match.group("image_name")
130
+ if match.group("region") != environment.region:
131
+ raise LaunchError(
132
+ f"The Google Artifact Registry uri {config['uri']} does not "
133
+ f"match the configured region {environment.region}."
134
+ )
135
+ if match.group("project") != environment.project:
136
+ raise LaunchError(
137
+ f"The Google Artifact Registry uri {config['uri']} does not "
138
+ f"match the configured project {environment.project}."
139
+ )
140
+ else:
141
+ repository = config.get("repository")
142
+ if not repository:
143
+ raise LaunchError(
144
+ "The Google Artifact Registry repository must be specified "
145
+ "by setting the either the uri or repository key of your "
146
+ f"registry config. The provided config is:\n{yaml.dump(config)}."
147
+ )
148
+ image_name = config.get("image-name")
149
+ if not image_name:
150
+ raise LaunchError(
151
+ "The Google Artifact Registry repository must be specified "
152
+ "by setting the either the uri or repository key of your "
153
+ f"registry config. The provided config is:\n{yaml.dump(config)}."
154
+ )
116
155
  return cls(repository, image_name, environment, verify=verify)
117
156
 
118
157
  def verify(self) -> None:
@@ -8,7 +8,7 @@ import os
8
8
  import subprocess
9
9
  import sys
10
10
  from abc import ABC, abstractmethod
11
- from typing import Any, Dict, List, Optional, Union
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
12
12
 
13
13
  from dockerpycreds.utils import find_executable # type: ignore
14
14
 
@@ -22,6 +22,9 @@ from .._project_spec import LaunchProject
22
22
 
23
23
  _logger = logging.getLogger(__name__)
24
24
 
25
+ if TYPE_CHECKING:
26
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
27
+
25
28
 
26
29
  if sys.version_info >= (3, 8):
27
30
  from typing import Literal
@@ -29,7 +32,14 @@ else:
29
32
  from typing_extensions import Literal
30
33
 
31
34
  State = Literal[
32
- "unknown", "starting", "running", "failed", "finished", "stopping", "stopped"
35
+ "unknown",
36
+ "starting",
37
+ "running",
38
+ "failed",
39
+ "finished",
40
+ "stopping",
41
+ "stopped",
42
+ "preempted",
33
43
  ]
34
44
 
35
45
 
@@ -61,6 +71,11 @@ class AbstractRun(ABC):
61
71
  def status(self) -> Status:
62
72
  return self._status
63
73
 
74
+ @abstractmethod
75
+ def get_logs(self) -> Optional[str]:
76
+ """Return the logs associated with the run."""
77
+ pass
78
+
64
79
  def _run_cmd(
65
80
  self, cmd: List[str], output_only: Optional[bool] = False
66
81
  ) -> Optional[Union["subprocess.Popen[bytes]", bytes]]:
@@ -106,7 +121,7 @@ class AbstractRun(ABC):
106
121
 
107
122
  @property
108
123
  @abstractmethod
109
- def id(self) -> str:
124
+ def id(self) -> Optional[str]:
110
125
  pass
111
126
 
112
127
 
@@ -152,6 +167,7 @@ class AbstractRunner(ABC):
152
167
  self,
153
168
  launch_project: LaunchProject,
154
169
  builder: AbstractBuilder,
170
+ job_tracker: Optional["JobAndRunStatusTracker"] = None,
155
171
  ) -> Optional[AbstractRun]:
156
172
  """Submit an LaunchProject to be run.
157
173
 
@@ -4,13 +4,15 @@ import base64
4
4
  import json
5
5
  import logging
6
6
  import time
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
8
8
 
9
9
  import wandb
10
10
  from wandb.apis.internal import Api
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
11
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
12
13
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
13
14
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
15
+ from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
14
16
  from wandb.sdk.launch.registry.local_registry import LocalRegistry
15
17
  from wandb.sdk.launch.runner.abstract import State, Status
16
18
  from wandb.util import get_module
@@ -104,6 +106,22 @@ class KubernetesSubmittedRun(AbstractRun):
104
106
  """Return the run id."""
105
107
  return self.name
106
108
 
109
+ def get_logs(self) -> Optional[str]:
110
+ try:
111
+ logs = self.core_api.read_namespaced_pod_log(
112
+ name=self.pod_names[0], namespace=self.namespace
113
+ )
114
+ if logs:
115
+ return str(logs)
116
+ else:
117
+ wandb.termwarn(
118
+ f"Retrieved no logs for kubernetes pod(s): {self.pod_names}"
119
+ )
120
+ return None
121
+ except Exception as e:
122
+ wandb.termerror(f"{LOG_PREFIX}Failed to get pod logs: {e}")
123
+ return None
124
+
107
125
  def get_job(self) -> "V1Job":
108
126
  """Return the job object."""
109
127
  return self.batch_api.read_namespaced_job(
@@ -128,14 +146,22 @@ class KubernetesSubmittedRun(AbstractRun):
128
146
 
129
147
  def get_status(self) -> Status:
130
148
  """Return the run status."""
131
- job_response = self.batch_api.read_namespaced_job_status(
132
- name=self.name, namespace=self.namespace
133
- )
134
- status = job_response.status
149
+ try:
150
+ job_response = self.batch_api.read_namespaced_job_status(
151
+ name=self.name, namespace=self.namespace
152
+ )
153
+ status = job_response.status
154
+
155
+ pod = self.core_api.read_namespaced_pod(
156
+ name=self.pod_names[0], namespace=self.namespace
157
+ )
158
+ except ApiException as e:
159
+ if "(404)" not in str(e):
160
+ raise
161
+ # 404 = Pod/job not reachable
162
+ wandb.termlog(f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}")
163
+ return Status("preempted")
135
164
 
136
- pod = self.core_api.read_namespaced_pod(
137
- name=self.pod_names[0], namespace=self.namespace
138
- )
139
165
  if pod.status.phase in ["Pending", "Unknown"]:
140
166
  now = time.time()
141
167
  if self._fail_count == 0:
@@ -154,7 +180,13 @@ class KubernetesSubmittedRun(AbstractRun):
154
180
  if status.succeeded == 1:
155
181
  return_status = Status("finished")
156
182
  elif status.failed is not None and status.failed >= 1:
157
- return_status = Status("failed")
183
+ if status.conditions[0].reason == "BackoffLimitExceeded":
184
+ wandb.termlog(
185
+ f"{LOG_PREFIX}Job or pod disconnected for job: {self.name}"
186
+ )
187
+ return_status = Status("preempted")
188
+ else:
189
+ return_status = Status("failed")
158
190
  elif status.active == 1:
159
191
  return Status("running")
160
192
  elif status.conditions is not None and status.conditions[0].type == "Suspended":
@@ -261,6 +293,23 @@ class CrdSubmittedRun(AbstractRun):
261
293
  """Get the name of the custom object."""
262
294
  return self.name
263
295
 
296
+ def get_logs(self) -> Optional[str]:
297
+ """Get logs for custom object."""
298
+ # TODO: test more carefully once we release multi-node support
299
+ logs: Dict[str, Optional[str]] = {}
300
+ try:
301
+ for pod_name in self.pod_names:
302
+ logs[pod_name] = self.core_api.read_namespaced_pod_log(
303
+ name=pod_name, namespace=self.namespace
304
+ )
305
+ except ApiException as e:
306
+ wandb.termwarn(f"Failed to get logs for {self.name}: {str(e)}")
307
+ return None
308
+ if not logs:
309
+ return None
310
+ logs_as_array = [f"Pod {pod_name}:\n{log}" for pod_name, log in logs.items()]
311
+ return "\n".join(logs_as_array)
312
+
264
313
  def get_status(self) -> Status:
265
314
  """Get status of custom object."""
266
315
  try:
@@ -403,6 +452,7 @@ class KubernetesRunner(AbstractRunner):
403
452
  builder: Optional[AbstractBuilder],
404
453
  namespace: str,
405
454
  core_api: "CoreV1Api",
455
+ job_tracker: Optional[JobAndRunStatusTracker],
406
456
  ) -> Tuple[Dict[str, Any], Optional["V1Secret"]]:
407
457
  """Apply our default values, return job dict and secret.
408
458
 
@@ -454,9 +504,10 @@ class KubernetesRunner(AbstractRunner):
454
504
  "Invalid specification of multiple containers. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
455
505
  )
456
506
  # dont specify run id if user provided image, could have multiple runs
457
- containers[0]["image"] = launch_project.docker_image
507
+ image_uri = launch_project.docker_image
508
+ containers[0]["image"] = image_uri
509
+ launch_project.fill_macros(image_uri)
458
510
  # TODO: handle secret pulling image from registry
459
- launch_project.fill_macros(launch_project.docker_image)
460
511
  elif not any(["image" in cont for cont in containers]):
461
512
  if len(containers) > 1:
462
513
  raise LaunchError(
@@ -464,7 +515,8 @@ class KubernetesRunner(AbstractRunner):
464
515
  )
465
516
  assert entry_point is not None
466
517
  assert builder is not None
467
- image_uri = builder.build_image(launch_project, entry_point)
518
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
519
+ image_uri = image_uri.replace("https://", "")
468
520
  launch_project.fill_macros(image_uri)
469
521
  # in the non instance case we need to make an imagePullSecret
470
522
  # so the new job can pull the image
@@ -510,6 +562,7 @@ class KubernetesRunner(AbstractRunner):
510
562
  self,
511
563
  launch_project: LaunchProject,
512
564
  builder: AbstractBuilder,
565
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
513
566
  ) -> Optional[AbstractRun]: # noqa: C901
514
567
  """Execute a launch project on Kubernetes.
515
568
 
@@ -545,7 +598,7 @@ class KubernetesRunner(AbstractRunner):
545
598
  image_uri = launch_project.docker_image
546
599
  else:
547
600
  assert entrypoint is not None
548
- image_uri = builder.build_image(launch_project, entrypoint)
601
+ image_uri = builder.build_image(launch_project, entrypoint, job_tracker)
549
602
  launch_project.fill_macros(image_uri)
550
603
  env_vars = get_env_vars_dict(launch_project, self._api)
551
604
  # Crawl the resource args and add our env vars to the containers.
@@ -607,11 +660,7 @@ class KubernetesRunner(AbstractRunner):
607
660
  namespace = self.get_namespace(resource_args, context)
608
661
 
609
662
  job, secret = self._inject_defaults(
610
- resource_args,
611
- launch_project,
612
- builder,
613
- namespace,
614
- core_api,
663
+ resource_args, launch_project, builder, namespace, core_api, job_tracker
615
664
  )
616
665
 
617
666
  msg = "Creating Kubernetes job"
@@ -681,7 +730,9 @@ def maybe_create_imagepull_secret(
681
730
  A secret if one was created, otherwise None.
682
731
  """
683
732
  secret = None
684
- if isinstance(registry, LocalRegistry):
733
+ if isinstance(registry, LocalRegistry) or isinstance(
734
+ registry, AzureContainerRegistry
735
+ ):
685
736
  # Secret not required
686
737
  return None
687
738
  uname, token = registry.get_username_password()
@@ -715,30 +766,37 @@ def add_wandb_env(root: Union[dict, list], env_vars: Dict[str, str]) -> None:
715
766
  Recursively walks the spec and injects the environment variables into
716
767
  every container spec. Containers are identified by the "containers" key.
717
768
 
769
+ This function treats the WANDB_RUN_ID and WANDB_GROUP_ID environment variables
770
+ specially. If they are present in the spec, they will be overwritten. If a setting
771
+ for WANDB_RUN_ID is provided in env_vars, then that environment variable will only be
772
+ set in the first container modified by this function.
773
+
718
774
  Arguments:
719
775
  root: The spec to modify.
720
776
  env_vars: The environment variables to inject.
721
777
 
722
778
  Returns: None.
723
779
  """
724
- if isinstance(root, dict):
725
- for k, v in root.items():
726
- if k == "containers":
727
- if isinstance(v, list):
728
- for cont in v:
729
- env = cont.get("env", [])
730
- env.extend(
731
- [
732
- {"name": key, "value": value}
733
- for key, value in env_vars.items()
734
- ]
735
- )
736
- cont["env"] = env
737
- elif isinstance(v, (dict, list)):
738
- add_wandb_env(v, env_vars)
739
- elif isinstance(root, list):
740
- for item in root:
741
- add_wandb_env(item, env_vars)
780
+
781
+ def yield_containers(root: Any) -> Iterator[dict]:
782
+ if isinstance(root, dict):
783
+ for k, v in root.items():
784
+ if k == "containers":
785
+ if isinstance(v, list):
786
+ yield from v
787
+ elif isinstance(v, (dict, list)):
788
+ yield from yield_containers(v)
789
+ elif isinstance(root, list):
790
+ for item in root:
791
+ yield from yield_containers(item)
792
+
793
+ for cont in yield_containers(root):
794
+ env = cont.setdefault("env", [])
795
+ env.extend([{"name": key, "value": value} for key, value in env_vars.items()])
796
+ cont["env"] = env
797
+ # After we have set WANDB_RUN_ID once, we don't want to set it again
798
+ if "WANDB_RUN_ID" in env_vars:
799
+ env_vars.pop("WANDB_RUN_ID")
742
800
 
743
801
 
744
802
  def add_label_to_pods(
@@ -757,16 +815,22 @@ def add_label_to_pods(
757
815
 
758
816
  Returns: None.
759
817
  """
760
- if isinstance(manifest, list):
761
- for item in manifest:
762
- add_label_to_pods(item, label_key, label_value)
763
- elif isinstance(manifest, dict):
764
- if "spec" in manifest and "containers" in manifest["spec"]:
765
- metadata = manifest.setdefault("metadata", {})
766
- labels = metadata.setdefault("labels", {})
767
- labels[label_key] = label_value
768
- for value in manifest.values():
769
- add_label_to_pods(value, label_key, label_value)
818
+
819
+ def yield_pods(manifest: Any) -> Iterator[dict]:
820
+ if isinstance(manifest, list):
821
+ for item in manifest:
822
+ yield from yield_pods(item)
823
+ elif isinstance(manifest, dict):
824
+ if "spec" in manifest and "containers" in manifest["spec"]:
825
+ yield manifest
826
+ for value in manifest.values():
827
+ if isinstance(value, (dict, list)):
828
+ yield from yield_pods(value)
829
+
830
+ for pod in yield_pods(manifest):
831
+ metadata = pod.setdefault("metadata", {})
832
+ labels = metadata.setdefault("labels", {})
833
+ labels[label_key] = label_value
770
834
 
771
835
 
772
836
  def add_entrypoint_args_overrides(manifest: Union[dict, list], overrides: dict) -> None: