wandb 0.16.6__py3-none-any.whl → 0.17.0rc2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/public/api.py +1 -0
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +1 -0
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +1 -0
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +2 -2
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/keras/__init__.py +1 -0
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plots/precision_recall.py +1 -1
- wandb/plots/roc.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/wandb_internal_codegen.py +0 -25
- wandb/sdk/artifacts/artifact.py +16 -4
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +5 -2
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
- wandb/sdk/artifacts/storage_policy.py +1 -0
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +55 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +5 -4
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/_project_spec.py +8 -4
- wandb/sdk/launch/agent/agent.py +2 -1
- wandb/sdk/launch/agent/config.py +72 -11
- wandb/sdk/launch/builder/abstract.py +2 -1
- wandb/sdk/launch/builder/build.py +29 -2
- wandb/sdk/launch/builder/docker_builder.py +1 -0
- wandb/sdk/launch/builder/kaniko_builder.py +2 -2
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/create_job.py +18 -0
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +1 -1
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +17 -15
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +25 -20
- wandb/sdk/wandb_settings.py +0 -1
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +7 -6
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +40 -17
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/METADATA +68 -69
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/RECORD +149 -150
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
Arguments can come from a launch spec or call to wandb launch.
|
4
4
|
"""
|
5
|
+
|
5
6
|
import enum
|
6
7
|
import logging
|
7
8
|
import os
|
@@ -120,6 +121,7 @@ class LaunchProject:
|
|
120
121
|
self.override_args: List[str] = overrides.get("args", [])
|
121
122
|
self.override_config: Dict[str, Any] = overrides.get("run_config", {})
|
122
123
|
self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
|
124
|
+
self.override_files: Dict[str, Any] = overrides.get("files", {})
|
123
125
|
self.override_entrypoint: Optional[EntryPoint] = None
|
124
126
|
self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
|
125
127
|
self.deps_type: Optional[str] = None
|
@@ -128,9 +130,9 @@ class LaunchProject:
|
|
128
130
|
self._queue_name: Optional[str] = None
|
129
131
|
self._queue_entity: Optional[str] = None
|
130
132
|
self._run_queue_item_id: Optional[str] = None
|
131
|
-
self._entry_point: Optional[
|
132
|
-
|
133
|
-
|
133
|
+
self._entry_point: Optional[EntryPoint] = (
|
134
|
+
None # todo: keep multiple entrypoint support?
|
135
|
+
)
|
134
136
|
|
135
137
|
override_entrypoint = overrides.get("entry_point")
|
136
138
|
if override_entrypoint:
|
@@ -402,7 +404,9 @@ class LaunchProject:
|
|
402
404
|
_logger.debug("")
|
403
405
|
return self.docker_image
|
404
406
|
else:
|
405
|
-
raise LaunchError(
|
407
|
+
raise LaunchError(
|
408
|
+
"Unknown source type when determining image source string"
|
409
|
+
)
|
406
410
|
|
407
411
|
def _ensure_not_docker_image_and_local_process(self) -> None:
|
408
412
|
"""Ensure that docker image is not specified with local-process resource runner.
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of launch agent."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
import os
|
@@ -240,7 +241,7 @@ class LaunchAgent:
|
|
240
241
|
"""Determine whether a job/runSpec is a sweep scheduler."""
|
241
242
|
if not run_spec:
|
242
243
|
self._internal_logger.debug(
|
243
|
-
"
|
244
|
+
"Received runSpec in _is_scheduler_job that was empty"
|
244
245
|
)
|
245
246
|
|
246
247
|
if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
|
wandb/sdk/launch/agent/config.py
CHANGED
@@ -80,17 +80,7 @@ class RegistryConfig(BaseModel):
|
|
80
80
|
@validator("uri") # type: ignore
|
81
81
|
@classmethod
|
82
82
|
def validate_uri(cls, uri: str) -> str:
|
83
|
-
|
84
|
-
GCP_ARTIFACT_REGISTRY_URI_REGEX,
|
85
|
-
AZURE_CONTAINER_REGISTRY_URI_REGEX,
|
86
|
-
ELASTIC_CONTAINER_REGISTRY_URI_REGEX,
|
87
|
-
]:
|
88
|
-
if regex.match(uri):
|
89
|
-
return uri
|
90
|
-
raise ValueError(
|
91
|
-
"Invalid uri. URI must be a repository URI for an "
|
92
|
-
"ECR, ACR, or GCP Artifact Registry."
|
93
|
-
)
|
83
|
+
return validate_registry_uri(uri)
|
94
84
|
|
95
85
|
|
96
86
|
class EnvironmentConfig(BaseModel):
|
@@ -186,6 +176,14 @@ class BuilderConfig(BaseModel):
|
|
186
176
|
"""Right now there are no required fields for docker builds."""
|
187
177
|
return values
|
188
178
|
|
179
|
+
@validator("destination") # type: ignore
|
180
|
+
@classmethod
|
181
|
+
def validate_destination(cls, destination: Optional[str]) -> Optional[str]:
|
182
|
+
"""Validate that the destination is a valid container registry URI."""
|
183
|
+
if destination is None:
|
184
|
+
return None
|
185
|
+
return validate_registry_uri(destination)
|
186
|
+
|
189
187
|
|
190
188
|
class AgentConfig(BaseModel):
|
191
189
|
"""Configuration for the Launch agent."""
|
@@ -236,3 +234,66 @@ class AgentConfig(BaseModel):
|
|
236
234
|
|
237
235
|
class Config:
|
238
236
|
extra = "forbid"
|
237
|
+
|
238
|
+
|
239
|
+
def validate_registry_uri(uri: str) -> str:
|
240
|
+
"""Validate that the registry URI is a valid container registry URI.
|
241
|
+
|
242
|
+
The URI should resolve to an image name in a container registry. The recognized
|
243
|
+
formats are for ECR, ACR, and GCP Artifact Registry. If the URI does not match
|
244
|
+
any of these formats, a warning is printed indicating the registry type is not
|
245
|
+
recognized and the agent can't guarantee that images can be pushed.
|
246
|
+
|
247
|
+
If the format is recognized but does not resolve to an image name, an
|
248
|
+
error is raised. For example, if the URI is an ECR URI but does not include
|
249
|
+
an image name or includes a tag as well as an image name, an error is raised.
|
250
|
+
"""
|
251
|
+
tag_msg = (
|
252
|
+
"Destination for built images may not include a tag, but the URI provided "
|
253
|
+
"includes the suffix '{tag}'. Please remove the tag and try again. The agent "
|
254
|
+
"will automatically tag each image with a unique hash of the source code."
|
255
|
+
)
|
256
|
+
if uri.startswith("https://"):
|
257
|
+
uri = uri[8:]
|
258
|
+
|
259
|
+
match = GCP_ARTIFACT_REGISTRY_URI_REGEX.match(uri)
|
260
|
+
if match:
|
261
|
+
if match.group("tag"):
|
262
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
263
|
+
if not match.group("image_name"):
|
264
|
+
raise ValueError(
|
265
|
+
"An image name must be specified in the URI for a GCP Artifact Registry. "
|
266
|
+
"Please provide a uri with the format "
|
267
|
+
"'https://<region>-docker.pkg.dev/<project>/<repository>/<image>'."
|
268
|
+
)
|
269
|
+
return uri
|
270
|
+
|
271
|
+
match = AZURE_CONTAINER_REGISTRY_URI_REGEX.match(uri)
|
272
|
+
if match:
|
273
|
+
if match.group("tag"):
|
274
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
275
|
+
if not match.group("repository"):
|
276
|
+
raise ValueError(
|
277
|
+
"A repository name must be specified in the URI for an "
|
278
|
+
"Azure Container Registry. Please provide a uri with the format "
|
279
|
+
"'https://<registry-name>.azurecr.io/<repository>'."
|
280
|
+
)
|
281
|
+
return uri
|
282
|
+
|
283
|
+
match = ELASTIC_CONTAINER_REGISTRY_URI_REGEX.match(uri)
|
284
|
+
if match:
|
285
|
+
if match.group("tag"):
|
286
|
+
raise ValueError(tag_msg.format(tag=match.group("tag")))
|
287
|
+
if not match.group("repository"):
|
288
|
+
raise ValueError(
|
289
|
+
"A repository name must be specified in the URI for an "
|
290
|
+
"Elastic Container Registry. Please provide a uri with the format "
|
291
|
+
"'https://<account-id>.dkr.ecr.<region>.amazonaws.com/<repository>'."
|
292
|
+
)
|
293
|
+
return uri
|
294
|
+
|
295
|
+
wandb.termwarn(
|
296
|
+
f"Unable to recognize registry type in URI {uri}. You are responsible "
|
297
|
+
"for ensuring the agent can push images to this registry."
|
298
|
+
)
|
299
|
+
return uri
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Abstract plugin class defining the interface needed to build container images for W&B Launch."""
|
2
|
+
|
2
3
|
from abc import ABC, abstractmethod
|
3
4
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
4
5
|
|
@@ -34,7 +35,7 @@ class AbstractBuilder(ABC):
|
|
34
35
|
verify: Whether to verify the functionality of the builder.
|
35
36
|
|
36
37
|
Raises:
|
37
|
-
LaunchError: If the builder cannot be
|
38
|
+
LaunchError: If the builder cannot be initialized or verified.
|
38
39
|
"""
|
39
40
|
raise NotImplementedError
|
40
41
|
|
@@ -65,7 +65,7 @@ def registry_from_uri(uri: str) -> AbstractRegistry:
|
|
65
65
|
it as an AWS Elastic Container Registry. If the uri contains
|
66
66
|
`-docker.pkg.dev`, we classify it as a Google Artifact Registry.
|
67
67
|
|
68
|
-
This function will attempt to load the
|
68
|
+
This function will attempt to load the appropriate cloud helpers for the
|
69
69
|
|
70
70
|
`https://` prefix is optional for all of the above.
|
71
71
|
|
@@ -237,7 +237,11 @@ def get_base_setup(
|
|
237
237
|
|
238
238
|
CPU version is built on python, Accelerator version is built on user provided.
|
239
239
|
"""
|
240
|
-
|
240
|
+
minor = int(py_version.split(".")[1])
|
241
|
+
if minor < 12:
|
242
|
+
python_base_image = f"python:{py_version}-buster"
|
243
|
+
else:
|
244
|
+
python_base_image = f"python:{py_version}-bookworm"
|
241
245
|
if launch_project.accelerator_base_image:
|
242
246
|
_logger.info(
|
243
247
|
f"Using accelerator base image: {launch_project.accelerator_base_image}"
|
@@ -311,6 +315,11 @@ def get_env_vars_dict(
|
|
311
315
|
_inject_wandb_config_env_vars(
|
312
316
|
launch_project.override_config, env_vars, max_env_length
|
313
317
|
)
|
318
|
+
|
319
|
+
_inject_file_overrides_env_vars(
|
320
|
+
launch_project.override_files, env_vars, max_env_length
|
321
|
+
)
|
322
|
+
|
314
323
|
artifacts = {}
|
315
324
|
# if we're spinning up a launch process from a job
|
316
325
|
# we should tell the run to use that artifact
|
@@ -677,3 +686,21 @@ def _inject_wandb_config_env_vars(
|
|
677
686
|
]
|
678
687
|
config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
|
679
688
|
env_dict.update(config_chunks_dict)
|
689
|
+
|
690
|
+
|
691
|
+
def _inject_file_overrides_env_vars(
|
692
|
+
overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
|
693
|
+
) -> None:
|
694
|
+
str_overrides = json.dumps(overrides)
|
695
|
+
if len(str_overrides) <= maximum_env_length:
|
696
|
+
env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
|
697
|
+
return
|
698
|
+
|
699
|
+
chunks = [
|
700
|
+
str_overrides[i : i + maximum_env_length]
|
701
|
+
for i in range(0, len(str_overrides), maximum_env_length)
|
702
|
+
]
|
703
|
+
overrides_chunks_dict = {
|
704
|
+
f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
|
705
|
+
}
|
706
|
+
env_dict.update(overrides_chunks_dict)
|
@@ -286,7 +286,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
286
286
|
_, api_client = await get_kube_context_and_api_client(
|
287
287
|
kubernetes, launch_project.resource_args
|
288
288
|
)
|
289
|
-
# TODO: use same client as
|
289
|
+
# TODO: use same client as kubernetes_runner.py
|
290
290
|
batch_v1 = client.BatchV1Api(api_client)
|
291
291
|
core_v1 = client.CoreV1Api(api_client)
|
292
292
|
|
@@ -522,7 +522,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
522
522
|
volume_mounts.append(
|
523
523
|
{"name": "docker-config", "mountPath": "/kaniko/.docker/"}
|
524
524
|
)
|
525
|
-
# Kaniko doesn't want https:// at the
|
525
|
+
# Kaniko doesn't want https:// at the beginning of the image tag.
|
526
526
|
destination = image_tag
|
527
527
|
if destination.startswith("https://"):
|
528
528
|
destination = destination.replace("https://", "")
|
wandb/sdk/launch/builder/noop.py
CHANGED
wandb/sdk/launch/create_job.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import re
|
4
5
|
import sys
|
5
6
|
import tempfile
|
6
7
|
from typing import Any, Dict, List, Optional, Tuple
|
@@ -19,6 +20,9 @@ logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
19
20
|
_logger = logging.getLogger("wandb")
|
20
21
|
|
21
22
|
|
23
|
+
CODE_ARTIFACT_EXCLUDE_PATHS = ["wandb", ".git"]
|
24
|
+
|
25
|
+
|
22
26
|
def create_job(
|
23
27
|
path: str,
|
24
28
|
job_type: str,
|
@@ -107,6 +111,13 @@ def _create_job(
|
|
107
111
|
)
|
108
112
|
return None, "", []
|
109
113
|
|
114
|
+
if runtime is not None:
|
115
|
+
if not re.match(r"^3\.\d+$", runtime):
|
116
|
+
wandb.termerror(
|
117
|
+
f"Runtime (-r, --runtime) must be a minor version of Python 3, "
|
118
|
+
f"e.g. 3.9 or 3.10, received {runtime}"
|
119
|
+
)
|
120
|
+
return None, "", []
|
110
121
|
aliases = aliases or []
|
111
122
|
tempdir = tempfile.TemporaryDirectory()
|
112
123
|
try:
|
@@ -436,6 +447,13 @@ def _make_code_artifact(
|
|
436
447
|
wandb.termerror(f"Error adding to code artifact: {e}")
|
437
448
|
return None
|
438
449
|
|
450
|
+
# Remove paths we don't want to include, if present
|
451
|
+
for item in CODE_ARTIFACT_EXCLUDE_PATHS:
|
452
|
+
try:
|
453
|
+
code_artifact.remove(item)
|
454
|
+
except FileNotFoundError:
|
455
|
+
pass
|
456
|
+
|
439
457
|
res, _ = api.create_artifact(
|
440
458
|
artifact_type_name="code",
|
441
459
|
artifact_collection_name=artifact_name,
|
wandb/sdk/launch/loader.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of Google Artifact Registry for wandb launch."""
|
2
|
+
|
2
3
|
import logging
|
3
4
|
from typing import Optional, Tuple
|
4
5
|
|
@@ -210,7 +211,7 @@ class GoogleArtifactRegistry(AbstractRegistry):
|
|
210
211
|
for image in await list_images(request={"parent": parent}):
|
211
212
|
if tag in image.tags:
|
212
213
|
return True
|
213
|
-
except google.api_core.exceptions.NotFound as e:
|
214
|
+
except google.api_core.exceptions.NotFound as e: # type: ignore[attr-defined]
|
214
215
|
raise LaunchError(
|
215
216
|
f"The Google Artifact Registry repository {self.repository} "
|
216
217
|
f"does not exist. Please create it or modify your registry configuration."
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of KubernetesRunner class for wandb launch."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import base64
|
4
5
|
import datetime
|
@@ -539,9 +540,9 @@ class KubernetesRunner(AbstractRunner):
|
|
539
540
|
WANDB_K8S_LABEL_MONITOR,
|
540
541
|
LaunchAgent.name(),
|
541
542
|
)
|
542
|
-
resource_args["metadata"]["labels"][
|
543
|
-
|
544
|
-
|
543
|
+
resource_args["metadata"]["labels"][WANDB_K8S_LABEL_AGENT] = (
|
544
|
+
LaunchAgent.name()
|
545
|
+
)
|
545
546
|
|
546
547
|
overrides = {}
|
547
548
|
if launch_project.override_args:
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Implementation of the SageMakerRunner class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import logging
|
4
5
|
from typing import Any, Dict, List, Optional, cast
|
@@ -324,16 +325,16 @@ def build_sagemaker_args(
|
|
324
325
|
sagemaker_args["TrainingJobName"] = training_job_name
|
325
326
|
entry_cmd = entry_point.command if entry_point else []
|
326
327
|
|
327
|
-
sagemaker_args[
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
328
|
+
sagemaker_args["AlgorithmSpecification"] = (
|
329
|
+
merge_image_uri_with_algorithm_specification(
|
330
|
+
given_sagemaker_args.get(
|
331
|
+
"AlgorithmSpecification",
|
332
|
+
given_sagemaker_args.get("algorithm_specification"),
|
333
|
+
),
|
334
|
+
image_uri,
|
335
|
+
entry_cmd,
|
336
|
+
args,
|
337
|
+
)
|
337
338
|
)
|
338
339
|
|
339
340
|
sagemaker_args["RoleArn"] = role_arn
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Abstract Scheduler class."""
|
2
|
+
|
2
3
|
import asyncio
|
3
4
|
import base64
|
4
5
|
import copy
|
@@ -407,7 +408,7 @@ class Scheduler(ABC):
|
|
407
408
|
return count
|
408
409
|
|
409
410
|
def _try_load_executable(self) -> bool:
|
410
|
-
"""Check
|
411
|
+
"""Check existence of valid executable for a run.
|
411
412
|
|
412
413
|
logs and returns False when job is unreachable
|
413
414
|
"""
|
@@ -422,7 +423,7 @@ class Scheduler(ABC):
|
|
422
423
|
return False
|
423
424
|
return True
|
424
425
|
elif self._kwargs.get("image_uri"):
|
425
|
-
# TODO(gst): check docker
|
426
|
+
# TODO(gst): check docker existence? Use registry in launch config?
|
426
427
|
return True
|
427
428
|
else:
|
428
429
|
return False
|
@@ -610,7 +611,7 @@ class Scheduler(ABC):
|
|
610
611
|
f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
|
611
612
|
)
|
612
613
|
run_state = RunState.FAILED
|
613
|
-
else: # first time we get
|
614
|
+
else: # first time we get unknown state
|
614
615
|
run_state = RunState.UNKNOWN
|
615
616
|
except (AttributeError, ValueError):
|
616
617
|
wandb.termwarn(
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Scheduler for classic wandb Sweeps."""
|
2
|
+
|
2
3
|
import logging
|
3
4
|
from pprint import pformat as pf
|
4
5
|
from typing import Any, Dict, List, Optional
|
@@ -58,7 +59,7 @@ class SweepScheduler(Scheduler):
|
|
58
59
|
return None
|
59
60
|
|
60
61
|
def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
|
61
|
-
"""Helper to
|
62
|
+
"""Helper to receive sweep command from backend."""
|
62
63
|
# AgentHeartbeat wants a Dict of runs which are running or queued
|
63
64
|
_run_states: Dict[str, bool] = {}
|
64
65
|
for run_id, run in self._yield_runs():
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -217,7 +217,7 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
217
217
|
flags: List[str] = []
|
218
218
|
# (2) flags without hyphens (e.g. foo=bar)
|
219
219
|
flags_no_hyphens: List[str] = []
|
220
|
-
# (3) flags with false booleans
|
220
|
+
# (3) flags with false booleans omitted (e.g. --foo)
|
221
221
|
flags_no_booleans: List[str] = []
|
222
222
|
# (4) flags as a dictionary (used for constructing a json)
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
@@ -257,7 +257,7 @@ def make_launch_sweep_entrypoint(
|
|
257
257
|
"""Use args dict from create_sweep_command_args to construct entrypoint.
|
258
258
|
|
259
259
|
If replace is True, remove macros from entrypoint, fill them in with args
|
260
|
-
and then return the args in
|
260
|
+
and then return the args in separate return value.
|
261
261
|
"""
|
262
262
|
if not command:
|
263
263
|
return None, None
|
@@ -296,7 +296,7 @@ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
|
296
296
|
|
297
297
|
|
298
298
|
def get_previous_args(
|
299
|
-
run_spec: Dict[str, Any]
|
299
|
+
run_spec: Dict[str, Any],
|
300
300
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
301
301
|
"""Parse through previous scheduler run_spec.
|
302
302
|
|
wandb/sdk/launch/utils.py
CHANGED
@@ -57,15 +57,15 @@ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
|
|
57
57
|
MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
|
58
58
|
|
59
59
|
AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
60
|
-
r"(?:https://)?([\w]+)\.azurecr\.io/([\w\-]+):?(
|
60
|
+
r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
|
61
61
|
)
|
62
62
|
|
63
63
|
ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
64
|
-
r"^(?P<account
|
64
|
+
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\w-]+):?(?P<tag>.*)$"
|
65
65
|
)
|
66
66
|
|
67
67
|
GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
|
68
|
-
r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)
|
68
|
+
r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
|
69
69
|
re.IGNORECASE,
|
70
70
|
)
|
71
71
|
|
wandb/sdk/lib/fsm.py
CHANGED
@@ -52,43 +52,39 @@ T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
|
|
52
52
|
@runtime_checkable
|
53
53
|
class FsmStateCheck(Protocol[T_FsmInputs]):
|
54
54
|
@abstractmethod
|
55
|
-
def on_check(self, inputs: T_FsmInputs) -> None:
|
56
|
-
... # pragma: no cover
|
55
|
+
def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
57
56
|
|
58
57
|
|
59
58
|
@runtime_checkable
|
60
59
|
class FsmStateOutput(Protocol[T_FsmInputs]):
|
61
60
|
@abstractmethod
|
62
|
-
def on_state(self, inputs: T_FsmInputs) -> None:
|
63
|
-
... # pragma: no cover
|
61
|
+
def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
64
62
|
|
65
63
|
|
66
64
|
@runtime_checkable
|
67
65
|
class FsmStateEnter(Protocol[T_FsmInputs]):
|
68
66
|
@abstractmethod
|
69
|
-
def on_enter(self, inputs: T_FsmInputs) -> None:
|
70
|
-
... # pragma: no cover
|
67
|
+
def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
71
68
|
|
72
69
|
|
73
70
|
@runtime_checkable
|
74
71
|
class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
|
75
72
|
@abstractmethod
|
76
|
-
def on_enter(
|
77
|
-
|
73
|
+
def on_enter(
|
74
|
+
self, inputs: T_FsmInputs, context: T_FsmContext_contra
|
75
|
+
) -> None: ... # pragma: no cover
|
78
76
|
|
79
77
|
|
80
78
|
@runtime_checkable
|
81
79
|
class FsmStateStay(Protocol[T_FsmInputs]):
|
82
80
|
@abstractmethod
|
83
|
-
def on_stay(self, inputs: T_FsmInputs) -> None:
|
84
|
-
... # pragma: no cover
|
81
|
+
def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
85
82
|
|
86
83
|
|
87
84
|
@runtime_checkable
|
88
85
|
class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
|
89
86
|
@abstractmethod
|
90
|
-
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov:
|
91
|
-
... # pragma: no cover
|
87
|
+
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
|
92
88
|
|
93
89
|
|
94
90
|
# It would be nice if python provided optional protocol members, but it doesnt as described here:
|
wandb/sdk/lib/gitlib.py
CHANGED
@@ -14,7 +14,7 @@ try:
|
|
14
14
|
Repo,
|
15
15
|
)
|
16
16
|
except ImportError:
|
17
|
-
Repo = None
|
17
|
+
Repo = None # type: ignore
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from git import Repo
|
@@ -121,7 +121,7 @@ class GitRepo:
|
|
121
121
|
# TODO: Saw a user getting a Unicode decode error when parsing refs,
|
122
122
|
# more details on implementing a real fix in [WB-4064]
|
123
123
|
try:
|
124
|
-
if len(self.repo.refs) > 0:
|
124
|
+
if len(self.repo.refs) > 0: # type: ignore[arg-type]
|
125
125
|
return self.repo.head.commit.hexsha
|
126
126
|
else:
|
127
127
|
return self.repo.git.show_ref("--head").split(" ")[0]
|
@@ -140,7 +140,7 @@ class GitRepo:
|
|
140
140
|
if not self.repo:
|
141
141
|
return None
|
142
142
|
try:
|
143
|
-
return self.repo.remotes[self.remote_name]
|
143
|
+
return self.repo.remotes[self.remote_name] # type: ignore[index]
|
144
144
|
except IndexError:
|
145
145
|
return None
|
146
146
|
|
@@ -200,7 +200,7 @@ class GitRepo:
|
|
200
200
|
possible_relatives.append(tracking_branch.commit)
|
201
201
|
|
202
202
|
if not possible_relatives:
|
203
|
-
for branch in self.repo.branches:
|
203
|
+
for branch in self.repo.branches: # type: ignore[attr-defined]
|
204
204
|
tracking_branch = branch.tracking_branch()
|
205
205
|
if tracking_branch is not None:
|
206
206
|
possible_relatives.append(tracking_branch.commit)
|
wandb/sdk/lib/import_hooks.py
CHANGED
@@ -143,7 +143,7 @@ class _ImportHookChainedLoader:
|
|
143
143
|
# None, so handle None as well. The module may not support attribute
|
144
144
|
# assignment, in which case we simply skip it. Note that we also deal
|
145
145
|
# with __loader__ not existing at all. This is to future proof things
|
146
|
-
# due to proposal to remove the
|
146
|
+
# due to proposal to remove the attribute as described in the GitHub
|
147
147
|
# issue at https://github.com/python/cpython/issues/77458. Also prior
|
148
148
|
# to Python 3.3, the __loader__ attribute was only set if a custom
|
149
149
|
# module loader was used. It isn't clear whether the attribute still
|
wandb/sdk/lib/lazyloader.py
CHANGED
wandb/sdk/lib/proto_util.py
CHANGED
@@ -29,7 +29,7 @@ def _assign_end_offset(record: "pb.Record", end_offset: int) -> None:
|
|
29
29
|
|
30
30
|
|
31
31
|
def proto_encode_to_dict(
|
32
|
-
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"]
|
32
|
+
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"],
|
33
33
|
) -> Dict[int, Any]:
|
34
34
|
data: Dict[int, Any] = dict()
|
35
35
|
fields = pb_obj.ListFields()
|