wandb 0.13.11__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/apis/importers/__init__.py +4 -0
  3. wandb/apis/importers/base.py +312 -0
  4. wandb/apis/importers/mlflow.py +113 -0
  5. wandb/apis/internal.py +9 -0
  6. wandb/apis/public.py +0 -2
  7. wandb/cli/cli.py +100 -72
  8. wandb/docker/__init__.py +33 -5
  9. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  10. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  11. wandb/sdk/internal/internal_api.py +85 -9
  12. wandb/sdk/launch/_project_spec.py +45 -55
  13. wandb/sdk/launch/agent/agent.py +80 -18
  14. wandb/sdk/launch/builder/build.py +16 -74
  15. wandb/sdk/launch/builder/docker_builder.py +36 -8
  16. wandb/sdk/launch/builder/kaniko_builder.py +78 -37
  17. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +68 -18
  18. wandb/sdk/launch/environment/aws_environment.py +4 -0
  19. wandb/sdk/launch/launch.py +1 -6
  20. wandb/sdk/launch/launch_add.py +0 -5
  21. wandb/sdk/launch/registry/abstract.py +12 -0
  22. wandb/sdk/launch/registry/elastic_container_registry.py +31 -1
  23. wandb/sdk/launch/registry/google_artifact_registry.py +32 -0
  24. wandb/sdk/launch/registry/local_registry.py +15 -1
  25. wandb/sdk/launch/runner/abstract.py +0 -14
  26. wandb/sdk/launch/runner/kubernetes_runner.py +25 -19
  27. wandb/sdk/launch/runner/local_container.py +7 -8
  28. wandb/sdk/launch/runner/local_process.py +0 -3
  29. wandb/sdk/launch/runner/sagemaker_runner.py +0 -3
  30. wandb/sdk/launch/runner/vertex_runner.py +0 -2
  31. wandb/sdk/launch/sweeps/scheduler.py +39 -10
  32. wandb/sdk/launch/utils.py +52 -4
  33. wandb/sdk/wandb_run.py +3 -10
  34. wandb/sync/sync.py +1 -0
  35. wandb/util.py +1 -0
  36. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/METADATA +1 -1
  37. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/RECORD +41 -38
  38. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  39. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  40. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  41. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,11 @@
1
1
  import base64
2
2
  import json
3
+ import logging
3
4
  import tarfile
4
5
  import tempfile
5
6
  import time
6
7
  from typing import Optional
7
8
 
8
- import kubernetes # type: ignore
9
- from kubernetes import client
10
-
11
9
  import wandb
12
10
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
13
11
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
@@ -16,7 +14,7 @@ from wandb.sdk.launch.registry.elastic_container_registry import (
16
14
  ElasticContainerRegistry,
17
15
  )
18
16
  from wandb.sdk.launch.registry.google_artifact_registry import GoogleArtifactRegistry
19
- from wandb.sdk.launch.utils import LaunchError
17
+ from wandb.util import get_module
20
18
 
21
19
  from .._project_spec import (
22
20
  EntryPoint,
@@ -24,8 +22,28 @@ from .._project_spec import (
24
22
  create_metadata_file,
25
23
  get_entry_point_command,
26
24
  )
27
- from ..utils import LOG_PREFIX, get_kube_context_and_api_client, sanitize_wandb_api_key
28
- from .build import _create_docker_build_ctx, generate_dockerfile
25
+ from ..utils import (
26
+ LOG_PREFIX,
27
+ LaunchError,
28
+ get_kube_context_and_api_client,
29
+ sanitize_wandb_api_key,
30
+ warn_failed_packages_from_build_logs,
31
+ )
32
+ from .build import (
33
+ _create_docker_build_ctx,
34
+ generate_dockerfile,
35
+ image_tag_from_dockerfile_and_source,
36
+ )
37
+
38
+ get_module(
39
+ "kubernetes",
40
+ required="Kaniko builder requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
41
+ )
42
+
43
+ import kubernetes # type: ignore # noqa: E402
44
+ from kubernetes import client # noqa: E402
45
+
46
+ _logger = logging.getLogger(__name__)
29
47
 
30
48
  _DEFAULT_BUILD_TIMEOUT_SECS = 1800 # 30 minute build timeout
31
49
 
@@ -155,7 +173,7 @@ class KanikoBuilder(AbstractBuilder):
155
173
  pass
156
174
 
157
175
  def _create_docker_ecr_config_map(
158
- self, corev1_client: client.CoreV1Api, repository: str
176
+ self, job_name: str, corev1_client: client.CoreV1Api, repository: str
159
177
  ) -> None:
160
178
  if self.registry is None:
161
179
  raise LaunchError("No registry specified for Kaniko build.")
@@ -165,7 +183,7 @@ class KanikoBuilder(AbstractBuilder):
165
183
  api_version="v1",
166
184
  kind="ConfigMap",
167
185
  metadata=client.V1ObjectMeta(
168
- name="docker-config",
186
+ name=f"docker-config-{job_name}",
169
187
  namespace="wandb",
170
188
  ),
171
189
  data={
@@ -177,8 +195,11 @@ class KanikoBuilder(AbstractBuilder):
177
195
  )
178
196
  corev1_client.create_namespaced_config_map("wandb", ecr_config_map)
179
197
 
180
- def _delete_docker_ecr_config_map(self, client: client.CoreV1Api) -> None:
181
- client.delete_namespaced_config_map("docker-config", "wandb")
198
+ def _delete_docker_ecr_config_map(
199
+ self, job_name: str, client: client.CoreV1Api
200
+ ) -> None:
201
+ if self.secret_name:
202
+ client.delete_namespaced_config_map(f"docker-config-{job_name}", "wandb")
182
203
 
183
204
  def _upload_build_context(self, run_id: str, context_path: str) -> str:
184
205
  # creat a tar archive of the build context and upload it to s3
@@ -197,20 +218,28 @@ class KanikoBuilder(AbstractBuilder):
197
218
  launch_project: LaunchProject,
198
219
  entrypoint: EntryPoint,
199
220
  ) -> str:
221
+ # TODO: this should probably throw an error if the registry is a local registry
200
222
  if not self.registry:
201
223
  raise LaunchError("No registry specified for Kaniko build.")
224
+ # kaniko builder doesn't seem to work with a custom user id, need more investigation
225
+ dockerfile_str = generate_dockerfile(
226
+ launch_project, entrypoint, launch_project.resource, "kaniko"
227
+ )
228
+ image_tag = image_tag_from_dockerfile_and_source(launch_project, dockerfile_str)
202
229
  repo_uri = self.registry.get_repo_uri()
203
- image_uri = repo_uri + ":" + launch_project.image_tag
230
+ image_uri = repo_uri + ":" + image_tag
231
+
232
+ if not launch_project.build_required() and self.registry.check_image_exists(
233
+ image_uri
234
+ ):
235
+ return image_uri
236
+
237
+ _logger.info(f"Building image {image_uri}...")
204
238
 
205
239
  entry_cmd = " ".join(
206
240
  get_entry_point_command(entrypoint, launch_project.override_args)
207
241
  )
208
242
 
209
- # kaniko builder doesn't seem to work with a custom user id, need more investigation
210
- dockerfile_str = generate_dockerfile(
211
- launch_project, entrypoint, launch_project.resource, self.type
212
- )
213
-
214
243
  create_metadata_file(
215
244
  launch_project,
216
245
  image_uri,
@@ -240,7 +269,8 @@ class KanikoBuilder(AbstractBuilder):
240
269
 
241
270
  try:
242
271
  # core_v1.create_namespaced_config_map("wandb", dockerfile_config_map)
243
- self._create_docker_ecr_config_map(core_v1, repo_uri)
272
+ if self.secret_name:
273
+ self._create_docker_ecr_config_map(build_job_name, core_v1, repo_uri)
244
274
  batch_v1.create_namespaced_job("wandb", build_job)
245
275
 
246
276
  # wait for double the job deadline since it might take time to schedule
@@ -248,6 +278,13 @@ class KanikoBuilder(AbstractBuilder):
248
278
  batch_v1, build_job_name, 3 * _DEFAULT_BUILD_TIMEOUT_SECS
249
279
  ):
250
280
  raise Exception(f"Failed to build image in kaniko for job {run_id}")
281
+ try:
282
+ logs = batch_v1.read_namespaced_job_log(build_job_name, "wandb")
283
+ warn_failed_packages_from_build_logs(logs, image_uri)
284
+ except Exception as e:
285
+ wandb.termwarn(
286
+ f"{LOG_PREFIX}Failed to get logs for kaniko job {build_job_name}: {e}"
287
+ )
251
288
  except Exception as e:
252
289
  wandb.termerror(
253
290
  f"{LOG_PREFIX}Exception when creating Kubernetes resources: {e}\n"
@@ -258,7 +295,8 @@ class KanikoBuilder(AbstractBuilder):
258
295
  try:
259
296
  # should we clean up the s3 build contexts? can set bucket level policy to auto deletion
260
297
  # core_v1.delete_namespaced_config_map(config_map_name, "wandb")
261
- self._delete_docker_ecr_config_map(core_v1)
298
+ if self.secret_name:
299
+ self._delete_docker_ecr_config_map(build_job_name, core_v1)
262
300
  batch_v1.delete_namespaced_job(build_job_name, "wandb")
263
301
  except Exception as e:
264
302
  raise LaunchError(f"Exception during Kubernetes resource clean up {e}")
@@ -273,25 +311,34 @@ class KanikoBuilder(AbstractBuilder):
273
311
  build_context_path: str,
274
312
  ) -> "client.V1Job":
275
313
  env = []
276
- volume_mounts = [
277
- client.V1VolumeMount(name="docker-config", mount_path="/kaniko/.docker/"),
278
- ]
279
- volumes = [
280
- client.V1Volume(
281
- name="docker-config",
282
- config_map=client.V1ConfigMapVolumeSource(
283
- name="docker-config",
284
- ),
285
- ),
286
- ]
314
+ volume_mounts = []
315
+ volumes = []
287
316
  if bool(self.secret_name) != bool(self.secret_key):
288
317
  raise LaunchError(
289
318
  "Both secret_name and secret_key or neither must be specified "
290
319
  "for kaniko build. You provided only one of them."
291
320
  )
321
+ if isinstance(self.registry, ElasticContainerRegistry):
322
+ env += [
323
+ client.V1EnvVar(
324
+ name="AWS_REGION",
325
+ value=self.registry.environment.region,
326
+ )
327
+ ]
292
328
  if self.secret_name and self.secret_key:
293
- # TODO: We should validate that the secret exists and has the key
294
- # before creating the job. Or when we create the builder.
329
+ volumes += [
330
+ client.V1Volume(
331
+ name="docker-config",
332
+ config_map=client.V1ConfigMapVolumeSource(
333
+ name=f"docker-config-{job_name}",
334
+ ),
335
+ ),
336
+ ]
337
+ volume_mounts += [
338
+ client.V1VolumeMount(
339
+ name="docker-config", mount_path="/kaniko/.docker/"
340
+ ),
341
+ ]
295
342
  # TODO: I don't like conditioning on the registry type here. As a
296
343
  # future change I want the registry and environment classes to provide
297
344
  # a list of environment variables and volume mounts that need to be
@@ -303,12 +350,6 @@ class KanikoBuilder(AbstractBuilder):
303
350
  if isinstance(self.registry, ElasticContainerRegistry):
304
351
  mount_path = "/root/.aws"
305
352
  key = "credentials"
306
- env += [
307
- client.V1EnvVar(
308
- name="AWS_REGION",
309
- value=self.registry.environment.region,
310
- )
311
- ]
312
353
  elif isinstance(self.registry, GoogleArtifactRegistry):
313
354
  mount_path = "/kaniko/.config/gcloud"
314
355
  key = "config.json"
@@ -1,10 +1,13 @@
1
1
  import json
2
2
  import multiprocessing
3
3
  import os
4
+ import re
4
5
  import subprocess
5
6
  import sys
6
7
  from typing import List, Optional, Set
7
8
 
9
+ FAILED_PACKAGES_PREFIX = "ERROR: Failed to install: "
10
+ FAILED_PACKAGES_POSTFIX = ". During automated build process."
8
11
  CORES = multiprocessing.cpu_count()
9
12
  ONLY_INCLUDE = {x for x in os.getenv("WANDB_ONLY_INCLUDE", "").split(",") if x != ""}
10
13
  OPTS = []
@@ -21,7 +24,10 @@ else:
21
24
 
22
25
 
23
26
  def install_deps(
24
- deps: List[str], failed: Optional[Set[str]] = None
27
+ deps: List[str],
28
+ failed: Optional[Set[str]] = None,
29
+ extra_index: Optional[str] = None,
30
+ opts: Optional[List[str]] = None,
25
31
  ) -> Optional[Set[str]]:
26
32
  """Install pip dependencies.
27
33
 
@@ -35,33 +41,45 @@ def install_deps(
35
41
  try:
36
42
  # Include only uri if @ is present
37
43
  clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
38
-
44
+ index_args = ["--extra-index-url", extra_index] if extra_index else []
39
45
  print("installing {}...".format(", ".join(clean_deps)))
46
+ opts = opts or []
47
+ args = ["pip", "install"] + opts + clean_deps + index_args
40
48
  sys.stdout.flush()
41
- subprocess.check_output(
42
- ["pip", "install"] + OPTS + clean_deps, stderr=subprocess.STDOUT
43
- )
44
- if failed is not None and len(failed) > 0:
45
- sys.stderr.write(
46
- "ERROR: Unable to install: {}".format(", ".join(clean_deps))
47
- )
48
- sys.stderr.flush()
49
+ subprocess.check_output(args, stderr=subprocess.STDOUT)
49
50
  return failed
50
51
  except subprocess.CalledProcessError as e:
51
52
  if failed is None:
52
53
  failed = set()
53
54
  num_failed = len(failed)
54
- for line in e.output.decode("utf8"):
55
+ for line in e.output.decode("utf8").splitlines():
55
56
  if line.startswith("ERROR:"):
56
- failed.add(line.split(" ")[-1])
57
- if len(failed) > num_failed:
58
- return install_deps(list(set(clean_deps) - failed), failed)
57
+ clean_dep = find_package_in_error_string(clean_deps, line)
58
+ if clean_dep is not None:
59
+ if clean_dep in deps:
60
+ failed.add(clean_dep)
61
+ else:
62
+ for d in deps:
63
+ if clean_dep in d:
64
+ failed.add(d.replace(" ", ""))
65
+ break
66
+ if len(set(clean_deps) - failed) == 0:
67
+ return failed
68
+ elif len(failed) > num_failed:
69
+ return install_deps(
70
+ list(set(clean_deps) - failed),
71
+ failed,
72
+ extra_index=extra_index,
73
+ opts=opts,
74
+ )
59
75
  else:
60
76
  return failed
61
77
 
62
78
 
63
79
  def main() -> None:
64
80
  """Install deps in requirements.frozen.txt."""
81
+ extra_index = None
82
+ torch_reqs = []
65
83
  if os.path.exists("requirements.frozen.txt"):
66
84
  with open("requirements.frozen.txt") as f:
67
85
  print("Installing frozen dependencies...")
@@ -72,28 +90,60 @@ def main() -> None:
72
90
  # can't pip install wandb==0.*.*.dev1 through pip. Lets just install wandb for now
73
91
  if req.startswith("wandb==") and "dev1" in req:
74
92
  req = "wandb"
75
- reqs.append(req.strip().replace(" ", ""))
93
+ match = re.match(
94
+ r"torch(vision|audio)?==\d+\.\d+\.\d+(\+(?:cu[\d]{2,3})|(?:cpu))?",
95
+ req,
96
+ )
97
+ if match:
98
+ variant = match.group(2)
99
+ if variant:
100
+ extra_index = (
101
+ f"https://download.pytorch.org/whl/{variant[1:]}"
102
+ )
103
+ torch_reqs.append(req.strip().replace(" ", ""))
104
+ else:
105
+ reqs.append(req.strip().replace(" ", ""))
76
106
  else:
77
107
  print(f"Ignoring requirement: {req} from frozen requirements")
78
108
  if len(reqs) >= CORES:
79
- deps_failed = install_deps(reqs)
109
+ deps_failed = install_deps(reqs, opts=OPTS)
80
110
  reqs = []
81
111
  if deps_failed is not None:
82
112
  failed = failed.union(deps_failed)
83
113
  if len(reqs) > 0:
84
- deps_failed = install_deps(reqs)
114
+ deps_failed = install_deps(reqs, opts=OPTS)
85
115
  if deps_failed is not None:
86
116
  failed = failed.union(deps_failed)
87
117
  with open("_wandb_bootstrap_errors.json", "w") as f:
88
118
  f.write(json.dumps({"pip": list(failed)}))
89
119
  if len(failed) > 0:
90
120
  sys.stderr.write(
91
- "ERROR: Failed to install: {}".format(",".join(failed))
121
+ FAILED_PACKAGES_PREFIX + ",".join(failed) + FAILED_PACKAGES_POSTFIX
92
122
  )
93
123
  sys.stderr.flush()
124
+ install_deps(torch_reqs, extra_index=extra_index)
94
125
  else:
95
126
  print("No frozen requirements found")
96
127
 
97
128
 
129
+ # hacky way to get the name of the requirement that failed
130
+ # attempt last word which is the name of the package often
131
+ # fall back to checking all words in the line for the package name
132
+ def find_package_in_error_string(deps: List[str], line: str) -> Optional[str]:
133
+ # if the last word in the error string is in the list of deps, return it
134
+ last_word = line.split(" ")[-1]
135
+ if last_word in deps:
136
+ return last_word
137
+ # if the last word is not in the list of deps, check all words
138
+ # TODO: this could report the wrong package if the error string
139
+ # contains a reference to another package in the deps
140
+ # before the package that failed to install
141
+ for word in line.split(" "):
142
+ if word in deps:
143
+ return word
144
+ # if we can't find the package, return None
145
+ return None
146
+
147
+
98
148
  if __name__ == "__main__":
99
149
  main()
@@ -70,6 +70,10 @@ class AwsEnvironment(AbstractEnvironment):
70
70
  session = boto3.Session()
71
71
  region = region or session.region_name
72
72
  credentials = session.get_credentials()
73
+ if not credentials:
74
+ raise LaunchError(
75
+ "Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly."
76
+ )
73
77
  access_key = credentials.access_key
74
78
  secret_key = credentials.secret_key
75
79
  session_token = credentials.token
@@ -55,7 +55,7 @@ def resolve_agent_config( # noqa: C901
55
55
  "api_key": api.api_key,
56
56
  "base_url": api.settings("base_url"),
57
57
  "registry": {},
58
- "build": {},
58
+ "builder": {},
59
59
  "runner": {},
60
60
  }
61
61
  user_set_project = False
@@ -151,7 +151,6 @@ def _run(
151
151
  resource_args: Optional[Dict[str, Any]],
152
152
  launch_config: Optional[Dict[str, Any]],
153
153
  synchronous: Optional[bool],
154
- cuda: Optional[bool],
155
154
  api: Api,
156
155
  run_id: Optional[str],
157
156
  repository: Optional[str],
@@ -171,7 +170,6 @@ def _run(
171
170
  parameters,
172
171
  resource_args,
173
172
  launch_config,
174
- cuda,
175
173
  run_id,
176
174
  repository,
177
175
  )
@@ -217,7 +215,6 @@ def run(
217
215
  docker_image: Optional[str] = None,
218
216
  config: Optional[Dict[str, Any]] = None,
219
217
  synchronous: Optional[bool] = True,
220
- cuda: Optional[bool] = None,
221
218
  run_id: Optional[str] = None,
222
219
  repository: Optional[str] = None,
223
220
  ) -> AbstractRun:
@@ -247,7 +244,6 @@ def run(
247
244
  asynchronous runs launched via this method will be terminated. If
248
245
  ``synchronous`` is True and the run fails, the current process will
249
246
  error out as well.
250
- cuda: Whether to build a CUDA-enabled docker image or not
251
247
  run_id: ID for the run (To ultimately replace the :name: field)
252
248
  repository: string name of repository path for remote registry
253
249
 
@@ -290,7 +286,6 @@ def run(
290
286
  resource_args=resource_args,
291
287
  launch_config=config,
292
288
  synchronous=synchronous,
293
- cuda=cuda,
294
289
  api=api,
295
290
  run_id=run_id,
296
291
  repository=repository,
@@ -44,7 +44,6 @@ def launch_add(
44
44
  params: Optional[Dict[str, Any]] = None,
45
45
  project_queue: Optional[str] = None,
46
46
  resource_args: Optional[Dict[str, Any]] = None,
47
- cuda: Optional[bool] = None,
48
47
  run_id: Optional[str] = None,
49
48
  build: Optional[bool] = False,
50
49
  repository: Optional[str] = None,
@@ -69,7 +68,6 @@ def launch_add(
69
68
  the parameters used to run the original run.
70
69
  resource_args: Resource related arguments for launching runs onto a remote backend.
71
70
  Will be stored on the constructed launch config under ``resource_args``.
72
- cuda: Whether to build a CUDA-enabled docker image or not
73
71
  run_id: optional string indicating the id of the launched run
74
72
  build: optional flag defaulting to false, requires queue to be set
75
73
  if build, an image is created, creates a job artifact, pushes a reference
@@ -116,7 +114,6 @@ def launch_add(
116
114
  params,
117
115
  project_queue,
118
116
  resource_args,
119
- cuda,
120
117
  run_id=run_id,
121
118
  build=build,
122
119
  repository=repository,
@@ -139,7 +136,6 @@ def _launch_add(
139
136
  params: Optional[Dict[str, Any]],
140
137
  project_queue: Optional[str],
141
138
  resource_args: Optional[Dict[str, Any]] = None,
142
- cuda: Optional[bool] = None,
143
139
  run_id: Optional[str] = None,
144
140
  build: Optional[bool] = False,
145
141
  repository: Optional[str] = None,
@@ -158,7 +154,6 @@ def _launch_add(
158
154
  params,
159
155
  resource_args,
160
156
  config,
161
- cuda,
162
157
  run_id,
163
158
  repository,
164
159
  )
@@ -33,6 +33,18 @@ class AbstractRegistry(ABC):
33
33
  """
34
34
  raise NotImplementedError
35
35
 
36
+ @abstractmethod
37
+ def check_image_exists(self, image_uri: str) -> bool:
38
+ """Check if an image exists in the registry.
39
+
40
+ Arguments:
41
+ image_uri (str): The URI of the image.
42
+
43
+ Returns:
44
+ bool: True if the image exists.
45
+ """
46
+ raise NotImplementedError
47
+
36
48
  @classmethod
37
49
  @abstractmethod
38
50
  def from_config(
@@ -87,7 +87,6 @@ class ElasticContainerRegistry(AbstractRegistry):
87
87
  try:
88
88
  session = self.environment.get_session()
89
89
  client = session.client("ecr")
90
- response = client.describe_registry()
91
90
  response = client.describe_repositories(repositoryNames=[self.repo_name])
92
91
  self.uri = response["repositories"][0]["repositoryUri"].split("/")[0]
93
92
 
@@ -131,3 +130,34 @@ class ElasticContainerRegistry(AbstractRegistry):
131
130
  str: The uri of the repository.
132
131
  """
133
132
  return self.uri + "/" + self.repo_name
133
+
134
+ def check_image_exists(self, image_uri: str) -> bool:
135
+ """Check if the image tag exists.
136
+
137
+ Arguments:
138
+ image_uri (str): The full image_uri.
139
+
140
+ Returns:
141
+ bool: True if the image tag exists.
142
+ """
143
+ uri, tag = image_uri.split(":")
144
+ if uri != self.get_repo_uri():
145
+ raise LaunchError(
146
+ f"Image uri {image_uri} does not match Elastic Container Registry uri {self.get_repo_uri()}."
147
+ )
148
+
149
+ _logger.debug("Checking if image tag exists.")
150
+ try:
151
+ session = self.environment.get_session()
152
+ client = session.client("ecr")
153
+ response = client.describe_images(
154
+ repositoryName=self.repo_name, imageIds=[{"imageTag": tag}]
155
+ )
156
+ return len(response["imageDetails"]) > 0
157
+
158
+ except botocore.exceptions.ClientError as e:
159
+ code = e.response["Error"]["Code"]
160
+ if code == "ImageNotFoundException":
161
+ return False
162
+ msg = e.response["Error"]["Message"]
163
+ raise LaunchError(f"Error checking if image tag exists: {code} {msg}")
@@ -169,3 +169,35 @@ class GoogleArtifactRegistry(AbstractRegistry):
169
169
  f"{self.environment.region}-docker.pkg.dev/"
170
170
  f"{self.environment.project}/{self.repository}/{self.image_name}"
171
171
  )
172
+
173
+ def check_image_exists(self, image_uri: str) -> bool:
174
+ """Check if the image exists.
175
+
176
+ Arguments:
177
+ image_uri: The image URI.
178
+
179
+ Returns:
180
+ True if the image exists, False otherwise.
181
+ """
182
+ _logger.info(
183
+ f"Checking if image {image_uri} exists. In Google Artifact Registry {self.uri}."
184
+ )
185
+
186
+ return False
187
+ # TODO: Test GCP Artifact Registry image exists to get working
188
+ # repo_uri, _ = image_uri.split(":")
189
+ # if repo_uri != self.get_repo_uri():
190
+ # raise LaunchError(
191
+ # f"The image {image_uri} does not belong to the Google Artifact "
192
+ # f"Repository {self.get_repo_uri()}."
193
+ # )
194
+ # credentials = self.environment.get_credentials()
195
+ # request = google.cloud.artifactregistry.GetTagRequest(parent=image_uri)
196
+ # client = google.cloud.artifactregistry.ArtifactRegistryClient(
197
+ # credentials=credentials
198
+ # )
199
+ # try:
200
+ # client.get_tag(request=request)
201
+ # return True
202
+ # except google.api_core.exceptions.NotFound:
203
+ # return False
@@ -1,11 +1,14 @@
1
1
  """Local registry implementation."""
2
+ import logging
2
3
  from typing import Tuple
3
4
 
4
- from wandb.sdk.launch.utils import LaunchError
5
+ from wandb.sdk.launch.utils import LaunchError, docker_image_exists
5
6
 
6
7
  from ..environment.abstract import AbstractEnvironment
7
8
  from .abstract import AbstractRegistry
8
9
 
10
+ _logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  class LocalRegistry(AbstractRegistry):
11
14
  """A local registry.
@@ -46,3 +49,14 @@ class LocalRegistry(AbstractRegistry):
46
49
  Returns: An empty string.
47
50
  """
48
51
  return ""
52
+
53
+ def check_image_exists(self, image_uri: str) -> bool:
54
+ """Check if an image exists in the local registry.
55
+
56
+ Arguments:
57
+ image_uri (str): The uri of the image.
58
+
59
+ Returns:
60
+ bool: True.
61
+ """
62
+ return docker_image_exists(image_uri)
@@ -10,7 +10,6 @@ from dockerpycreds.utils import find_executable # type: ignore
10
10
  import wandb
11
11
  from wandb import Settings
12
12
  from wandb.apis.internal import Api
13
- from wandb.errors import CommError
14
13
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
15
14
  from wandb.sdk.lib import runid
16
15
 
@@ -143,19 +142,6 @@ class AbstractRunner(ABC):
143
142
  sys.exit(1)
144
143
  return True
145
144
 
146
- def ack_run_queue_item(self, launch_project: LaunchProject) -> bool:
147
- if self.backend_config.get("runQueueItemId"):
148
- try:
149
- self._api.ack_run_queue_item(
150
- self.backend_config["runQueueItemId"], launch_project.run_id
151
- )
152
- except CommError:
153
- wandb.termerror(
154
- "Error acking run queue item. Item lease may have ended or another process may have acked it."
155
- )
156
- return False
157
- return True
158
-
159
145
  @abstractmethod
160
146
  def run(
161
147
  self,