wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -2
- wandb/__init__.pyi +3 -6
- wandb/_iterutils.py +26 -7
- wandb/_pydantic/__init__.py +2 -1
- wandb/_pydantic/utils.py +7 -0
- wandb/agents/pyagent.py +9 -15
- wandb/analytics/sentry.py +1 -2
- wandb/apis/attrs.py +3 -4
- wandb/apis/importers/internals/util.py +1 -1
- wandb/apis/importers/validation.py +2 -2
- wandb/apis/importers/wandb.py +30 -25
- wandb/apis/normalize.py +2 -2
- wandb/apis/public/__init__.py +1 -0
- wandb/apis/public/api.py +37 -33
- wandb/apis/public/artifacts.py +103 -72
- wandb/apis/public/jobs.py +3 -2
- wandb/apis/public/registries/registries_search.py +4 -2
- wandb/apis/public/registries/registry.py +1 -1
- wandb/apis/public/registries/utils.py +9 -9
- wandb/apis/public/runs.py +18 -6
- wandb/automations/_filters/expressions.py +1 -1
- wandb/automations/_filters/operators.py +1 -1
- wandb/automations/_filters/run_metrics.py +1 -1
- wandb/beta/workflows.py +6 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +54 -73
- wandb/docker/__init__.py +21 -74
- wandb/docker/names.py +40 -0
- wandb/env.py +0 -1
- wandb/errors/util.py +1 -1
- wandb/filesync/step_checksum.py +1 -1
- wandb/filesync/step_upload.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +1 -2
- wandb/integration/gym/__init__.py +5 -6
- wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
- wandb/integration/keras/keras.py +13 -19
- wandb/integration/kfp/kfp_patch.py +2 -3
- wandb/integration/langchain/wandb_tracer.py +1 -1
- wandb/integration/metaflow/metaflow.py +13 -13
- wandb/integration/openai/fine_tuning.py +3 -2
- wandb/integration/sagemaker/auth.py +2 -1
- wandb/integration/sklearn/utils.py +2 -1
- wandb/integration/tensorboard/__init__.py +1 -1
- wandb/integration/tensorboard/log.py +2 -5
- wandb/integration/tensorflow/__init__.py +2 -2
- wandb/jupyter.py +20 -17
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/utils.py +8 -7
- wandb/proto/v3/wandb_internal_pb2.py +355 -335
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +339 -335
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +339 -335
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_internal_pb2.py +339 -335
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +6 -8
- wandb/sdk/artifacts/_internal_artifact.py +43 -0
- wandb/sdk/artifacts/_validators.py +55 -35
- wandb/sdk/artifacts/artifact.py +117 -115
- wandb/sdk/artifacts/artifact_download_logger.py +2 -0
- wandb/sdk/artifacts/artifact_saver.py +1 -3
- wandb/sdk/artifacts/artifact_state.py +2 -0
- wandb/sdk/artifacts/artifact_ttl.py +2 -0
- wandb/sdk/artifacts/exceptions.py +14 -0
- wandb/sdk/artifacts/staging.py +2 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/artifacts/storage_layout.py +2 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
- wandb/sdk/backend/backend.py +11 -182
- wandb/sdk/data_types/_dtypes.py +2 -6
- wandb/sdk/data_types/audio.py +20 -3
- wandb/sdk/data_types/base_types/media.py +12 -7
- wandb/sdk/data_types/base_types/wb_value.py +8 -18
- wandb/sdk/data_types/bokeh.py +19 -2
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
- wandb/sdk/data_types/helper_types/image_mask.py +7 -1
- wandb/sdk/data_types/html.py +4 -4
- wandb/sdk/data_types/image.py +178 -103
- wandb/sdk/data_types/molecule.py +6 -6
- wandb/sdk/data_types/object_3d.py +10 -5
- wandb/sdk/data_types/saved_model.py +11 -6
- wandb/sdk/data_types/table.py +313 -83
- wandb/sdk/data_types/table_decorators.py +108 -0
- wandb/sdk/data_types/utils.py +43 -7
- wandb/sdk/data_types/video.py +21 -3
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/internal/datastore.py +2 -6
- wandb/sdk/internal/file_pusher.py +1 -5
- wandb/sdk/internal/file_stream.py +8 -17
- wandb/sdk/internal/handler.py +2 -2
- wandb/sdk/internal/incremental_table_util.py +53 -0
- wandb/sdk/internal/internal.py +3 -5
- wandb/sdk/internal/internal_api.py +66 -89
- wandb/sdk/internal/job_builder.py +2 -7
- wandb/sdk/internal/profiler.py +2 -2
- wandb/sdk/internal/progress.py +1 -3
- wandb/sdk/internal/run.py +1 -6
- wandb/sdk/internal/sender.py +24 -36
- wandb/sdk/internal/system/assets/aggregators.py +1 -7
- wandb/sdk/internal/system/assets/disk.py +3 -3
- wandb/sdk/internal/system/assets/gpu.py +4 -4
- wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
- wandb/sdk/internal/system/assets/interfaces.py +6 -6
- wandb/sdk/internal/system/assets/tpu.py +1 -1
- wandb/sdk/internal/system/assets/trainium.py +6 -6
- wandb/sdk/internal/system/system_info.py +5 -7
- wandb/sdk/internal/system/system_monitor.py +4 -4
- wandb/sdk/internal/tb_watcher.py +5 -7
- wandb/sdk/launch/_launch.py +1 -1
- wandb/sdk/launch/_project_spec.py +19 -20
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/config.py +1 -1
- wandb/sdk/launch/agent/job_status_tracker.py +2 -2
- wandb/sdk/launch/builder/build.py +2 -3
- wandb/sdk/launch/builder/kaniko_builder.py +5 -4
- wandb/sdk/launch/environment/gcp_environment.py +1 -2
- wandb/sdk/launch/registry/azure_container_registry.py +2 -2
- wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
- wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
- wandb/sdk/launch/runner/abstract.py +5 -5
- wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
- wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
- wandb/sdk/launch/runner/vertex_runner.py +2 -7
- wandb/sdk/launch/sweeps/__init__.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -2
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -4
- wandb/sdk/lib/apikey.py +5 -8
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/fsm.py +3 -18
- wandb/sdk/lib/gitlib.py +6 -5
- wandb/sdk/lib/ipython.py +2 -2
- wandb/sdk/lib/json_util.py +9 -14
- wandb/sdk/lib/printer.py +3 -8
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/lib/retry.py +3 -7
- wandb/sdk/lib/run_moment.py +2 -2
- wandb/sdk/lib/service_connection.py +3 -1
- wandb/sdk/lib/service_token.py +1 -2
- wandb/sdk/mailbox/mailbox_handle.py +3 -7
- wandb/sdk/mailbox/response_handle.py +2 -6
- wandb/sdk/service/streams.py +3 -7
- wandb/sdk/verify/verify.py +5 -6
- wandb/sdk/wandb_config.py +1 -1
- wandb/sdk/wandb_init.py +38 -106
- wandb/sdk/wandb_login.py +7 -6
- wandb/sdk/wandb_run.py +52 -240
- wandb/sdk/wandb_settings.py +71 -60
- wandb/sdk/wandb_setup.py +40 -14
- wandb/sdk/wandb_watch.py +5 -7
- wandb/sync/__init__.py +1 -1
- wandb/sync/sync.py +13 -13
- wandb/util.py +17 -35
- wandb/wandb_agent.py +8 -11
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
- wandb/docker/auth.py +0 -435
- wandb/docker/www_authenticate.py +0 -94
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/artifacts.py
CHANGED
@@ -1,20 +1,11 @@
|
|
1
1
|
"""Public API: artifacts."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import json
|
4
6
|
import re
|
5
7
|
from copy import copy
|
6
|
-
from typing import
|
7
|
-
TYPE_CHECKING,
|
8
|
-
Any,
|
9
|
-
Iterable,
|
10
|
-
List,
|
11
|
-
Literal,
|
12
|
-
Mapping,
|
13
|
-
Optional,
|
14
|
-
Sequence,
|
15
|
-
Type,
|
16
|
-
Union,
|
17
|
-
)
|
8
|
+
from typing import TYPE_CHECKING, Any, Iterable, Literal, Mapping, Sequence
|
18
9
|
|
19
10
|
from typing_extensions import override
|
20
11
|
from wandb_gql import Client, gql
|
@@ -62,6 +53,7 @@ from wandb.sdk.artifacts._graphql_fragments import omit_artifact_fields
|
|
62
53
|
from wandb.sdk.artifacts._validators import (
|
63
54
|
SOURCE_ARTIFACT_COLLECTION_TYPE,
|
64
55
|
validate_artifact_name,
|
56
|
+
validate_artifact_type,
|
65
57
|
)
|
66
58
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
67
59
|
from wandb.sdk.lib import deprecate
|
@@ -69,13 +61,15 @@ from wandb.sdk.lib import deprecate
|
|
69
61
|
from .utils import gql_compat
|
70
62
|
|
71
63
|
if TYPE_CHECKING:
|
72
|
-
from wandb.
|
64
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
65
|
+
|
66
|
+
from . import RetryingClient, Run
|
73
67
|
|
74
68
|
|
75
69
|
class ArtifactTypes(Paginator["ArtifactType"]):
|
76
70
|
QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL)
|
77
71
|
|
78
|
-
last_response:
|
72
|
+
last_response: ArtifactTypesFragment | None
|
79
73
|
|
80
74
|
def __init__(
|
81
75
|
self,
|
@@ -117,7 +111,7 @@ class ArtifactTypes(Paginator["ArtifactType"]):
|
|
117
111
|
return self.last_response.page_info.has_next_page
|
118
112
|
|
119
113
|
@property
|
120
|
-
def cursor(self) ->
|
114
|
+
def cursor(self) -> str | None:
|
121
115
|
if self.last_response is None:
|
122
116
|
return None
|
123
117
|
return self.last_response.edges[-1].cursor
|
@@ -125,7 +119,7 @@ class ArtifactTypes(Paginator["ArtifactType"]):
|
|
125
119
|
def update_variables(self) -> None:
|
126
120
|
self.variables.update({"cursor": self.cursor})
|
127
121
|
|
128
|
-
def convert_objects(self) ->
|
122
|
+
def convert_objects(self) -> list[ArtifactType]:
|
129
123
|
if self.last_response is None:
|
130
124
|
return []
|
131
125
|
|
@@ -149,7 +143,7 @@ class ArtifactType:
|
|
149
143
|
entity: str,
|
150
144
|
project: str,
|
151
145
|
type_name: str,
|
152
|
-
attrs:
|
146
|
+
attrs: Mapping[str, Any] | None = None,
|
153
147
|
):
|
154
148
|
self.client = client
|
155
149
|
self.entity = entity
|
@@ -159,8 +153,8 @@ class ArtifactType:
|
|
159
153
|
if self._attrs is None:
|
160
154
|
self.load()
|
161
155
|
|
162
|
-
def load(self):
|
163
|
-
data:
|
156
|
+
def load(self) -> Mapping[str, Any]:
|
157
|
+
data: Mapping[str, Any] | None = self.client.execute(
|
164
158
|
gql(PROJECT_ARTIFACT_TYPE_GQL),
|
165
159
|
variable_values={
|
166
160
|
"entityName": self.entity,
|
@@ -176,29 +170,29 @@ class ArtifactType:
|
|
176
170
|
return self._attrs
|
177
171
|
|
178
172
|
@property
|
179
|
-
def id(self):
|
173
|
+
def id(self) -> str:
|
180
174
|
return self._attrs["id"]
|
181
175
|
|
182
176
|
@property
|
183
|
-
def name(self):
|
177
|
+
def name(self) -> str:
|
184
178
|
return self._attrs["name"]
|
185
179
|
|
186
180
|
@normalize_exceptions
|
187
|
-
def collections(self, per_page=50):
|
181
|
+
def collections(self, per_page: int = 50) -> ArtifactCollections:
|
188
182
|
"""Artifact collections."""
|
189
183
|
return ArtifactCollections(self.client, self.entity, self.project, self.type)
|
190
184
|
|
191
|
-
def collection(self, name):
|
185
|
+
def collection(self, name: str) -> ArtifactCollection:
|
192
186
|
return ArtifactCollection(
|
193
187
|
self.client, self.entity, self.project, name, self.type
|
194
188
|
)
|
195
189
|
|
196
|
-
def __repr__(self):
|
190
|
+
def __repr__(self) -> str:
|
197
191
|
return f"<ArtifactType {self.type}>"
|
198
192
|
|
199
193
|
|
200
194
|
class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
|
201
|
-
last_response:
|
195
|
+
last_response: ArtifactCollectionsFragment | None
|
202
196
|
|
203
197
|
def __init__(
|
204
198
|
self,
|
@@ -266,7 +260,10 @@ class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
|
|
266
260
|
def update_variables(self) -> None:
|
267
261
|
self.variables.update({"cursor": self.cursor})
|
268
262
|
|
269
|
-
def convert_objects(self) ->
|
263
|
+
def convert_objects(self) -> list[ArtifactCollection]:
|
264
|
+
if self.last_response is None:
|
265
|
+
return []
|
266
|
+
|
270
267
|
return [
|
271
268
|
ArtifactCollection(
|
272
269
|
client=self.client,
|
@@ -288,9 +285,9 @@ class ArtifactCollection:
|
|
288
285
|
project: str,
|
289
286
|
name: str,
|
290
287
|
type: str,
|
291
|
-
organization:
|
292
|
-
attrs:
|
293
|
-
is_sequence:
|
288
|
+
organization: str | None = None,
|
289
|
+
attrs: Mapping[str, Any] | None = None,
|
290
|
+
is_sequence: bool | None = None,
|
294
291
|
):
|
295
292
|
self.client = client
|
296
293
|
self.entity = entity
|
@@ -302,8 +299,7 @@ class ArtifactCollection:
|
|
302
299
|
self._attrs = attrs
|
303
300
|
if is_sequence is not None:
|
304
301
|
self._is_sequence = is_sequence
|
305
|
-
|
306
|
-
if not is_loaded:
|
302
|
+
if (attrs is None) or (is_sequence is None):
|
307
303
|
self.load()
|
308
304
|
self._aliases = [a["node"]["alias"] for a in self._attrs["aliases"]["edges"]]
|
309
305
|
self._description = self._attrs["description"]
|
@@ -317,7 +313,7 @@ class ArtifactCollection:
|
|
317
313
|
return self._attrs["id"]
|
318
314
|
|
319
315
|
@normalize_exceptions
|
320
|
-
def artifacts(self, per_page: int = 50) ->
|
316
|
+
def artifacts(self, per_page: int = 50) -> Artifacts:
|
321
317
|
"""Artifacts."""
|
322
318
|
return Artifacts(
|
323
319
|
client=self.client,
|
@@ -329,7 +325,7 @@ class ArtifactCollection:
|
|
329
325
|
)
|
330
326
|
|
331
327
|
@property
|
332
|
-
def aliases(self) ->
|
328
|
+
def aliases(self) -> list[str]:
|
333
329
|
"""Artifact Collection Aliases."""
|
334
330
|
return self._aliases
|
335
331
|
|
@@ -372,6 +368,7 @@ class ArtifactCollection:
|
|
372
368
|
self._attrs = collection.model_dump(exclude_unset=True)
|
373
369
|
return self._attrs
|
374
370
|
|
371
|
+
@normalize_exceptions
|
375
372
|
def change_type(self, new_type: str) -> None:
|
376
373
|
"""Deprecated, change type directly with `save` instead."""
|
377
374
|
deprecate.deprecate(
|
@@ -379,6 +376,17 @@ class ArtifactCollection:
|
|
379
376
|
warning_message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.",
|
380
377
|
)
|
381
378
|
|
379
|
+
if self._saved_type != new_type:
|
380
|
+
try:
|
381
|
+
validate_artifact_type(self._saved_type, self.name)
|
382
|
+
except ValueError as e:
|
383
|
+
raise ValueError(
|
384
|
+
f"The current type '{self._saved_type!r}' is an internal type and cannot be changed."
|
385
|
+
) from e
|
386
|
+
|
387
|
+
# Check that the new type is not going to conflict with internal types
|
388
|
+
validate_artifact_type(new_type, self.name)
|
389
|
+
|
382
390
|
if not self.is_sequence():
|
383
391
|
raise ValueError("Artifact collection needs to be a sequence")
|
384
392
|
termlog(
|
@@ -416,16 +424,16 @@ class ArtifactCollection:
|
|
416
424
|
return self._description
|
417
425
|
|
418
426
|
@description.setter
|
419
|
-
def description(self, description:
|
427
|
+
def description(self, description: str | None) -> None:
|
420
428
|
self._description = description
|
421
429
|
|
422
430
|
@property
|
423
|
-
def tags(self) ->
|
431
|
+
def tags(self) -> list[str]:
|
424
432
|
"""The tags associated with the artifact collection."""
|
425
433
|
return self._tags
|
426
434
|
|
427
435
|
@tags.setter
|
428
|
-
def tags(self, tags:
|
436
|
+
def tags(self, tags: list[str]) -> None:
|
429
437
|
if any(not re.match(r"^[-\w]+([ ]+[-\w]+)*$", tag) for tag in tags):
|
430
438
|
raise ValueError(
|
431
439
|
"Tags must only contain alphanumeric characters or underscores separated by spaces or hyphens"
|
@@ -447,7 +455,7 @@ class ArtifactCollection:
|
|
447
455
|
return self._type
|
448
456
|
|
449
457
|
@type.setter
|
450
|
-
def type(self, type:
|
458
|
+
def type(self, type: list[str]) -> None:
|
451
459
|
if not self.is_sequence():
|
452
460
|
raise ValueError(
|
453
461
|
"Type can only be changed if the artifact collection is a sequence."
|
@@ -501,8 +509,22 @@ class ArtifactCollection:
|
|
501
509
|
},
|
502
510
|
)
|
503
511
|
|
512
|
+
@normalize_exceptions
|
504
513
|
def save(self) -> None:
|
505
514
|
"""Persist any changes made to the artifact collection."""
|
515
|
+
if self._saved_type != self.type:
|
516
|
+
try:
|
517
|
+
validate_artifact_type(self.type, self._name)
|
518
|
+
except ValueError as e:
|
519
|
+
raise ValueError(f"Failed to save artifact collection: {e}") from e
|
520
|
+
try:
|
521
|
+
validate_artifact_type(self._saved_type, self._name)
|
522
|
+
except ValueError as e:
|
523
|
+
raise ValueError(
|
524
|
+
f"Failed to save artifact collection '{self._name}': "
|
525
|
+
f"The current type '{self._saved_type!r}' is an internal type and cannot be changed."
|
526
|
+
) from e
|
527
|
+
|
506
528
|
self._update_collection()
|
507
529
|
|
508
530
|
if self.is_sequence() and (self._saved_type != self._type):
|
@@ -520,13 +542,13 @@ class ArtifactCollection:
|
|
520
542
|
return f"<ArtifactCollection {self._name} ({self._type})>"
|
521
543
|
|
522
544
|
|
523
|
-
class Artifacts(SizedPaginator["
|
545
|
+
class Artifacts(SizedPaginator["Artifact"]):
|
524
546
|
"""An iterable collection of artifact versions associated with a project and optional filter.
|
525
547
|
|
526
548
|
This is generally used indirectly via the `Api`.artifact_versions method.
|
527
549
|
"""
|
528
550
|
|
529
|
-
last_response:
|
551
|
+
last_response: ArtifactsFragment | None
|
530
552
|
|
531
553
|
def __init__(
|
532
554
|
self,
|
@@ -535,10 +557,10 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
|
|
535
557
|
project: str,
|
536
558
|
collection_name: str,
|
537
559
|
type: str,
|
538
|
-
filters:
|
539
|
-
order:
|
560
|
+
filters: Mapping[str, Any] | None = None,
|
561
|
+
order: str | None = None,
|
540
562
|
per_page: int = 50,
|
541
|
-
tags:
|
563
|
+
tags: str | list[str] | None = None,
|
542
564
|
):
|
543
565
|
self.entity = entity
|
544
566
|
self.collection_name = collection_name
|
@@ -586,7 +608,7 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
|
|
586
608
|
self.last_response = ArtifactsFragment.model_validate(conn)
|
587
609
|
|
588
610
|
@property
|
589
|
-
def length(self) ->
|
611
|
+
def length(self) -> int | None:
|
590
612
|
if self.last_response is None:
|
591
613
|
return None
|
592
614
|
return self.last_response.total_count
|
@@ -598,12 +620,15 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
|
|
598
620
|
return self.last_response.page_info.has_next_page
|
599
621
|
|
600
622
|
@property
|
601
|
-
def cursor(self) ->
|
623
|
+
def cursor(self) -> str | None:
|
602
624
|
if self.last_response is None:
|
603
625
|
return None
|
604
626
|
return self.last_response.edges[-1].cursor
|
605
627
|
|
606
|
-
def convert_objects(self) ->
|
628
|
+
def convert_objects(self) -> list[Artifact]:
|
629
|
+
if self.last_response is None:
|
630
|
+
return []
|
631
|
+
|
607
632
|
artifact_edges = (edge for edge in self.last_response.edges if edge.node)
|
608
633
|
artifacts = (
|
609
634
|
wandb.Artifact._from_attrs(
|
@@ -619,24 +644,22 @@ class Artifacts(SizedPaginator["wandb.Artifact"]):
|
|
619
644
|
return [art for art in artifacts if required_tags.issubset(art.tags)]
|
620
645
|
|
621
646
|
|
622
|
-
class RunArtifacts(SizedPaginator["
|
623
|
-
last_response:
|
624
|
-
RunOutputArtifactsProjectRunOutputArtifacts
|
625
|
-
RunInputArtifactsProjectRunInputArtifacts
|
626
|
-
|
647
|
+
class RunArtifacts(SizedPaginator["Artifact"]):
|
648
|
+
last_response: (
|
649
|
+
RunOutputArtifactsProjectRunOutputArtifacts
|
650
|
+
| RunInputArtifactsProjectRunInputArtifacts
|
651
|
+
)
|
627
652
|
|
628
653
|
#: The pydantic model used to parse the (inner part of the) raw response.
|
629
|
-
_response_cls:
|
630
|
-
|
631
|
-
|
632
|
-
RunInputArtifactsProjectRunInputArtifacts,
|
633
|
-
]
|
654
|
+
_response_cls: type[
|
655
|
+
RunOutputArtifactsProjectRunOutputArtifacts
|
656
|
+
| RunInputArtifactsProjectRunInputArtifacts
|
634
657
|
]
|
635
658
|
|
636
659
|
def __init__(
|
637
660
|
self,
|
638
661
|
client: Client,
|
639
|
-
run:
|
662
|
+
run: Run,
|
640
663
|
mode: Literal["logged", "used"] = "logged",
|
641
664
|
per_page: int = 50,
|
642
665
|
):
|
@@ -675,7 +698,7 @@ class RunArtifacts(SizedPaginator["wandb.Artifact"]):
|
|
675
698
|
self.last_response = self._response_cls.model_validate(inner_data)
|
676
699
|
|
677
700
|
@property
|
678
|
-
def length(self) ->
|
701
|
+
def length(self) -> int | None:
|
679
702
|
if self.last_response is None:
|
680
703
|
return None
|
681
704
|
return self.last_response.total_count
|
@@ -687,36 +710,41 @@ class RunArtifacts(SizedPaginator["wandb.Artifact"]):
|
|
687
710
|
return self.last_response.page_info.has_next_page
|
688
711
|
|
689
712
|
@property
|
690
|
-
def cursor(self) ->
|
713
|
+
def cursor(self) -> str | None:
|
691
714
|
if self.last_response is None:
|
692
715
|
return None
|
693
716
|
return self.last_response.edges[-1].cursor
|
694
717
|
|
695
|
-
def convert_objects(self) ->
|
718
|
+
def convert_objects(self) -> list[Artifact]:
|
719
|
+
if self.last_response is None:
|
720
|
+
return []
|
721
|
+
|
696
722
|
return [
|
697
723
|
wandb.Artifact._from_attrs(
|
698
|
-
entity=
|
699
|
-
project=
|
700
|
-
name=f"{
|
724
|
+
entity=proj.entity_name,
|
725
|
+
project=proj.name,
|
726
|
+
name=f"{artifact_seq.name}:v{node.version_index}",
|
701
727
|
attrs=node.model_dump(exclude_unset=True),
|
702
728
|
client=self.client,
|
703
729
|
)
|
704
|
-
for
|
705
|
-
if (node :=
|
730
|
+
for edge in self.last_response.edges
|
731
|
+
if (node := edge.node)
|
732
|
+
and (artifact_seq := node.artifact_sequence)
|
733
|
+
and (proj := artifact_seq.project)
|
706
734
|
]
|
707
735
|
|
708
736
|
|
709
737
|
class ArtifactFiles(SizedPaginator["public.File"]):
|
710
|
-
last_response:
|
738
|
+
last_response: FilesFragment | None
|
711
739
|
|
712
740
|
def __init__(
|
713
741
|
self,
|
714
742
|
client: Client,
|
715
|
-
artifact:
|
716
|
-
names:
|
743
|
+
artifact: Artifact,
|
744
|
+
names: Sequence[str] | None = None,
|
717
745
|
per_page: int = 50,
|
718
746
|
):
|
719
|
-
self.query_via_membership = InternalApi().
|
747
|
+
self.query_via_membership = InternalApi()._server_supports(
|
720
748
|
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
|
721
749
|
)
|
722
750
|
self.artifact = artifact
|
@@ -767,7 +795,7 @@ class ArtifactFiles(SizedPaginator["public.File"]):
|
|
767
795
|
self.last_response = FilesFragment.model_validate(conn)
|
768
796
|
|
769
797
|
@property
|
770
|
-
def path(self) ->
|
798
|
+
def path(self) -> list[str]:
|
771
799
|
return [self.artifact.entity, self.artifact.project, self.artifact.name]
|
772
800
|
|
773
801
|
@property
|
@@ -781,7 +809,7 @@ class ArtifactFiles(SizedPaginator["public.File"]):
|
|
781
809
|
return self.last_response.page_info.has_next_page
|
782
810
|
|
783
811
|
@property
|
784
|
-
def cursor(self) ->
|
812
|
+
def cursor(self) -> str | None:
|
785
813
|
if self.last_response is None:
|
786
814
|
return None
|
787
815
|
return self.last_response.edges[-1].cursor
|
@@ -789,7 +817,10 @@ class ArtifactFiles(SizedPaginator["public.File"]):
|
|
789
817
|
def update_variables(self) -> None:
|
790
818
|
self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
|
791
819
|
|
792
|
-
def convert_objects(self) ->
|
820
|
+
def convert_objects(self) -> list[public.File]:
|
821
|
+
if self.last_response is None:
|
822
|
+
return []
|
823
|
+
|
793
824
|
return [
|
794
825
|
public.File(
|
795
826
|
client=self.client,
|
@@ -805,7 +836,7 @@ class ArtifactFiles(SizedPaginator["public.File"]):
|
|
805
836
|
|
806
837
|
|
807
838
|
def server_supports_artifact_collections_gql_edges(
|
808
|
-
client:
|
839
|
+
client: RetryingClient, warn: bool = False
|
809
840
|
) -> bool:
|
810
841
|
# TODO: Validate this version
|
811
842
|
# Edges were merged into core on Mar 2, 2022: https://github.com/wandb/core/commit/81c90b29eaacfe0a96dc1ebd83c53560ca763e8b
|
wandb/apis/public/jobs.py
CHANGED
@@ -405,9 +405,10 @@ class QueuedRun:
|
|
405
405
|
None,
|
406
406
|
)
|
407
407
|
self._run_id = item["associatedRunId"]
|
408
|
-
return self._run
|
409
408
|
except ValueError as e:
|
410
|
-
wandb.termwarn(e)
|
409
|
+
wandb.termwarn(str(e))
|
410
|
+
else:
|
411
|
+
return self._run
|
411
412
|
elif item:
|
412
413
|
wandb.termlog("Waiting for run to start")
|
413
414
|
|
@@ -10,13 +10,13 @@ from wandb_gql import gql
|
|
10
10
|
|
11
11
|
import wandb
|
12
12
|
from wandb.apis.paginator import Paginator
|
13
|
-
from wandb.apis.public.artifacts import ArtifactCollection
|
14
|
-
from wandb.apis.public.registries.utils import _ensure_registry_prefix_on_names
|
15
13
|
from wandb.sdk.artifacts._graphql_fragments import (
|
16
14
|
_gql_artifact_fragment,
|
17
15
|
_gql_registry_fragment,
|
18
16
|
)
|
19
17
|
|
18
|
+
from .utils import _ensure_registry_prefix_on_names
|
19
|
+
|
20
20
|
|
21
21
|
class Registries(Paginator):
|
22
22
|
"""Iterator that returns Registries."""
|
@@ -286,6 +286,8 @@ class Collections(Paginator):
|
|
286
286
|
return None
|
287
287
|
|
288
288
|
def convert_objects(self):
|
289
|
+
from wandb.apis.public import ArtifactCollection
|
290
|
+
|
289
291
|
if not self.last_response:
|
290
292
|
return []
|
291
293
|
if (
|
@@ -296,7 +296,7 @@ class Registry:
|
|
296
296
|
|
297
297
|
def save(self) -> None:
|
298
298
|
"""Save registry attributes to the backend."""
|
299
|
-
if not InternalApi().
|
299
|
+
if not InternalApi()._server_supports(
|
300
300
|
ServerFeature.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION
|
301
301
|
):
|
302
302
|
raise RuntimeError(
|
@@ -125,16 +125,16 @@ def _fetch_org_entity_from_organization(client: "Client", organization: str) ->
|
|
125
125
|
)
|
126
126
|
try:
|
127
127
|
response = client.execute(query, variable_values={"organization": organization})
|
128
|
-
if response["organization"] and response["organization"]["orgEntity"]:
|
129
|
-
if response["organization"]["orgEntity"]["name"]:
|
130
|
-
return response["organization"]["orgEntity"]["name"]
|
131
|
-
return ValueError(
|
132
|
-
f"Organization entity for organization: {organization} is empty"
|
133
|
-
)
|
134
|
-
raise ValueError(
|
135
|
-
f"Organization entity for organization: {organization} not found"
|
136
|
-
)
|
137
128
|
except Exception as e:
|
138
129
|
raise ValueError(
|
139
130
|
f"Error fetching org entity for organization: {organization}"
|
140
131
|
) from e
|
132
|
+
|
133
|
+
if (
|
134
|
+
not (org := response["organization"])
|
135
|
+
or not (org_entity := org["orgEntity"])
|
136
|
+
or not (org_name := org_entity["name"])
|
137
|
+
):
|
138
|
+
raise ValueError(f"Organization entity for {organization} not found.")
|
139
|
+
|
140
|
+
return org_name
|
wandb/apis/public/runs.py
CHANGED
@@ -400,14 +400,21 @@ class Run(Attrs):
|
|
400
400
|
return new_name
|
401
401
|
|
402
402
|
@classmethod
|
403
|
-
def create(
|
403
|
+
def create(
|
404
|
+
cls,
|
405
|
+
api,
|
406
|
+
run_id=None,
|
407
|
+
project=None,
|
408
|
+
entity=None,
|
409
|
+
state: Literal["running", "pending"] = "running",
|
410
|
+
):
|
404
411
|
"""Create a run for the given project."""
|
405
412
|
run_id = run_id or runid.generate_id()
|
406
413
|
project = project or api.settings.get("project") or "uncategorized"
|
407
414
|
mutation = gql(
|
408
415
|
"""
|
409
|
-
mutation UpsertBucket($project: String, $entity: String, $name: String
|
410
|
-
upsertBucket(input: {modelName: $project, entityName: $entity, name: $name}) {
|
416
|
+
mutation UpsertBucket($project: String, $entity: String, $name: String!, $state: String) {
|
417
|
+
upsertBucket(input: {modelName: $project, entityName: $entity, name: $name, state: $state}) {
|
411
418
|
bucket {
|
412
419
|
project {
|
413
420
|
name
|
@@ -421,7 +428,12 @@ class Run(Attrs):
|
|
421
428
|
}
|
422
429
|
"""
|
423
430
|
)
|
424
|
-
variables = {
|
431
|
+
variables = {
|
432
|
+
"entity": entity,
|
433
|
+
"project": project,
|
434
|
+
"name": run_id,
|
435
|
+
"state": state,
|
436
|
+
}
|
425
437
|
res = api.client.execute(mutation, variable_values=variables)
|
426
438
|
res = res["upsertBucket"]["bucket"]
|
427
439
|
return Run(
|
@@ -437,7 +449,7 @@ class Run(Attrs):
|
|
437
449
|
"tags": [],
|
438
450
|
"description": None,
|
439
451
|
"notes": None,
|
440
|
-
"state":
|
452
|
+
"state": state,
|
441
453
|
},
|
442
454
|
)
|
443
455
|
|
@@ -918,7 +930,7 @@ class Run(Attrs):
|
|
918
930
|
api.set_current_run_id(self.id)
|
919
931
|
|
920
932
|
if not isinstance(artifact, wandb.Artifact):
|
921
|
-
raise
|
933
|
+
raise TypeError("You must pass a wandb.Api().artifact() to use_artifact")
|
922
934
|
if artifact.is_draft():
|
923
935
|
raise ValueError(
|
924
936
|
"Only existing artifacts are accepted by this api. "
|
@@ -147,7 +147,7 @@ class FilterExpr(CompatBaseModel, SupportsLogicalOpSyntax):
|
|
147
147
|
def __repr__(self) -> str:
|
148
148
|
return f"{type(self).__name__}({self.field!s}: {self.op!r})"
|
149
149
|
|
150
|
-
def __rich_repr__(self) -> RichReprResult:
|
150
|
+
def __rich_repr__(self) -> RichReprResult:
|
151
151
|
# https://rich.readthedocs.io/en/stable/pretty.html
|
152
152
|
yield self.field, self.op
|
153
153
|
|
@@ -64,7 +64,7 @@ class BaseOp(GQLBase, SupportsLogicalOpSyntax):
|
|
64
64
|
values_repr = ", ".join(map(repr, self.model_dump().values()))
|
65
65
|
return f"{type(self).__name__}({values_repr})"
|
66
66
|
|
67
|
-
def __rich_repr__(self) -> RichReprResult:
|
67
|
+
def __rich_repr__(self) -> RichReprResult:
|
68
68
|
# Display field values as positional args:
|
69
69
|
# https://rich.readthedocs.io/en/stable/pretty.html
|
70
70
|
yield from ((None, v) for v in self.model_dump().values())
|
@@ -112,7 +112,7 @@ class BaseMetricFilter(GQLBase, ABC, extra="forbid"):
|
|
112
112
|
raise NotImplementedError
|
113
113
|
|
114
114
|
@override
|
115
|
-
def __rich_repr__(self) -> RichReprResult:
|
115
|
+
def __rich_repr__(self) -> RichReprResult:
|
116
116
|
"""The representation of the metric filter when using `rich` for pretty-printing."""
|
117
117
|
# See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol
|
118
118
|
yield None, repr(self)
|
wandb/beta/workflows.py
CHANGED
@@ -47,8 +47,9 @@ def _add_any(
|
|
47
47
|
with artifact.new_file(name) as f:
|
48
48
|
f.write(json.dumps(path_or_obj, sort_keys=True))
|
49
49
|
else:
|
50
|
-
raise
|
51
|
-
|
50
|
+
raise TypeError(
|
51
|
+
"Expected `path_or_obj` to be instance of `ArtifactManifestEntry`,"
|
52
|
+
f" `WBValue`, or `str, found {type(path_or_obj)}"
|
52
53
|
)
|
53
54
|
|
54
55
|
|
@@ -86,7 +87,7 @@ def _log_artifact_version(
|
|
86
87
|
Artifact
|
87
88
|
|
88
89
|
"""
|
89
|
-
run = wandb_setup.
|
90
|
+
run = wandb_setup.singleton().most_recent_active_run
|
90
91
|
if not run:
|
91
92
|
run = wandb.init(
|
92
93
|
project=project,
|
@@ -217,7 +218,7 @@ def use_model(aliased_path: str, unsafe: bool = False) -> "_SavedModel":
|
|
217
218
|
)
|
218
219
|
|
219
220
|
# Returns a _SavedModel instance
|
220
|
-
if run := wandb_setup.
|
221
|
+
if run := wandb_setup.singleton().most_recent_active_run:
|
221
222
|
artifact = run.use_artifact(aliased_path)
|
222
223
|
sm = artifact.get("index")
|
223
224
|
|
@@ -262,7 +263,7 @@ def link_model(
|
|
262
263
|
"""
|
263
264
|
aliases = wandb.util._resolve_aliases(aliases)
|
264
265
|
|
265
|
-
if run := wandb_setup.
|
266
|
+
if run := wandb_setup.singleton().most_recent_active_run:
|
266
267
|
# _artifact_source, if it exists, points to a Public Artifact.
|
267
268
|
# Its existence means that _SavedModel was deserialized from a logged artifact, most likely from `use_model`.
|
268
269
|
if model._artifact_source:
|
wandb/bin/gpu_stats.exe
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|