wandb 0.16.6__py3-none-any.whl → 0.17.0rc2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (151) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -2
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/protocols.py +30 -56
  6. wandb/apis/importers/mlflow.py +13 -26
  7. wandb/apis/importers/wandb.py +8 -14
  8. wandb/apis/public/api.py +1 -0
  9. wandb/apis/public/artifacts.py +1 -0
  10. wandb/apis/public/files.py +1 -0
  11. wandb/apis/public/history.py +1 -0
  12. wandb/apis/public/jobs.py +1 -0
  13. wandb/apis/public/projects.py +1 -0
  14. wandb/apis/public/reports.py +1 -0
  15. wandb/apis/public/runs.py +1 -0
  16. wandb/apis/public/sweeps.py +1 -0
  17. wandb/apis/public/teams.py +1 -0
  18. wandb/apis/public/users.py +1 -0
  19. wandb/apis/reports/v1/_blocks.py +3 -7
  20. wandb/apis/reports/v2/gql.py +1 -0
  21. wandb/apis/reports/v2/interface.py +3 -4
  22. wandb/apis/reports/v2/internal.py +5 -8
  23. wandb/cli/cli.py +2 -2
  24. wandb/data_types.py +9 -6
  25. wandb/docker/__init__.py +1 -1
  26. wandb/env.py +38 -8
  27. wandb/errors/__init__.py +5 -0
  28. wandb/integration/catboost/catboost.py +1 -1
  29. wandb/integration/fastai/__init__.py +1 -0
  30. wandb/integration/huggingface/resolver.py +2 -2
  31. wandb/integration/keras/__init__.py +1 -0
  32. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  33. wandb/integration/keras/keras.py +7 -7
  34. wandb/integration/langchain/wandb_tracer.py +1 -0
  35. wandb/integration/lightning/fabric/logger.py +1 -3
  36. wandb/integration/metaflow/metaflow.py +41 -6
  37. wandb/integration/openai/fine_tuning.py +3 -3
  38. wandb/keras/__init__.py +1 -0
  39. wandb/old/summary.py +1 -1
  40. wandb/plot/confusion_matrix.py +1 -1
  41. wandb/plots/precision_recall.py +1 -1
  42. wandb/plots/roc.py +1 -1
  43. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  44. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  45. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  46. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  47. wandb/proto/wandb_internal_codegen.py +0 -25
  48. wandb/sdk/artifacts/artifact.py +16 -4
  49. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  50. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  51. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  52. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  53. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
  54. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  55. wandb/sdk/artifacts/artifact_saver.py +5 -2
  56. wandb/sdk/artifacts/artifact_state.py +1 -0
  57. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  58. wandb/sdk/artifacts/exceptions.py +1 -0
  59. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  60. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  61. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  62. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  63. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  64. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  65. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  66. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  67. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  68. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
  69. wandb/sdk/artifacts/storage_policy.py +1 -0
  70. wandb/sdk/data_types/_dtypes.py +8 -8
  71. wandb/sdk/data_types/base_types/media.py +3 -6
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  73. wandb/sdk/data_types/image.py +1 -1
  74. wandb/sdk/data_types/video.py +1 -1
  75. wandb/sdk/integration_utils/auto_logging.py +5 -6
  76. wandb/sdk/integration_utils/data_logging.py +10 -6
  77. wandb/sdk/interface/interface.py +55 -32
  78. wandb/sdk/interface/interface_shared.py +7 -13
  79. wandb/sdk/internal/datastore.py +1 -1
  80. wandb/sdk/internal/handler.py +18 -2
  81. wandb/sdk/internal/internal.py +0 -1
  82. wandb/sdk/internal/internal_util.py +0 -1
  83. wandb/sdk/internal/job_builder.py +5 -4
  84. wandb/sdk/internal/profiler.py +1 -0
  85. wandb/sdk/internal/run.py +1 -0
  86. wandb/sdk/internal/sender.py +1 -1
  87. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  88. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  89. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  90. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  91. wandb/sdk/internal/system/assets/trainium.py +1 -3
  92. wandb/sdk/launch/_project_spec.py +8 -4
  93. wandb/sdk/launch/agent/agent.py +2 -1
  94. wandb/sdk/launch/agent/config.py +72 -11
  95. wandb/sdk/launch/builder/abstract.py +2 -1
  96. wandb/sdk/launch/builder/build.py +29 -2
  97. wandb/sdk/launch/builder/docker_builder.py +1 -0
  98. wandb/sdk/launch/builder/kaniko_builder.py +2 -2
  99. wandb/sdk/launch/builder/noop.py +1 -0
  100. wandb/sdk/launch/create_job.py +18 -0
  101. wandb/sdk/launch/environment/abstract.py +1 -0
  102. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  103. wandb/sdk/launch/environment/local_environment.py +1 -0
  104. wandb/sdk/launch/loader.py +1 -0
  105. wandb/sdk/launch/registry/abstract.py +1 -0
  106. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  107. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  108. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  109. wandb/sdk/launch/registry/local_registry.py +1 -0
  110. wandb/sdk/launch/runner/abstract.py +1 -0
  111. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  112. wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
  113. wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
  114. wandb/sdk/launch/sweeps/scheduler.py +4 -3
  115. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  116. wandb/sdk/launch/sweeps/utils.py +3 -3
  117. wandb/sdk/launch/utils.py +3 -3
  118. wandb/sdk/lib/fsm.py +8 -12
  119. wandb/sdk/lib/gitlib.py +4 -4
  120. wandb/sdk/lib/import_hooks.py +1 -1
  121. wandb/sdk/lib/lazyloader.py +0 -1
  122. wandb/sdk/lib/proto_util.py +1 -1
  123. wandb/sdk/lib/redirect.py +19 -14
  124. wandb/sdk/lib/retry.py +3 -2
  125. wandb/sdk/lib/tracelog.py +1 -1
  126. wandb/sdk/service/service.py +17 -15
  127. wandb/sdk/verify/verify.py +2 -1
  128. wandb/sdk/wandb_manager.py +2 -2
  129. wandb/sdk/wandb_require.py +5 -0
  130. wandb/sdk/wandb_run.py +25 -20
  131. wandb/sdk/wandb_settings.py +0 -1
  132. wandb/sdk/wandb_setup.py +1 -1
  133. wandb/sklearn/__init__.py +1 -0
  134. wandb/sklearn/plot/__init__.py +1 -0
  135. wandb/sklearn/plot/classifier.py +7 -6
  136. wandb/sklearn/plot/clusterer.py +2 -1
  137. wandb/sklearn/plot/regressor.py +1 -0
  138. wandb/sklearn/plot/shared.py +1 -0
  139. wandb/sklearn/utils.py +1 -0
  140. wandb/testing/relay.py +4 -4
  141. wandb/trigger.py +1 -0
  142. wandb/util.py +40 -17
  143. wandb/wandb_controller.py +2 -3
  144. wandb/wandb_torch.py +1 -2
  145. {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/METADATA +68 -69
  146. {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/RECORD +149 -150
  147. {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/WHEEL +1 -2
  148. wandb/bin/apple_gpu_stats +0 -0
  149. wandb-0.16.6.dist-info/top_level.txt +0 -1
  150. {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/entry_points.txt +0 -0
  151. {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  Arguments can come from a launch spec or call to wandb launch.
4
4
  """
5
+
5
6
  import enum
6
7
  import logging
7
8
  import os
@@ -120,6 +121,7 @@ class LaunchProject:
120
121
  self.override_args: List[str] = overrides.get("args", [])
121
122
  self.override_config: Dict[str, Any] = overrides.get("run_config", {})
122
123
  self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
124
+ self.override_files: Dict[str, Any] = overrides.get("files", {})
123
125
  self.override_entrypoint: Optional[EntryPoint] = None
124
126
  self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
125
127
  self.deps_type: Optional[str] = None
@@ -128,9 +130,9 @@ class LaunchProject:
128
130
  self._queue_name: Optional[str] = None
129
131
  self._queue_entity: Optional[str] = None
130
132
  self._run_queue_item_id: Optional[str] = None
131
- self._entry_point: Optional[
132
- EntryPoint
133
- ] = None # todo: keep multiple entrypoint support?
133
+ self._entry_point: Optional[EntryPoint] = (
134
+ None # todo: keep multiple entrypoint support?
135
+ )
134
136
 
135
137
  override_entrypoint = overrides.get("entry_point")
136
138
  if override_entrypoint:
@@ -402,7 +404,9 @@ class LaunchProject:
402
404
  _logger.debug("")
403
405
  return self.docker_image
404
406
  else:
405
- raise LaunchError("Unknown source type when determing image source string")
407
+ raise LaunchError(
408
+ "Unknown source type when determining image source string"
409
+ )
406
410
 
407
411
  def _ensure_not_docker_image_and_local_process(self) -> None:
408
412
  """Ensure that docker image is not specified with local-process resource runner.
@@ -1,4 +1,5 @@
1
1
  """Implementation of launch agent."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import os
@@ -240,7 +241,7 @@ class LaunchAgent:
240
241
  """Determine whether a job/runSpec is a sweep scheduler."""
241
242
  if not run_spec:
242
243
  self._internal_logger.debug(
243
- "Recieved runSpec in _is_scheduler_job that was empty"
244
+ "Received runSpec in _is_scheduler_job that was empty"
244
245
  )
245
246
 
246
247
  if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
@@ -80,17 +80,7 @@ class RegistryConfig(BaseModel):
80
80
  @validator("uri") # type: ignore
81
81
  @classmethod
82
82
  def validate_uri(cls, uri: str) -> str:
83
- for regex in [
84
- GCP_ARTIFACT_REGISTRY_URI_REGEX,
85
- AZURE_CONTAINER_REGISTRY_URI_REGEX,
86
- ELASTIC_CONTAINER_REGISTRY_URI_REGEX,
87
- ]:
88
- if regex.match(uri):
89
- return uri
90
- raise ValueError(
91
- "Invalid uri. URI must be a repository URI for an "
92
- "ECR, ACR, or GCP Artifact Registry."
93
- )
83
+ return validate_registry_uri(uri)
94
84
 
95
85
 
96
86
  class EnvironmentConfig(BaseModel):
@@ -186,6 +176,14 @@ class BuilderConfig(BaseModel):
186
176
  """Right now there are no required fields for docker builds."""
187
177
  return values
188
178
 
179
+ @validator("destination") # type: ignore
180
+ @classmethod
181
+ def validate_destination(cls, destination: Optional[str]) -> Optional[str]:
182
+ """Validate that the destination is a valid container registry URI."""
183
+ if destination is None:
184
+ return None
185
+ return validate_registry_uri(destination)
186
+
189
187
 
190
188
  class AgentConfig(BaseModel):
191
189
  """Configuration for the Launch agent."""
@@ -236,3 +234,66 @@ class AgentConfig(BaseModel):
236
234
 
237
235
  class Config:
238
236
  extra = "forbid"
237
+
238
+
239
+ def validate_registry_uri(uri: str) -> str:
240
+ """Validate that the registry URI is a valid container registry URI.
241
+
242
+ The URI should resolve to an image name in a container registry. The recognized
243
+ formats are for ECR, ACR, and GCP Artifact Registry. If the URI does not match
244
+ any of these formats, a warning is printed indicating the registry type is not
245
+ recognized and the agent can't guarantee that images can be pushed.
246
+
247
+ If the format is recognized but does not resolve to an image name, an
248
+ error is raised. For example, if the URI is an ECR URI but does not include
249
+ an image name or includes a tag as well as an image name, an error is raised.
250
+ """
251
+ tag_msg = (
252
+ "Destination for built images may not include a tag, but the URI provided "
253
+ "includes the suffix '{tag}'. Please remove the tag and try again. The agent "
254
+ "will automatically tag each image with a unique hash of the source code."
255
+ )
256
+ if uri.startswith("https://"):
257
+ uri = uri[8:]
258
+
259
+ match = GCP_ARTIFACT_REGISTRY_URI_REGEX.match(uri)
260
+ if match:
261
+ if match.group("tag"):
262
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
263
+ if not match.group("image_name"):
264
+ raise ValueError(
265
+ "An image name must be specified in the URI for a GCP Artifact Registry. "
266
+ "Please provide a uri with the format "
267
+ "'https://<region>-docker.pkg.dev/<project>/<repository>/<image>'."
268
+ )
269
+ return uri
270
+
271
+ match = AZURE_CONTAINER_REGISTRY_URI_REGEX.match(uri)
272
+ if match:
273
+ if match.group("tag"):
274
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
275
+ if not match.group("repository"):
276
+ raise ValueError(
277
+ "A repository name must be specified in the URI for an "
278
+ "Azure Container Registry. Please provide a uri with the format "
279
+ "'https://<registry-name>.azurecr.io/<repository>'."
280
+ )
281
+ return uri
282
+
283
+ match = ELASTIC_CONTAINER_REGISTRY_URI_REGEX.match(uri)
284
+ if match:
285
+ if match.group("tag"):
286
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
287
+ if not match.group("repository"):
288
+ raise ValueError(
289
+ "A repository name must be specified in the URI for an "
290
+ "Elastic Container Registry. Please provide a uri with the format "
291
+ "'https://<account-id>.dkr.ecr.<region>.amazonaws.com/<repository>'."
292
+ )
293
+ return uri
294
+
295
+ wandb.termwarn(
296
+ f"Unable to recognize registry type in URI {uri}. You are responsible "
297
+ "for ensuring the agent can push images to this registry."
298
+ )
299
+ return uri
@@ -1,4 +1,5 @@
1
1
  """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
  from typing import TYPE_CHECKING, Any, Dict, Optional
4
5
 
@@ -34,7 +35,7 @@ class AbstractBuilder(ABC):
34
35
  verify: Whether to verify the functionality of the builder.
35
36
 
36
37
  Raises:
37
- LaunchError: If the builder cannot be intialized or verified.
38
+ LaunchError: If the builder cannot be initialized or verified.
38
39
  """
39
40
  raise NotImplementedError
40
41
 
@@ -65,7 +65,7 @@ def registry_from_uri(uri: str) -> AbstractRegistry:
65
65
  it as an AWS Elastic Container Registry. If the uri contains
66
66
  `-docker.pkg.dev`, we classify it as a Google Artifact Registry.
67
67
 
68
- This function will attempt to load the approriate cloud helpers for the
68
+ This function will attempt to load the appropriate cloud helpers for the
69
69
 
70
70
  `https://` prefix is optional for all of the above.
71
71
 
@@ -237,7 +237,11 @@ def get_base_setup(
237
237
 
238
238
  CPU version is built on python, Accelerator version is built on user provided.
239
239
  """
240
- python_base_image = f"python:{py_version}-buster"
240
+ minor = int(py_version.split(".")[1])
241
+ if minor < 12:
242
+ python_base_image = f"python:{py_version}-buster"
243
+ else:
244
+ python_base_image = f"python:{py_version}-bookworm"
241
245
  if launch_project.accelerator_base_image:
242
246
  _logger.info(
243
247
  f"Using accelerator base image: {launch_project.accelerator_base_image}"
@@ -311,6 +315,11 @@ def get_env_vars_dict(
311
315
  _inject_wandb_config_env_vars(
312
316
  launch_project.override_config, env_vars, max_env_length
313
317
  )
318
+
319
+ _inject_file_overrides_env_vars(
320
+ launch_project.override_files, env_vars, max_env_length
321
+ )
322
+
314
323
  artifacts = {}
315
324
  # if we're spinning up a launch process from a job
316
325
  # we should tell the run to use that artifact
@@ -677,3 +686,21 @@ def _inject_wandb_config_env_vars(
677
686
  ]
678
687
  config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
679
688
  env_dict.update(config_chunks_dict)
689
+
690
+
691
+ def _inject_file_overrides_env_vars(
692
+ overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
693
+ ) -> None:
694
+ str_overrides = json.dumps(overrides)
695
+ if len(str_overrides) <= maximum_env_length:
696
+ env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
697
+ return
698
+
699
+ chunks = [
700
+ str_overrides[i : i + maximum_env_length]
701
+ for i in range(0, len(str_overrides), maximum_env_length)
702
+ ]
703
+ overrides_chunks_dict = {
704
+ f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
705
+ }
706
+ env_dict.update(overrides_chunks_dict)
@@ -1,4 +1,5 @@
1
1
  """Implementation of the docker builder."""
2
+
2
3
  import logging
3
4
  import os
4
5
  from typing import Any, Dict, Optional
@@ -286,7 +286,7 @@ class KanikoBuilder(AbstractBuilder):
286
286
  _, api_client = await get_kube_context_and_api_client(
287
287
  kubernetes, launch_project.resource_args
288
288
  )
289
- # TODO: use same client as kuberentes_runner.py
289
+ # TODO: use same client as kubernetes_runner.py
290
290
  batch_v1 = client.BatchV1Api(api_client)
291
291
  core_v1 = client.CoreV1Api(api_client)
292
292
 
@@ -522,7 +522,7 @@ class KanikoBuilder(AbstractBuilder):
522
522
  volume_mounts.append(
523
523
  {"name": "docker-config", "mountPath": "/kaniko/.docker/"}
524
524
  )
525
- # Kaniko doesn't want https:// at the begining of the image tag.
525
+ # Kaniko doesn't want https:// at the beginning of the image tag.
526
526
  destination = image_tag
527
527
  if destination.startswith("https://"):
528
528
  destination = destination.replace("https://", "")
@@ -1,4 +1,5 @@
1
1
  """NoOp builder implementation."""
2
+
2
3
  from typing import Any, Dict, Optional
3
4
 
4
5
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import re
4
5
  import sys
5
6
  import tempfile
6
7
  from typing import Any, Dict, List, Optional, Tuple
@@ -19,6 +20,9 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19
20
  _logger = logging.getLogger("wandb")
20
21
 
21
22
 
23
+ CODE_ARTIFACT_EXCLUDE_PATHS = ["wandb", ".git"]
24
+
25
+
22
26
  def create_job(
23
27
  path: str,
24
28
  job_type: str,
@@ -107,6 +111,13 @@ def _create_job(
107
111
  )
108
112
  return None, "", []
109
113
 
114
+ if runtime is not None:
115
+ if not re.match(r"^3\.\d+$", runtime):
116
+ wandb.termerror(
117
+ f"Runtime (-r, --runtime) must be a minor version of Python 3, "
118
+ f"e.g. 3.9 or 3.10, received {runtime}"
119
+ )
120
+ return None, "", []
110
121
  aliases = aliases or []
111
122
  tempdir = tempfile.TemporaryDirectory()
112
123
  try:
@@ -436,6 +447,13 @@ def _make_code_artifact(
436
447
  wandb.termerror(f"Error adding to code artifact: {e}")
437
448
  return None
438
449
 
450
+ # Remove paths we don't want to include, if present
451
+ for item in CODE_ARTIFACT_EXCLUDE_PATHS:
452
+ try:
453
+ code_artifact.remove(item)
454
+ except FileNotFoundError:
455
+ pass
456
+
439
457
  res, _ = api.create_artifact(
440
458
  artifact_type_name="code",
441
459
  artifact_collection_name=artifact_name,
@@ -1,4 +1,5 @@
1
1
  """Abstract base class for environments."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
 
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Implementation of the GCP environment for wandb launch."""
2
+
2
3
  import logging
3
4
  import os
4
5
  import subprocess
@@ -1,4 +1,5 @@
1
1
  """Dummy local environment implementation. This is the default environment."""
2
+
2
3
  from typing import Any, Dict, Union
3
4
 
4
5
  from wandb.sdk.launch.errors import LaunchError
@@ -1,4 +1,5 @@
1
1
  """Utilities for the agent."""
2
+
2
3
  from typing import Any, Dict, Optional
3
4
 
4
5
  import wandb
@@ -1,4 +1,5 @@
1
1
  """Abstract base class for registries."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
  from typing import Tuple
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Implementation of AzureContainerRegistry class."""
2
+
2
3
  import re
3
4
  from typing import TYPE_CHECKING, Optional, Tuple
4
5
 
@@ -1,4 +1,5 @@
1
1
  """Implementation of Elastic Container Registry class for wandb launch."""
2
+
2
3
  import base64
3
4
  import logging
4
5
  from typing import Dict, Optional, Tuple
@@ -1,4 +1,5 @@
1
1
  """Implementation of Google Artifact Registry for wandb launch."""
2
+
2
3
  import logging
3
4
  from typing import Optional, Tuple
4
5
 
@@ -210,7 +211,7 @@ class GoogleArtifactRegistry(AbstractRegistry):
210
211
  for image in await list_images(request={"parent": parent}):
211
212
  if tag in image.tags:
212
213
  return True
213
- except google.api_core.exceptions.NotFound as e:
214
+ except google.api_core.exceptions.NotFound as e: # type: ignore[attr-defined]
214
215
  raise LaunchError(
215
216
  f"The Google Artifact Registry repository {self.repository} "
216
217
  f"does not exist. Please create it or modify your registry configuration."
@@ -1,4 +1,5 @@
1
1
  """Local registry implementation."""
2
+
2
3
  import logging
3
4
  from typing import Tuple
4
5
 
@@ -3,6 +3,7 @@
3
3
  This class defines the interface that the W&B launch runner uses to manage the lifecycle
4
4
  of runs launched in different environments (e.g. runs launched locally or in a cluster).
5
5
  """
6
+
6
7
  import logging
7
8
  import os
8
9
  import subprocess
@@ -1,4 +1,5 @@
1
1
  """Monitors kubernetes resources managed by the launch agent."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import sys
@@ -1,4 +1,5 @@
1
1
  """Implementation of KubernetesRunner class for wandb launch."""
2
+
2
3
  import asyncio
3
4
  import base64
4
5
  import datetime
@@ -539,9 +540,9 @@ class KubernetesRunner(AbstractRunner):
539
540
  WANDB_K8S_LABEL_MONITOR,
540
541
  LaunchAgent.name(),
541
542
  )
542
- resource_args["metadata"]["labels"][
543
- WANDB_K8S_LABEL_AGENT
544
- ] = LaunchAgent.name()
543
+ resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
544
+ LaunchAgent.name()
545
+ )
545
546
 
546
547
  overrides = {}
547
548
  if launch_project.override_args:
@@ -1,4 +1,5 @@
1
1
  """Implementation of the SageMakerRunner class."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  from typing import Any, Dict, List, Optional, cast
@@ -324,16 +325,16 @@ def build_sagemaker_args(
324
325
  sagemaker_args["TrainingJobName"] = training_job_name
325
326
  entry_cmd = entry_point.command if entry_point else []
326
327
 
327
- sagemaker_args[
328
- "AlgorithmSpecification"
329
- ] = merge_image_uri_with_algorithm_specification(
330
- given_sagemaker_args.get(
331
- "AlgorithmSpecification",
332
- given_sagemaker_args.get("algorithm_specification"),
333
- ),
334
- image_uri,
335
- entry_cmd,
336
- args,
328
+ sagemaker_args["AlgorithmSpecification"] = (
329
+ merge_image_uri_with_algorithm_specification(
330
+ given_sagemaker_args.get(
331
+ "AlgorithmSpecification",
332
+ given_sagemaker_args.get("algorithm_specification"),
333
+ ),
334
+ image_uri,
335
+ entry_cmd,
336
+ args,
337
+ )
337
338
  )
338
339
 
339
340
  sagemaker_args["RoleArn"] = role_arn
@@ -1,4 +1,5 @@
1
1
  """Abstract Scheduler class."""
2
+
2
3
  import asyncio
3
4
  import base64
4
5
  import copy
@@ -407,7 +408,7 @@ class Scheduler(ABC):
407
408
  return count
408
409
 
409
410
  def _try_load_executable(self) -> bool:
410
- """Check existance of valid executable for a run.
411
+ """Check existence of valid executable for a run.
411
412
 
412
413
  logs and returns False when job is unreachable
413
414
  """
@@ -422,7 +423,7 @@ class Scheduler(ABC):
422
423
  return False
423
424
  return True
424
425
  elif self._kwargs.get("image_uri"):
425
- # TODO(gst): check docker existance? Use registry in launch config?
426
+ # TODO(gst): check docker existence? Use registry in launch config?
426
427
  return True
427
428
  else:
428
429
  return False
@@ -610,7 +611,7 @@ class Scheduler(ABC):
610
611
  f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
611
612
  )
612
613
  run_state = RunState.FAILED
613
- else: # first time we get unknwon state
614
+ else: # first time we get unknown state
614
615
  run_state = RunState.UNKNOWN
615
616
  except (AttributeError, ValueError):
616
617
  wandb.termwarn(
@@ -1,4 +1,5 @@
1
1
  """Scheduler for classic wandb Sweeps."""
2
+
2
3
  import logging
3
4
  from pprint import pformat as pf
4
5
  from typing import Any, Dict, List, Optional
@@ -58,7 +59,7 @@ class SweepScheduler(Scheduler):
58
59
  return None
59
60
 
60
61
  def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
61
- """Helper to recieve sweep command from backend."""
62
+ """Helper to receive sweep command from backend."""
62
63
  # AgentHeartbeat wants a Dict of runs which are running or queued
63
64
  _run_states: Dict[str, bool] = {}
64
65
  for run_id, run in self._yield_runs():
@@ -217,7 +217,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
217
217
  flags: List[str] = []
218
218
  # (2) flags without hyphens (e.g. foo=bar)
219
219
  flags_no_hyphens: List[str] = []
220
- # (3) flags with false booleans ommited (e.g. --foo)
220
+ # (3) flags with false booleans omitted (e.g. --foo)
221
221
  flags_no_booleans: List[str] = []
222
222
  # (4) flags as a dictionary (used for constructing a json)
223
223
  flags_dict: Dict[str, Any] = {}
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
257
257
  """Use args dict from create_sweep_command_args to construct entrypoint.
258
258
 
259
259
  If replace is True, remove macros from entrypoint, fill them in with args
260
- and then return the args in seperate return value.
260
+ and then return the args in separate return value.
261
261
  """
262
262
  if not command:
263
263
  return None, None
@@ -296,7 +296,7 @@ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
296
296
 
297
297
 
298
298
  def get_previous_args(
299
- run_spec: Dict[str, Any]
299
+ run_spec: Dict[str, Any],
300
300
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
301
301
  """Parse through previous scheduler run_spec.
302
302
 
wandb/sdk/launch/utils.py CHANGED
@@ -57,15 +57,15 @@ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
57
57
  MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
58
58
 
59
59
  AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
60
- r"(?:https://)?([\w]+)\.azurecr\.io/([\w\-]+):?(.*)"
60
+ r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
61
61
  )
62
62
 
63
63
  ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
64
- r"^(?P<account>.*)\.dkr\.ecr\.(?P<region>.*)\.amazonaws\.com/(?P<repository>.*)/?$"
64
+ r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\w-]+):?(?P<tag>.*)$"
65
65
  )
66
66
 
67
67
  GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
68
- r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/(?P<image_name>[\w-]+)$",
68
+ r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
69
69
  re.IGNORECASE,
70
70
  )
71
71
 
wandb/sdk/lib/fsm.py CHANGED
@@ -52,43 +52,39 @@ T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
52
52
  @runtime_checkable
53
53
  class FsmStateCheck(Protocol[T_FsmInputs]):
54
54
  @abstractmethod
55
- def on_check(self, inputs: T_FsmInputs) -> None:
56
- ... # pragma: no cover
55
+ def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
57
56
 
58
57
 
59
58
  @runtime_checkable
60
59
  class FsmStateOutput(Protocol[T_FsmInputs]):
61
60
  @abstractmethod
62
- def on_state(self, inputs: T_FsmInputs) -> None:
63
- ... # pragma: no cover
61
+ def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
64
62
 
65
63
 
66
64
  @runtime_checkable
67
65
  class FsmStateEnter(Protocol[T_FsmInputs]):
68
66
  @abstractmethod
69
- def on_enter(self, inputs: T_FsmInputs) -> None:
70
- ... # pragma: no cover
67
+ def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
71
68
 
72
69
 
73
70
  @runtime_checkable
74
71
  class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
75
72
  @abstractmethod
76
- def on_enter(self, inputs: T_FsmInputs, context: T_FsmContext_contra) -> None:
77
- ... # pragma: no cover
73
+ def on_enter(
74
+ self, inputs: T_FsmInputs, context: T_FsmContext_contra
75
+ ) -> None: ... # pragma: no cover
78
76
 
79
77
 
80
78
  @runtime_checkable
81
79
  class FsmStateStay(Protocol[T_FsmInputs]):
82
80
  @abstractmethod
83
- def on_stay(self, inputs: T_FsmInputs) -> None:
84
- ... # pragma: no cover
81
+ def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
85
82
 
86
83
 
87
84
  @runtime_checkable
88
85
  class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
89
86
  @abstractmethod
90
- def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov:
91
- ... # pragma: no cover
87
+ def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
92
88
 
93
89
 
94
90
  # It would be nice if python provided optional protocol members, but it doesnt as described here:
wandb/sdk/lib/gitlib.py CHANGED
@@ -14,7 +14,7 @@ try:
14
14
  Repo,
15
15
  )
16
16
  except ImportError:
17
- Repo = None
17
+ Repo = None # type: ignore
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from git import Repo
@@ -121,7 +121,7 @@ class GitRepo:
121
121
  # TODO: Saw a user getting a Unicode decode error when parsing refs,
122
122
  # more details on implementing a real fix in [WB-4064]
123
123
  try:
124
- if len(self.repo.refs) > 0:
124
+ if len(self.repo.refs) > 0: # type: ignore[arg-type]
125
125
  return self.repo.head.commit.hexsha
126
126
  else:
127
127
  return self.repo.git.show_ref("--head").split(" ")[0]
@@ -140,7 +140,7 @@ class GitRepo:
140
140
  if not self.repo:
141
141
  return None
142
142
  try:
143
- return self.repo.remotes[self.remote_name]
143
+ return self.repo.remotes[self.remote_name] # type: ignore[index]
144
144
  except IndexError:
145
145
  return None
146
146
 
@@ -200,7 +200,7 @@ class GitRepo:
200
200
  possible_relatives.append(tracking_branch.commit)
201
201
 
202
202
  if not possible_relatives:
203
- for branch in self.repo.branches:
203
+ for branch in self.repo.branches: # type: ignore[attr-defined]
204
204
  tracking_branch = branch.tracking_branch()
205
205
  if tracking_branch is not None:
206
206
  possible_relatives.append(tracking_branch.commit)
@@ -143,7 +143,7 @@ class _ImportHookChainedLoader:
143
143
  # None, so handle None as well. The module may not support attribute
144
144
  # assignment, in which case we simply skip it. Note that we also deal
145
145
  # with __loader__ not existing at all. This is to future proof things
146
- # due to proposal to remove the attribue as described in the GitHub
146
+ # due to proposal to remove the attribute as described in the GitHub
147
147
  # issue at https://github.com/python/cpython/issues/77458. Also prior
148
148
  # to Python 3.3, the __loader__ attribute was only set if a custom
149
149
  # module loader was used. It isn't clear whether the attribute still
@@ -1,6 +1,5 @@
1
1
  """module lazyloader."""
2
2
 
3
-
4
3
  import importlib
5
4
  import sys
6
5
  import types
@@ -29,7 +29,7 @@ def _assign_end_offset(record: "pb.Record", end_offset: int) -> None:
29
29
 
30
30
 
31
31
  def proto_encode_to_dict(
32
- pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"]
32
+ pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"],
33
33
  ) -> Dict[int, Any]:
34
34
  data: Dict[int, Any] = dict()
35
35
  fields = pb_obj.ListFields()