wandb 0.22.0__py3-none-macosx_12_0_arm64.whl → 0.22.1__py3-none-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -3
- wandb/_pydantic/__init__.py +12 -11
- wandb/_pydantic/base.py +49 -19
- wandb/apis/__init__.py +2 -0
- wandb/apis/attrs.py +2 -0
- wandb/apis/importers/internals/internal.py +16 -23
- wandb/apis/internal.py +2 -0
- wandb/apis/normalize.py +2 -0
- wandb/apis/public/__init__.py +3 -2
- wandb/apis/public/api.py +215 -164
- wandb/apis/public/artifacts.py +23 -20
- wandb/apis/public/const.py +2 -0
- wandb/apis/public/files.py +33 -24
- wandb/apis/public/history.py +2 -0
- wandb/apis/public/jobs.py +20 -18
- wandb/apis/public/projects.py +4 -2
- wandb/apis/public/query_generator.py +3 -0
- wandb/apis/public/registries/__init__.py +7 -0
- wandb/apis/public/registries/_freezable_list.py +9 -12
- wandb/apis/public/registries/registries_search.py +8 -6
- wandb/apis/public/registries/registry.py +22 -17
- wandb/apis/public/reports.py +2 -0
- wandb/apis/public/runs.py +261 -57
- wandb/apis/public/sweeps.py +10 -9
- wandb/apis/public/teams.py +2 -0
- wandb/apis/public/users.py +2 -0
- wandb/apis/public/utils.py +16 -15
- wandb/automations/_generated/__init__.py +54 -127
- wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
- wandb/automations/_generated/fragments.py +26 -91
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta_sync.py +9 -11
- wandb/errors/errors.py +3 -3
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_sync_pb2.py +10 -6
- wandb/sdk/artifacts/_factories.py +7 -2
- wandb/sdk/artifacts/_generated/__init__.py +112 -412
- wandb/sdk/artifacts/_generated/fragments.py +65 -0
- wandb/sdk/artifacts/_generated/operations.py +52 -22
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/type_info.py +19 -0
- wandb/sdk/artifacts/_gqlutils.py +47 -0
- wandb/sdk/artifacts/_models/__init__.py +4 -0
- wandb/sdk/artifacts/_models/base_model.py +20 -0
- wandb/sdk/artifacts/_validators.py +40 -12
- wandb/sdk/artifacts/artifact.py +69 -88
- wandb/sdk/artifacts/artifact_file_cache.py +6 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +10 -0
- wandb/sdk/data_types/bokeh.py +5 -1
- wandb/sdk/data_types/image.py +17 -6
- wandb/sdk/interface/interface.py +31 -4
- wandb/sdk/interface/interface_queue.py +10 -0
- wandb/sdk/interface/interface_shared.py +0 -7
- wandb/sdk/interface/interface_sock.py +9 -3
- wandb/sdk/internal/_generated/__init__.py +2 -12
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/settings_static.py +2 -82
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
- wandb/sdk/launch/utils.py +82 -1
- wandb/sdk/lib/progress.py +7 -4
- wandb/sdk/lib/service/service_client.py +5 -9
- wandb/sdk/lib/service/service_connection.py +39 -23
- wandb/sdk/mailbox/mailbox_handle.py +2 -0
- wandb/sdk/projects/_generated/__init__.py +12 -33
- wandb/sdk/wandb_init.py +22 -2
- wandb/sdk/wandb_login.py +53 -27
- wandb/sdk/wandb_run.py +5 -3
- wandb/sdk/wandb_settings.py +50 -13
- wandb/sync/sync.py +7 -2
- wandb/util.py +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/RECORD +80 -77
- wandb/sdk/artifacts/_graphql_fragments.py +0 -19
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/artifact.py
CHANGED
@@ -16,7 +16,7 @@ import tempfile
|
|
16
16
|
import time
|
17
17
|
from collections import deque
|
18
18
|
from copy import copy
|
19
|
-
from dataclasses import dataclass
|
19
|
+
from dataclasses import asdict, dataclass, replace
|
20
20
|
from datetime import timedelta
|
21
21
|
from functools import partial
|
22
22
|
from itertools import filterfalse
|
@@ -30,7 +30,6 @@ from typing import (
|
|
30
30
|
Literal,
|
31
31
|
Sequence,
|
32
32
|
Type,
|
33
|
-
cast,
|
34
33
|
final,
|
35
34
|
)
|
36
35
|
from urllib.parse import quote, urljoin, urlparse
|
@@ -109,10 +108,11 @@ from ._generated import (
|
|
109
108
|
TagInput,
|
110
109
|
UpdateArtifact,
|
111
110
|
)
|
112
|
-
from .
|
111
|
+
from ._gqlutils import omit_artifact_fields, supports_enable_tracking_var, type_info
|
113
112
|
from ._validators import (
|
114
113
|
LINKED_ARTIFACT_COLLECTION_TYPE,
|
115
114
|
ArtifactPath,
|
115
|
+
FullArtifactPath,
|
116
116
|
_LinkArtifactFields,
|
117
117
|
ensure_logged,
|
118
118
|
ensure_not_finalized,
|
@@ -210,6 +210,7 @@ class Artifact:
|
|
210
210
|
metadata: dict[str, Any] | None = None,
|
211
211
|
incremental: bool = False,
|
212
212
|
use_as: str | None = None,
|
213
|
+
storage_region: str | None = None,
|
213
214
|
) -> None:
|
214
215
|
if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
|
215
216
|
raise ValueError(
|
@@ -268,7 +269,7 @@ class Artifact:
|
|
268
269
|
self._use_as: str | None = None
|
269
270
|
self._state: ArtifactState = ArtifactState.PENDING
|
270
271
|
self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
|
271
|
-
ArtifactManifestV1(storage_policy=make_storage_policy())
|
272
|
+
ArtifactManifestV1(storage_policy=make_storage_policy(storage_region))
|
272
273
|
)
|
273
274
|
self._commit_hash: str | None = None
|
274
275
|
self._file_count: int | None = None
|
@@ -286,36 +287,33 @@ class Artifact:
|
|
286
287
|
|
287
288
|
@classmethod
|
288
289
|
def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
|
289
|
-
if
|
290
|
-
return
|
290
|
+
if cached_artifact := artifact_instance_cache.get(artifact_id):
|
291
|
+
return cached_artifact
|
291
292
|
|
292
|
-
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
|
293
|
+
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(client))
|
293
294
|
|
294
295
|
data = client.execute(query, variable_values={"id": artifact_id})
|
295
296
|
result = ArtifactByID.model_validate(data)
|
296
297
|
|
297
|
-
if (
|
298
|
+
if (artifact := result.artifact) is None:
|
298
299
|
return None
|
299
300
|
|
300
|
-
src_collection =
|
301
|
+
src_collection = artifact.artifact_sequence
|
301
302
|
src_project = src_collection.project
|
302
303
|
|
303
304
|
entity_name = src_project.entity_name if src_project else ""
|
304
305
|
project_name = src_project.name if src_project else ""
|
305
306
|
|
306
|
-
name = f"{src_collection.name}:v{
|
307
|
-
|
307
|
+
name = f"{src_collection.name}:v{artifact.version_index}"
|
308
|
+
|
309
|
+
path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
|
310
|
+
return cls._from_attrs(path, artifact, client)
|
308
311
|
|
309
312
|
@classmethod
|
310
313
|
def _membership_from_name(
|
311
|
-
cls,
|
312
|
-
*,
|
313
|
-
entity: str,
|
314
|
-
project: str,
|
315
|
-
name: str,
|
316
|
-
client: RetryingClient,
|
314
|
+
cls, *, path: FullArtifactPath, client: RetryingClient
|
317
315
|
) -> Artifact:
|
318
|
-
if not
|
316
|
+
if not InternalApi()._server_supports(
|
319
317
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
320
318
|
):
|
321
319
|
raise UnsupportedError(
|
@@ -325,74 +323,70 @@ class Artifact:
|
|
325
323
|
|
326
324
|
query = gql_compat(
|
327
325
|
ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
|
328
|
-
omit_fields=omit_artifact_fields(
|
326
|
+
omit_fields=omit_artifact_fields(client),
|
329
327
|
)
|
330
|
-
|
331
|
-
|
328
|
+
gql_vars = {
|
329
|
+
"entityName": path.prefix,
|
330
|
+
"projectName": path.project,
|
331
|
+
"name": path.name,
|
332
|
+
}
|
332
333
|
data = client.execute(query, variable_values=gql_vars)
|
333
334
|
result = ArtifactViaMembershipByName.model_validate(data)
|
334
335
|
|
335
|
-
if not (
|
336
|
-
raise ValueError(f"project {project!r} not found under entity {entity!r}")
|
337
|
-
|
338
|
-
if not (acm_attrs := project_attrs.artifact_collection_membership):
|
339
|
-
entity_project = f"{entity}/{project}"
|
336
|
+
if not (project := result.project):
|
340
337
|
raise ValueError(
|
341
|
-
f"
|
338
|
+
f"project {path.project!r} not found under entity {path.prefix!r}"
|
342
339
|
)
|
343
|
-
|
344
|
-
|
345
|
-
|
340
|
+
if not (membership := project.artifact_collection_membership):
|
341
|
+
entity_project = f"{path.prefix}/{path.project}"
|
342
|
+
raise ValueError(
|
343
|
+
f"artifact membership {path.name!r} not found in {entity_project!r}"
|
344
|
+
)
|
345
|
+
return cls._from_membership(membership, target=path, client=client)
|
346
346
|
|
347
347
|
@classmethod
|
348
348
|
def _from_name(
|
349
349
|
cls,
|
350
350
|
*,
|
351
|
-
|
352
|
-
project: str,
|
353
|
-
name: str,
|
351
|
+
path: FullArtifactPath,
|
354
352
|
client: RetryingClient,
|
355
353
|
enable_tracking: bool = False,
|
356
354
|
) -> Artifact:
|
357
|
-
if
|
355
|
+
if InternalApi()._server_supports(
|
358
356
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
359
357
|
):
|
360
|
-
return cls._membership_from_name(
|
361
|
-
entity=entity, project=project, name=name, client=client
|
362
|
-
)
|
363
|
-
|
364
|
-
supports_enable_tracking_gql_var = api.server_project_type_introspection()
|
365
|
-
omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
|
358
|
+
return cls._membership_from_name(path=path, client=client)
|
366
359
|
|
360
|
+
omit_vars = None if supports_enable_tracking_var(client) else {"enableTracking"}
|
367
361
|
gql_vars = {
|
368
|
-
"entityName":
|
369
|
-
"projectName": project,
|
370
|
-
"name": name,
|
362
|
+
"entityName": path.prefix,
|
363
|
+
"projectName": path.project,
|
364
|
+
"name": path.name,
|
371
365
|
"enableTracking": enable_tracking,
|
372
366
|
}
|
373
367
|
query = gql_compat(
|
374
368
|
ARTIFACT_BY_NAME_GQL,
|
375
369
|
omit_variables=omit_vars,
|
376
|
-
omit_fields=omit_artifact_fields(
|
370
|
+
omit_fields=omit_artifact_fields(client),
|
377
371
|
)
|
378
|
-
|
379
372
|
data = client.execute(query, variable_values=gql_vars)
|
380
373
|
result = ArtifactByName.model_validate(data)
|
381
374
|
|
382
|
-
if not (
|
383
|
-
raise ValueError(
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
375
|
+
if not (project := result.project):
|
376
|
+
raise ValueError(
|
377
|
+
f"project {path.project!r} not found under entity {path.prefix!r}"
|
378
|
+
)
|
379
|
+
if not (artifact := project.artifact):
|
380
|
+
entity_project = f"{path.prefix}/{path.project}"
|
381
|
+
raise ValueError(f"artifact {path.name!r} not found in {entity_project!r}")
|
388
382
|
|
389
|
-
return cls._from_attrs(
|
383
|
+
return cls._from_attrs(path, artifact, client)
|
390
384
|
|
391
385
|
@classmethod
|
392
386
|
def _from_membership(
|
393
387
|
cls,
|
394
388
|
membership: MembershipWithArtifact,
|
395
|
-
target:
|
389
|
+
target: FullArtifactPath,
|
396
390
|
client: RetryingClient,
|
397
391
|
) -> Artifact:
|
398
392
|
if not (
|
@@ -410,35 +404,31 @@ class Artifact:
|
|
410
404
|
f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
|
411
405
|
f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
|
412
406
|
)
|
413
|
-
|
407
|
+
new_target = replace(target, prefix=proj.entity_name, project=proj.name)
|
414
408
|
else:
|
415
|
-
|
416
|
-
new_project = cast(str, target.project)
|
409
|
+
new_target = copy(target)
|
417
410
|
|
418
411
|
if not (artifact := membership.artifact):
|
419
412
|
raise ValueError(f"Artifact {target.to_str()!r} not found in response")
|
420
413
|
|
421
|
-
return cls._from_attrs(
|
414
|
+
return cls._from_attrs(new_target, artifact, client)
|
422
415
|
|
423
416
|
@classmethod
|
424
417
|
def _from_attrs(
|
425
418
|
cls,
|
426
|
-
|
427
|
-
|
428
|
-
name: str,
|
429
|
-
attrs: dict[str, Any] | ArtifactFragment,
|
419
|
+
path: FullArtifactPath,
|
420
|
+
attrs: ArtifactFragment,
|
430
421
|
client: RetryingClient,
|
431
422
|
aliases: list[str] | None = None,
|
432
423
|
) -> Artifact:
|
433
424
|
# Placeholder is required to skip validation.
|
434
425
|
artifact = cls("placeholder", type="placeholder")
|
435
426
|
artifact._client = client
|
436
|
-
artifact._entity =
|
437
|
-
artifact._project = project
|
438
|
-
artifact._name = name
|
427
|
+
artifact._entity = path.prefix
|
428
|
+
artifact._project = path.project
|
429
|
+
artifact._name = path.name
|
439
430
|
|
440
|
-
|
441
|
-
artifact._assign_attrs(validated_attrs, aliases)
|
431
|
+
artifact._assign_attrs(attrs, aliases)
|
442
432
|
|
443
433
|
artifact.finalize()
|
444
434
|
|
@@ -743,13 +733,12 @@ class Artifact:
|
|
743
733
|
raise ValueError("Client is not initialized")
|
744
734
|
|
745
735
|
try:
|
746
|
-
|
747
|
-
|
736
|
+
path = FullArtifactPath(
|
737
|
+
prefix=self.source_entity,
|
748
738
|
project=self.source_project,
|
749
739
|
name=self.source_name,
|
750
|
-
client=self._client,
|
751
740
|
)
|
752
|
-
self._source_artifact =
|
741
|
+
self._source_artifact = self._from_name(path=path, client=self._client)
|
753
742
|
except Exception as e:
|
754
743
|
raise ValueError(
|
755
744
|
f"Unable to fetch source artifact for linked artifact {self.name}"
|
@@ -1234,8 +1223,9 @@ class Artifact:
|
|
1234
1223
|
def _populate_after_save(self, artifact_id: str) -> None:
|
1235
1224
|
assert self._client is not None
|
1236
1225
|
|
1237
|
-
query = gql_compat(
|
1238
|
-
|
1226
|
+
query = gql_compat(
|
1227
|
+
ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(self._client)
|
1228
|
+
)
|
1239
1229
|
data = self._client.execute(query, variable_values={"id": artifact_id})
|
1240
1230
|
result = ArtifactByID.model_validate(data)
|
1241
1231
|
|
@@ -1258,21 +1248,9 @@ class Artifact:
|
|
1258
1248
|
collection = self.name.split(":")[0]
|
1259
1249
|
|
1260
1250
|
aliases = None
|
1261
|
-
introspect_query = gql(
|
1262
|
-
"""
|
1263
|
-
query ProbeServerAddAliasesInput {
|
1264
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
1265
|
-
name
|
1266
|
-
inputFields {
|
1267
|
-
name
|
1268
|
-
}
|
1269
|
-
}
|
1270
|
-
}
|
1271
|
-
"""
|
1272
|
-
)
|
1273
1251
|
|
1274
|
-
|
1275
|
-
|
1252
|
+
if type_info(self._client, "AddAliasesInput") is not None:
|
1253
|
+
# wandb backend version >= 0.13.0
|
1276
1254
|
alias_props = {
|
1277
1255
|
"entity_name": entity,
|
1278
1256
|
"project_name": project,
|
@@ -1330,7 +1308,7 @@ class Artifact:
|
|
1330
1308
|
for alias in self.aliases
|
1331
1309
|
]
|
1332
1310
|
|
1333
|
-
omit_fields = omit_artifact_fields()
|
1311
|
+
omit_fields = omit_artifact_fields(self._client)
|
1334
1312
|
omit_variables = set()
|
1335
1313
|
|
1336
1314
|
if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
|
@@ -2427,7 +2405,7 @@ class Artifact:
|
|
2427
2405
|
|
2428
2406
|
# Parse the entity (first part of the path) appropriately,
|
2429
2407
|
# depending on whether we're linking to a registry
|
2430
|
-
if target.
|
2408
|
+
if target.is_registry_path():
|
2431
2409
|
# In a Registry linking, the entity is used to fetch the organization of the artifact
|
2432
2410
|
# therefore the source artifact's entity is passed to the backend
|
2433
2411
|
org = target.prefix or settings.get("organization") or ""
|
@@ -2435,6 +2413,9 @@ class Artifact:
|
|
2435
2413
|
else:
|
2436
2414
|
target = target.with_defaults(prefix=self.source_entity)
|
2437
2415
|
|
2416
|
+
# Explicitly convert to FullArtifactPath to ensure all fields are present
|
2417
|
+
target = FullArtifactPath(**asdict(target))
|
2418
|
+
|
2438
2419
|
# Prepare the validated GQL input, send it
|
2439
2420
|
alias_inputs = [
|
2440
2421
|
ArtifactAliasInput(artifact_collection_name=target.name, alias=a)
|
@@ -26,6 +26,11 @@ class Opener(Protocol):
|
|
26
26
|
pass
|
27
27
|
|
28
28
|
|
29
|
+
def artifacts_cache_dir() -> Path:
|
30
|
+
"""Get the artifacts cache directory."""
|
31
|
+
return env.get_cache_dir() / "artifacts"
|
32
|
+
|
33
|
+
|
29
34
|
def _get_sys_umask_threadsafe() -> int:
|
30
35
|
# Workaround to get the current system umask, since
|
31
36
|
# - `os.umask()` isn't thread-safe
|
@@ -248,4 +253,4 @@ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
|
|
248
253
|
|
249
254
|
|
250
255
|
def get_artifact_file_cache() -> ArtifactFileCache:
|
251
|
-
return _build_artifact_file_cache(
|
256
|
+
return _build_artifact_file_cache(artifacts_cache_dir())
|
@@ -3,9 +3,11 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
+
import hashlib
|
6
7
|
import json
|
7
8
|
import logging
|
8
9
|
import os
|
10
|
+
from contextlib import suppress
|
9
11
|
from pathlib import Path
|
10
12
|
from typing import TYPE_CHECKING
|
11
13
|
from urllib.parse import urlparse
|
@@ -43,6 +45,48 @@ if TYPE_CHECKING:
|
|
43
45
|
_WB_ARTIFACT_SCHEME = "wandb-artifact"
|
44
46
|
|
45
47
|
|
48
|
+
def _checksum_cache_path(file_path: str) -> str:
|
49
|
+
"""Get path for checksum in central cache directory."""
|
50
|
+
from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
|
51
|
+
|
52
|
+
# Create a unique cache key based on the file's absolute path
|
53
|
+
abs_path = os.path.abspath(file_path)
|
54
|
+
path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
|
55
|
+
|
56
|
+
# Store in wandb cache directory under checksums subdirectory
|
57
|
+
cache_dir = artifacts_cache_dir() / "checksums"
|
58
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
59
|
+
|
60
|
+
return str(cache_dir / f"{path_hash}.checksum")
|
61
|
+
|
62
|
+
|
63
|
+
def _read_cached_checksum(file_path: str) -> str | None:
|
64
|
+
"""Read checksum from cache if it exists and is valid."""
|
65
|
+
checksum_path = _checksum_cache_path(file_path)
|
66
|
+
|
67
|
+
try:
|
68
|
+
with open(file_path) as f, open(checksum_path) as f_checksum:
|
69
|
+
if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
|
70
|
+
# File was modified after checksum was written
|
71
|
+
return None
|
72
|
+
# Read and return the cached checksum
|
73
|
+
return f_checksum.read().strip()
|
74
|
+
except OSError:
|
75
|
+
# File doesn't exist or couldn't be opened
|
76
|
+
return None
|
77
|
+
|
78
|
+
|
79
|
+
def _write_cached_checksum(file_path: str, checksum: str) -> None:
|
80
|
+
"""Write checksum to cache directory."""
|
81
|
+
checksum_path = _checksum_cache_path(file_path)
|
82
|
+
try:
|
83
|
+
with open(checksum_path, "w") as f:
|
84
|
+
f.write(checksum)
|
85
|
+
except OSError:
|
86
|
+
# Non-critical failure, just log it
|
87
|
+
logger.debug(f"Failed to write checksum cache for {file_path!r}")
|
88
|
+
|
89
|
+
|
46
90
|
class ArtifactManifestEntry:
|
47
91
|
"""A single entry in an artifact manifest."""
|
48
92
|
|
@@ -164,11 +208,19 @@ class ArtifactManifestEntry:
|
|
164
208
|
|
165
209
|
# Skip checking the cache (and possibly downloading) if the file already exists
|
166
210
|
# and has the digest we're expecting.
|
211
|
+
|
212
|
+
# Fast integrity check using cached checksum from persistent cache
|
213
|
+
with suppress(OSError):
|
214
|
+
if self.digest == _read_cached_checksum(dest_path):
|
215
|
+
return FilePathStr(dest_path)
|
216
|
+
|
217
|
+
# Fallback to computing/caching the checksum hash
|
167
218
|
try:
|
168
219
|
md5_hash = md5_file_b64(dest_path)
|
169
220
|
except (FileNotFoundError, IsADirectoryError):
|
170
|
-
logger.debug(f"unable to find {dest_path}, skip searching for file")
|
221
|
+
logger.debug(f"unable to find {dest_path!r}, skip searching for file")
|
171
222
|
else:
|
223
|
+
_write_cached_checksum(dest_path, md5_hash)
|
172
224
|
if self.digest == md5_hash:
|
173
225
|
return FilePathStr(dest_path)
|
174
226
|
|
@@ -184,12 +236,19 @@ class ArtifactManifestEntry:
|
|
184
236
|
executor=executor,
|
185
237
|
multipart=multipart,
|
186
238
|
)
|
187
|
-
|
239
|
+
|
240
|
+
# Determine the final path
|
241
|
+
final_path = (
|
188
242
|
dest_path
|
189
243
|
if skip_cache
|
190
244
|
else copy_or_overwrite_changed(cache_path, dest_path)
|
191
245
|
)
|
192
246
|
|
247
|
+
# Cache the checksum for future downloads
|
248
|
+
_write_cached_checksum(str(final_path), self.digest)
|
249
|
+
|
250
|
+
return FilePathStr(final_path)
|
251
|
+
|
193
252
|
def ref_target(self) -> FilePathStr | URIStr:
|
194
253
|
"""Get the reference URL that is targeted by this artifact entry.
|
195
254
|
|
@@ -91,6 +91,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
91
91
|
session: requests.Session | None = None,
|
92
92
|
) -> None:
|
93
93
|
self._config = config or {}
|
94
|
+
if (storage_region := self._config.get("storageRegion")) is not None:
|
95
|
+
self._validate_storage_region(storage_region)
|
94
96
|
self._cache = cache or get_artifact_file_cache()
|
95
97
|
self._session = session or make_http_session()
|
96
98
|
self._api = api or InternalApi()
|
@@ -99,6 +101,14 @@ class WandbStoragePolicy(StoragePolicy):
|
|
99
101
|
default_handler=TrackingHandler(),
|
100
102
|
)
|
101
103
|
|
104
|
+
def _validate_storage_region(self, storage_region: Any) -> None:
|
105
|
+
if not isinstance(storage_region, str):
|
106
|
+
raise TypeError(
|
107
|
+
f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
|
108
|
+
)
|
109
|
+
if not storage_region.strip():
|
110
|
+
raise ValueError("storageRegion must be a non-empty string")
|
111
|
+
|
102
112
|
def config(self) -> dict:
|
103
113
|
return self._config
|
104
114
|
|
wandb/sdk/data_types/bokeh.py
CHANGED
@@ -5,6 +5,7 @@ import pathlib
|
|
5
5
|
from typing import TYPE_CHECKING, Union
|
6
6
|
|
7
7
|
from wandb import util
|
8
|
+
from wandb._strutils import nameof
|
8
9
|
from wandb.sdk.lib import runid
|
9
10
|
|
10
11
|
from . import _dtypes
|
@@ -34,7 +35,10 @@ class Bokeh(Media):
|
|
34
35
|
],
|
35
36
|
):
|
36
37
|
super().__init__()
|
37
|
-
bokeh = util.get_module(
|
38
|
+
bokeh = util.get_module(
|
39
|
+
"bokeh",
|
40
|
+
required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
|
41
|
+
)
|
38
42
|
if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
|
39
43
|
data_or_path
|
40
44
|
):
|
wandb/sdk/data_types/image.py
CHANGED
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
|
|
161
161
|
) -> None:
|
162
162
|
"""Initialize a `wandb.Image` object.
|
163
163
|
|
164
|
+
This class handles various image data formats and automatically normalizes
|
165
|
+
pixel values to the range [0, 255] when needed, ensuring compatibility
|
166
|
+
with the W&B backend.
|
167
|
+
|
168
|
+
* Data in range [0, 1] is multiplied by 255 and converted to uint8
|
169
|
+
* Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
|
170
|
+
-1 to 0 and 1 to 255, then converted to uint8
|
171
|
+
* Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
|
172
|
+
converted to uint8 (with a warning if values fall outside [0, 255])
|
173
|
+
* Data already in [0, 255] is converted to uint8 without modification
|
174
|
+
|
164
175
|
Args:
|
165
176
|
data_or_path: Accepts NumPy array/pytorch tensor of image data,
|
166
177
|
a PIL image object, or a path to an image file. If a NumPy
|
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
|
|
168
179
|
the image data will be saved to the given file type.
|
169
180
|
If the values are not in the range [0, 255] or all values are in the range [0, 1],
|
170
181
|
the image pixel values will be normalized to the range [0, 255]
|
171
|
-
unless `normalize` is set to False
|
182
|
+
unless `normalize` is set to `False`.
|
172
183
|
- pytorch tensor should be in the format (channel, height, width)
|
173
184
|
- NumPy array should be in the format (height, width, channel)
|
174
185
|
mode: The PIL mode for an image. Most common are "L", "RGB",
|
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
|
|
178
189
|
classes: A list of class information for the image,
|
179
190
|
used for labeling bounding boxes, and image masks.
|
180
191
|
boxes: A dictionary containing bounding box information for the image.
|
181
|
-
see
|
192
|
+
see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
|
182
193
|
masks: A dictionary containing mask information for the image.
|
183
|
-
see
|
194
|
+
see https://docs.wandb.ai/ref/python/data-types/imagemask/
|
184
195
|
file_type: The file type to save the image as.
|
185
|
-
This parameter has no effect if data_or_path is a path to an image file.
|
186
|
-
normalize: If True
|
187
|
-
Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
|
196
|
+
This parameter has no effect if `data_or_path` is a path to an image file.
|
197
|
+
normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
|
198
|
+
Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
|
188
199
|
|
189
200
|
Examples:
|
190
201
|
Create a wandb.Image from a numpy array
|
wandb/sdk/interface/interface.py
CHANGED
@@ -87,11 +87,38 @@ def file_enum_to_policy(enum: "pb.FilesItem.PolicyType.V") -> "PolicyName":
|
|
87
87
|
|
88
88
|
|
89
89
|
class InterfaceBase:
|
90
|
+
"""Methods for sending different types of Records to the service.
|
91
|
+
|
92
|
+
None of the methods may be called from an asyncio context other than
|
93
|
+
deliver_async().
|
94
|
+
"""
|
95
|
+
|
90
96
|
_drop: bool
|
91
97
|
|
92
98
|
def __init__(self) -> None:
|
93
99
|
self._drop = False
|
94
100
|
|
101
|
+
@abstractmethod
|
102
|
+
async def deliver_async(
|
103
|
+
self,
|
104
|
+
record: pb.Record,
|
105
|
+
) -> MailboxHandle[pb.Result]:
|
106
|
+
"""Send a record and create a handle to wait for the response.
|
107
|
+
|
108
|
+
The synchronous publish and deliver methods on this class cannot be
|
109
|
+
called in the asyncio thread because they block. Instead of having
|
110
|
+
an async copy of every method, this is a general method for sending
|
111
|
+
any kind of record in the asyncio thread.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
record: The record to send. This method takes ownership of the
|
115
|
+
record and it must not be used afterward.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
A handle to wait for a response to the record.
|
119
|
+
"""
|
120
|
+
raise NotImplementedError
|
121
|
+
|
95
122
|
def publish_header(self) -> None:
|
96
123
|
header = pb.HeaderRecord()
|
97
124
|
self._publish_header(header)
|
@@ -392,9 +419,13 @@ class InterfaceBase:
|
|
392
419
|
proto_manifest.manifest_file_path = path
|
393
420
|
return proto_manifest
|
394
421
|
|
422
|
+
# Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
|
423
|
+
# NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
|
424
|
+
# The creation logic is in artifacts/_factories.py make_storage_policy
|
395
425
|
for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
|
396
426
|
cfg = proto_manifest.storage_policy_config.add()
|
397
427
|
cfg.key = k
|
428
|
+
# TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
|
398
429
|
cfg.value_json = json.dumps(v)
|
399
430
|
|
400
431
|
for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
|
@@ -1020,10 +1051,6 @@ class InterfaceBase:
|
|
1020
1051
|
) -> MailboxHandle[pb.Result]:
|
1021
1052
|
raise NotImplementedError
|
1022
1053
|
|
1023
|
-
@abstractmethod
|
1024
|
-
def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
|
1025
|
-
raise NotImplementedError
|
1026
|
-
|
1027
1054
|
def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
|
1028
1055
|
poll_exit = pb.PollExitRequest()
|
1029
1056
|
return self._deliver_poll_exit(poll_exit)
|
@@ -8,12 +8,15 @@ import logging
|
|
8
8
|
from multiprocessing.process import BaseProcess
|
9
9
|
from typing import TYPE_CHECKING, Optional
|
10
10
|
|
11
|
+
from typing_extensions import override
|
12
|
+
|
11
13
|
from .interface_shared import InterfaceShared
|
12
14
|
|
13
15
|
if TYPE_CHECKING:
|
14
16
|
from queue import Queue
|
15
17
|
|
16
18
|
from wandb.proto import wandb_internal_pb2 as pb
|
19
|
+
from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
|
17
20
|
|
18
21
|
|
19
22
|
logger = logging.getLogger("wandb")
|
@@ -31,6 +34,13 @@ class InterfaceQueue(InterfaceShared):
|
|
31
34
|
self._process = process
|
32
35
|
super().__init__()
|
33
36
|
|
37
|
+
@override
|
38
|
+
async def deliver_async(
|
39
|
+
self,
|
40
|
+
record: "pb.Record",
|
41
|
+
) -> "MailboxHandle[pb.Result]":
|
42
|
+
raise NotImplementedError
|
43
|
+
|
34
44
|
def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
|
35
45
|
if self._process and not self._process.is_alive():
|
36
46
|
raise Exception("The wandb backend process has shutdown")
|
@@ -87,7 +87,6 @@ class InterfaceShared(InterfaceBase):
|
|
87
87
|
stop_status: Optional[pb.StopStatusRequest] = None,
|
88
88
|
internal_messages: Optional[pb.InternalMessagesRequest] = None,
|
89
89
|
network_status: Optional[pb.NetworkStatusRequest] = None,
|
90
|
-
operation_stats: Optional[pb.OperationStatsRequest] = None,
|
91
90
|
poll_exit: Optional[pb.PollExitRequest] = None,
|
92
91
|
partial_history: Optional[pb.PartialHistoryRequest] = None,
|
93
92
|
sampled_history: Optional[pb.SampledHistoryRequest] = None,
|
@@ -129,8 +128,6 @@ class InterfaceShared(InterfaceBase):
|
|
129
128
|
request.internal_messages.CopyFrom(internal_messages)
|
130
129
|
elif network_status:
|
131
130
|
request.network_status.CopyFrom(network_status)
|
132
|
-
elif operation_stats:
|
133
|
-
request.operations.CopyFrom(operation_stats)
|
134
131
|
elif poll_exit:
|
135
132
|
request.poll_exit.CopyFrom(poll_exit)
|
136
133
|
elif partial_history:
|
@@ -424,10 +421,6 @@ class InterfaceShared(InterfaceBase):
|
|
424
421
|
record = self._make_record(exit=exit_data)
|
425
422
|
return self._deliver(record)
|
426
423
|
|
427
|
-
def deliver_operation_stats(self):
|
428
|
-
record = self._make_request(operation_stats=pb.OperationStatsRequest())
|
429
|
-
return self._deliver(record)
|
430
|
-
|
431
424
|
def _deliver_poll_exit(
|
432
425
|
self,
|
433
426
|
poll_exit: pb.PollExitRequest,
|