wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +92 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +26 -11
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +2 -8
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +68 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +186 -224
- wandb/sdk/launch/agent/agent.py +37 -13
- wandb/sdk/launch/agent/config.py +72 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +47 -37
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +15 -140
- wandb/sdk/lib/_settings_toposort_generated.py +0 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +14 -55
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +114 -56
- wandb/sdk/wandb_settings.py +0 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
"""WandB storage policy."""
|
2
|
+
|
2
3
|
import hashlib
|
3
4
|
import math
|
4
5
|
import os
|
@@ -262,7 +263,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
262
263
|
return math.ceil(file_size / S3_MAX_PART_NUMBERS)
|
263
264
|
return default_chunk_size
|
264
265
|
|
265
|
-
def
|
266
|
+
def store_file(
|
266
267
|
self,
|
267
268
|
artifact_id: str,
|
268
269
|
artifact_manifest_id: str,
|
@@ -300,7 +301,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
300
301
|
hex_digests[part_number] = hex_digest
|
301
302
|
part_number += 1
|
302
303
|
|
303
|
-
resp = preparer.
|
304
|
+
resp = preparer.prepare(
|
304
305
|
{
|
305
306
|
"artifactID": artifact_id,
|
306
307
|
"artifactManifestID": artifact_manifest_id,
|
@@ -346,46 +347,6 @@ class WandbStoragePolicy(StoragePolicy):
|
|
346
347
|
|
347
348
|
return False
|
348
349
|
|
349
|
-
async def store_file_async(
|
350
|
-
self,
|
351
|
-
artifact_id: str,
|
352
|
-
artifact_manifest_id: str,
|
353
|
-
entry: "ArtifactManifestEntry",
|
354
|
-
preparer: "StepPrepare",
|
355
|
-
progress_callback: Optional["progress.ProgressFn"] = None,
|
356
|
-
) -> bool:
|
357
|
-
"""Async equivalent to `store_file_sync`."""
|
358
|
-
resp = await preparer.prepare_async(
|
359
|
-
{
|
360
|
-
"artifactID": artifact_id,
|
361
|
-
"artifactManifestID": artifact_manifest_id,
|
362
|
-
"name": entry.path,
|
363
|
-
"md5": entry.digest,
|
364
|
-
}
|
365
|
-
)
|
366
|
-
|
367
|
-
entry.birth_artifact_id = resp.birth_artifact_id
|
368
|
-
if resp.upload_url is None:
|
369
|
-
return True
|
370
|
-
if entry.local_path is None:
|
371
|
-
return False
|
372
|
-
|
373
|
-
with open(entry.local_path, "rb") as file:
|
374
|
-
# This fails if we don't send the first byte before the signed URL expires.
|
375
|
-
await self._api.upload_file_retry_async(
|
376
|
-
resp.upload_url,
|
377
|
-
file,
|
378
|
-
progress_callback,
|
379
|
-
extra_headers={
|
380
|
-
header.split(":", 1)[0]: header.split(":", 1)[1]
|
381
|
-
for header in (resp.upload_headers or {})
|
382
|
-
},
|
383
|
-
)
|
384
|
-
|
385
|
-
self._write_cache(entry)
|
386
|
-
|
387
|
-
return False
|
388
|
-
|
389
350
|
def _write_cache(self, entry: "ArtifactManifestEntry") -> None:
|
390
351
|
if entry.local_path is None:
|
391
352
|
return
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Storage policy."""
|
2
|
+
|
2
3
|
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Type, Union
|
3
4
|
|
4
5
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
@@ -42,17 +43,7 @@ class StoragePolicy:
|
|
42
43
|
) -> FilePathStr:
|
43
44
|
raise NotImplementedError
|
44
45
|
|
45
|
-
def
|
46
|
-
self,
|
47
|
-
artifact_id: str,
|
48
|
-
artifact_manifest_id: str,
|
49
|
-
entry: "ArtifactManifestEntry",
|
50
|
-
preparer: "StepPrepare",
|
51
|
-
progress_callback: Optional["ProgressFn"] = None,
|
52
|
-
) -> bool:
|
53
|
-
raise NotImplementedError
|
54
|
-
|
55
|
-
async def store_file_async(
|
46
|
+
def store_file(
|
56
47
|
self,
|
57
48
|
artifact_id: str,
|
58
49
|
artifact_manifest_id: str,
|
@@ -60,7 +51,6 @@ class StoragePolicy:
|
|
60
51
|
preparer: "StepPrepare",
|
61
52
|
progress_callback: Optional["ProgressFn"] = None,
|
62
53
|
) -> bool:
|
63
|
-
"""Async equivalent to `store_file_sync`."""
|
64
54
|
raise NotImplementedError
|
65
55
|
|
66
56
|
def store_reference(
|
wandb/sdk/data_types/_dtypes.py
CHANGED
@@ -14,7 +14,7 @@ np = get_module("numpy") # intentionally not required
|
|
14
14
|
if t.TYPE_CHECKING:
|
15
15
|
from wandb.sdk.artifacts.artifact import Artifact
|
16
16
|
|
17
|
-
|
17
|
+
ConvertibleToType = t.Union["Type", t.Type["Type"], type, t.Any]
|
18
18
|
|
19
19
|
|
20
20
|
class TypeRegistry:
|
@@ -84,7 +84,7 @@ class TypeRegistry:
|
|
84
84
|
return _type.from_json(json_dict, artifact)
|
85
85
|
|
86
86
|
@staticmethod
|
87
|
-
def type_from_dtype(dtype:
|
87
|
+
def type_from_dtype(dtype: ConvertibleToType) -> "Type":
|
88
88
|
# The dtype is already an instance of Type
|
89
89
|
if isinstance(dtype, Type):
|
90
90
|
wbtype: Type = dtype
|
@@ -528,7 +528,7 @@ class UnionType(Type):
|
|
528
528
|
|
529
529
|
def __init__(
|
530
530
|
self,
|
531
|
-
allowed_types: t.Optional[t.Sequence[
|
531
|
+
allowed_types: t.Optional[t.Sequence[ConvertibleToType]] = None,
|
532
532
|
):
|
533
533
|
assert allowed_types is None or (allowed_types.__class__ == list)
|
534
534
|
if allowed_types is None:
|
@@ -576,7 +576,7 @@ class UnionType(Type):
|
|
576
576
|
return "{}".format(" or ".join([str(t) for t in self.params["allowed_types"]]))
|
577
577
|
|
578
578
|
|
579
|
-
def OptionalType(dtype:
|
579
|
+
def OptionalType(dtype: ConvertibleToType) -> UnionType: # noqa: N802
|
580
580
|
"""Function that mimics the Type class API for constructing an "Optional Type".
|
581
581
|
|
582
582
|
This is just a Union[wb_type, NoneType].
|
@@ -591,14 +591,14 @@ def OptionalType(dtype: ConvertableToType) -> UnionType: # noqa: N802
|
|
591
591
|
|
592
592
|
|
593
593
|
class ListType(Type):
|
594
|
-
"""A list of
|
594
|
+
"""A list of homogeneous types."""
|
595
595
|
|
596
596
|
name = "list"
|
597
597
|
types: t.ClassVar[t.List[type]] = [list, tuple, set, frozenset]
|
598
598
|
|
599
599
|
def __init__(
|
600
600
|
self,
|
601
|
-
element_type: t.Optional[
|
601
|
+
element_type: t.Optional[ConvertibleToType] = None,
|
602
602
|
length: t.Optional[int] = None,
|
603
603
|
):
|
604
604
|
if element_type is None:
|
@@ -691,7 +691,7 @@ class ListType(Type):
|
|
691
691
|
|
692
692
|
|
693
693
|
class NDArrayType(Type):
|
694
|
-
"""Represents a list of
|
694
|
+
"""Represents a list of homogeneous types."""
|
695
695
|
|
696
696
|
name = "ndarray"
|
697
697
|
types: t.ClassVar[t.List[type]] = [] # will manually add type if np is available
|
@@ -786,7 +786,7 @@ class TypedDictType(Type):
|
|
786
786
|
|
787
787
|
def __init__(
|
788
788
|
self,
|
789
|
-
type_map: t.Optional[t.Dict[str,
|
789
|
+
type_map: t.Optional[t.Dict[str, ConvertibleToType]] = None,
|
790
790
|
):
|
791
791
|
if type_map is None:
|
792
792
|
type_map = {}
|
@@ -177,9 +177,7 @@ class Media(WBValue):
|
|
177
177
|
json_obj["_latest_artifact_path"] = artifact_entry_latest_url
|
178
178
|
|
179
179
|
if artifact_entry_url is None or self.is_bound():
|
180
|
-
assert (
|
181
|
-
self.is_bound()
|
182
|
-
), "Value of type {} must be bound to a run with bind_to_run() before being serialized to JSON.".format(
|
180
|
+
assert self.is_bound(), "Value of type {} must be bound to a run with bind_to_run() before being serialized to JSON.".format(
|
183
181
|
type(self).__name__
|
184
182
|
)
|
185
183
|
|
@@ -221,8 +219,7 @@ class Media(WBValue):
|
|
221
219
|
|
222
220
|
# if not, check to see if there is a source artifact for this object
|
223
221
|
if (
|
224
|
-
self._artifact_source
|
225
|
-
is not None
|
222
|
+
self._artifact_source is not None
|
226
223
|
# and self._artifact_source.artifact != artifact
|
227
224
|
):
|
228
225
|
default_root = self._artifact_source.artifact._default_root()
|
@@ -292,7 +289,7 @@ class BatchableMedia(Media):
|
|
292
289
|
|
293
290
|
|
294
291
|
def _numpy_arrays_to_lists(
|
295
|
-
payload: Union[dict, Sequence, "np.ndarray"]
|
292
|
+
payload: Union[dict, Sequence, "np.ndarray"],
|
296
293
|
) -> Union[Sequence, dict, str, int, float, bool]:
|
297
294
|
# Casts all numpy arrays to lists so we don't convert them to histograms, primarily for Plotly
|
298
295
|
|
@@ -231,8 +231,10 @@ class BoundingBoxes2D(JSONMetadata):
|
|
231
231
|
|
232
232
|
for box in boxes:
|
233
233
|
# Required arguments
|
234
|
-
error_str =
|
234
|
+
error_str = (
|
235
|
+
"Each box must contain a position with: middle, width, and height or \
|
235
236
|
\nminX, maxX, minY, maxY."
|
237
|
+
)
|
236
238
|
if "position" not in box:
|
237
239
|
raise TypeError(error_str)
|
238
240
|
else:
|
wandb/sdk/data_types/image.py
CHANGED
@@ -167,7 +167,7 @@ class Image(BatchableMedia):
|
|
167
167
|
self._file_type = None
|
168
168
|
|
169
169
|
# Allows the user to pass an Image object as the first parameter and have a perfect copy,
|
170
|
-
# only overriding additional
|
170
|
+
# only overriding additional metadata passed in. If this pattern is compelling, we can generalize.
|
171
171
|
if isinstance(data_or_path, Image):
|
172
172
|
self._initialize_from_wbimage(data_or_path)
|
173
173
|
elif isinstance(data_or_path, str):
|
wandb/sdk/data_types/video.py
CHANGED
@@ -24,7 +24,7 @@ if TYPE_CHECKING: # pragma: no cover
|
|
24
24
|
# https://github.com/wandb/wandb/issues/3472
|
25
25
|
#
|
26
26
|
# Essentially, the issue is that moviepy's write_gif function fails to close
|
27
|
-
# the open write / file
|
27
|
+
# the open write / file descriptor returned from `imageio.save`. The following
|
28
28
|
# function is a simplified copy of the function in the moviepy source code.
|
29
29
|
# See https://github.com/Zulko/moviepy/blob/7e3e8bb1b739eb6d1c0784b0cb2594b587b93b39/moviepy/video/io/gif_writers.py#L428
|
30
30
|
#
|
@@ -27,11 +27,11 @@ V = TypeVar("V")
|
|
27
27
|
|
28
28
|
|
29
29
|
class Response(Protocol[K, V]):
|
30
|
-
def __getitem__(self, key: K) -> V:
|
31
|
-
... # pragma: no cover
|
30
|
+
def __getitem__(self, key: K) -> V: ... # pragma: no cover
|
32
31
|
|
33
|
-
def get(
|
34
|
-
|
32
|
+
def get(
|
33
|
+
self, key: K, default: Optional[V] = None
|
34
|
+
) -> Optional[V]: ... # pragma: no cover
|
35
35
|
|
36
36
|
|
37
37
|
class ArgumentResponseResolver(Protocol):
|
@@ -42,8 +42,7 @@ class ArgumentResponseResolver(Protocol):
|
|
42
42
|
response: Response,
|
43
43
|
start_time: float,
|
44
44
|
time_elapsed: float,
|
45
|
-
) -> Optional[Dict[str, Any]]:
|
46
|
-
... # pragma: no cover
|
45
|
+
) -> Optional[Dict[str, Any]]: ... # pragma: no cover
|
47
46
|
|
48
47
|
|
49
48
|
class PatchAPI:
|
@@ -78,7 +78,7 @@ class ValidationDataLogger:
|
|
78
78
|
Defaults to `"wb_validation_data"`.
|
79
79
|
artifact_type: The artifact type to use for the validation data.
|
80
80
|
Defaults to `"validation_dataset"`.
|
81
|
-
class_labels: Optional list of
|
81
|
+
class_labels: Optional list of labels to use in the inferred
|
82
82
|
processors. If the model's `target` or `output` is inferred to be a class,
|
83
83
|
we will attempt to map the class to these labels. Defaults to `None`.
|
84
84
|
infer_missing_processors: Determines if processors are inferred if
|
@@ -262,7 +262,7 @@ def _infer_single_example_keyed_processor(
|
|
262
262
|
):
|
263
263
|
np = wandb.util.get_module(
|
264
264
|
"numpy",
|
265
|
-
required="
|
265
|
+
required="Inferring processors require numpy",
|
266
266
|
)
|
267
267
|
# Assume these are logits
|
268
268
|
class_names = class_labels_table.get_column("label")
|
@@ -291,13 +291,17 @@ def _infer_single_example_keyed_processor(
|
|
291
291
|
):
|
292
292
|
# assume this is a class
|
293
293
|
if class_labels_table is not None:
|
294
|
-
processors["class"] =
|
294
|
+
processors["class"] = (
|
295
|
+
lambda n, d, p: class_labels_table.index_ref(d[0])
|
296
|
+
if d[0] < len(class_labels_table.data)
|
297
|
+
else d[0]
|
298
|
+
) # type: ignore
|
295
299
|
else:
|
296
300
|
processors["val"] = lambda n, d, p: d[0]
|
297
301
|
elif len(shape) == 1:
|
298
302
|
np = wandb.util.get_module(
|
299
303
|
"numpy",
|
300
|
-
required="
|
304
|
+
required="Inferring processors require numpy",
|
301
305
|
)
|
302
306
|
# This could be anything
|
303
307
|
if shape[0] <= 10:
|
@@ -350,7 +354,7 @@ def _infer_validation_row_processor(
|
|
350
354
|
input_col_name: str = "input",
|
351
355
|
target_col_name: str = "target",
|
352
356
|
) -> Callable:
|
353
|
-
"""Infers the
|
357
|
+
"""Infers the composite processor for the validation data."""
|
354
358
|
single_processors = {}
|
355
359
|
if isinstance(example_input, dict):
|
356
360
|
for key in example_input:
|
@@ -427,7 +431,7 @@ def _infer_prediction_row_processor(
|
|
427
431
|
input_col_name: str = "input",
|
428
432
|
output_col_name: str = "output",
|
429
433
|
) -> Callable:
|
430
|
-
"""Infers the
|
434
|
+
"""Infers the composite processor for the prediction output data."""
|
431
435
|
single_processors = {}
|
432
436
|
|
433
437
|
if isinstance(example_prediction, dict):
|
wandb/sdk/interface/interface.py
CHANGED
@@ -387,18 +387,30 @@ class InterfaceBase:
|
|
387
387
|
def _make_partial_source_str(
|
388
388
|
source: Any, job_info: Dict[str, Any], metadata: Dict[str, Any]
|
389
389
|
) -> str:
|
390
|
-
"""Construct use_artifact.partial.source_info.
|
390
|
+
"""Construct use_artifact.partial.source_info.source as str."""
|
391
391
|
source_type = job_info.get("source_type", "").strip()
|
392
392
|
if source_type == "artifact":
|
393
393
|
info_source = job_info.get("source", {})
|
394
394
|
source.artifact.artifact = info_source.get("artifact", "")
|
395
395
|
source.artifact.entrypoint.extend(info_source.get("entrypoint", []))
|
396
396
|
source.artifact.notebook = info_source.get("notebook", False)
|
397
|
+
build_context = info_source.get("build_context")
|
398
|
+
if build_context:
|
399
|
+
source.artifact.build_context = build_context
|
400
|
+
dockerfile = info_source.get("dockerfile")
|
401
|
+
if dockerfile:
|
402
|
+
source.artifact.dockerfile = dockerfile
|
397
403
|
elif source_type == "repo":
|
398
404
|
source.git.git_info.remote = metadata.get("git", {}).get("remote", "")
|
399
405
|
source.git.git_info.commit = metadata.get("git", {}).get("commit", "")
|
400
406
|
source.git.entrypoint.extend(metadata.get("entrypoint", []))
|
401
407
|
source.git.notebook = metadata.get("notebook", False)
|
408
|
+
build_context = metadata.get("build_context")
|
409
|
+
if build_context:
|
410
|
+
source.git.build_context = build_context
|
411
|
+
dockerfile = metadata.get("dockerfile")
|
412
|
+
if dockerfile:
|
413
|
+
source.git.dockerfile = dockerfile
|
402
414
|
elif source_type == "image":
|
403
415
|
source.image.image = metadata.get("docker", "")
|
404
416
|
else:
|
@@ -424,7 +436,7 @@ class InterfaceBase:
|
|
424
436
|
job_info=job_info,
|
425
437
|
metadata=metadata,
|
426
438
|
)
|
427
|
-
use_artifact.partial.source_info.source.ParseFromString(src_str)
|
439
|
+
use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type]
|
428
440
|
|
429
441
|
return use_artifact
|
430
442
|
|
@@ -516,11 +528,15 @@ class InterfaceBase:
|
|
516
528
|
artifact_id: str,
|
517
529
|
download_root: str,
|
518
530
|
allow_missing_references: bool,
|
531
|
+
skip_cache: bool,
|
532
|
+
path_prefix: Optional[str],
|
519
533
|
) -> MailboxHandle:
|
520
534
|
download_artifact = pb.DownloadArtifactRequest()
|
521
535
|
download_artifact.artifact_id = artifact_id
|
522
536
|
download_artifact.download_root = download_root
|
523
537
|
download_artifact.allow_missing_references = allow_missing_references
|
538
|
+
download_artifact.skip_cache = skip_cache
|
539
|
+
download_artifact.path_prefix = path_prefix or ""
|
524
540
|
resp = self._deliver_download_artifact(download_artifact)
|
525
541
|
return resp
|
526
542
|
|
@@ -729,6 +745,56 @@ class InterfaceBase:
|
|
729
745
|
def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
|
730
746
|
raise NotImplementedError
|
731
747
|
|
748
|
+
def publish_job_input(
|
749
|
+
self,
|
750
|
+
include_paths: List[List[str]],
|
751
|
+
exclude_paths: List[List[str]],
|
752
|
+
run_config: bool = False,
|
753
|
+
file_path: str = "",
|
754
|
+
):
|
755
|
+
"""Publishes a request to add inputs to the job.
|
756
|
+
|
757
|
+
If run_config is True, the wandb.config will be added as a job input.
|
758
|
+
If file_path is provided, the file at file_path will be added as a job
|
759
|
+
input.
|
760
|
+
|
761
|
+
The paths provided as arguments are sequences of dictionary keys that
|
762
|
+
specify a path within the wandb.config. If a path is included, the
|
763
|
+
corresponding field will be treated as a job input. If a path is
|
764
|
+
excluded, the corresponding field will not be treated as a job input.
|
765
|
+
|
766
|
+
Args:
|
767
|
+
include_paths: paths within config to include as job inputs.
|
768
|
+
exclude_paths: paths within config to exclude as job inputs.
|
769
|
+
run_config: bool indicating whether wandb.config is the input source.
|
770
|
+
file_path: path to file to include as a job input.
|
771
|
+
"""
|
772
|
+
if run_config and file_path:
|
773
|
+
raise ValueError(
|
774
|
+
"run_config and file_path are mutually exclusive arguments."
|
775
|
+
)
|
776
|
+
request = pb.JobInputRequest()
|
777
|
+
include_records = [pb.JobInputPath(path=path) for path in include_paths]
|
778
|
+
exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths]
|
779
|
+
request.include_paths.extend(include_records)
|
780
|
+
request.exclude_paths.extend(exclude_records)
|
781
|
+
source = pb.JobInputSource(
|
782
|
+
run_config=pb.JobInputSource.RunConfigSource(),
|
783
|
+
)
|
784
|
+
if run_config:
|
785
|
+
source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource())
|
786
|
+
else:
|
787
|
+
source.file.CopyFrom(
|
788
|
+
pb.JobInputSource.ConfigFileSource(path=file_path),
|
789
|
+
)
|
790
|
+
request.input_source.CopyFrom(source)
|
791
|
+
|
792
|
+
return self._publish_job_input(request)
|
793
|
+
|
794
|
+
@abstractmethod
|
795
|
+
def _publish_job_input(self, request: pb.JobInputRequest) -> MailboxHandle:
|
796
|
+
raise NotImplementedError
|
797
|
+
|
732
798
|
def join(self) -> None:
|
733
799
|
# Drop indicates that the internal process has already been shutdown
|
734
800
|
if self._drop:
|
@@ -779,36 +845,6 @@ class InterfaceBase:
|
|
779
845
|
run_start.run.CopyFrom(run_pb)
|
780
846
|
return self._deliver_run_start(run_start)
|
781
847
|
|
782
|
-
def publish_launch_wandb_config_parameters(
|
783
|
-
self, include_paths: List[List[str]], exclude_paths: List[List[str]]
|
784
|
-
):
|
785
|
-
"""Tells the internal process to treat wandb.config fields as job inputs.
|
786
|
-
|
787
|
-
The paths provided as arguments are sequences of dictionary keys that
|
788
|
-
specify a path within the wandb.config. If a path is included, the
|
789
|
-
corresponding field will be treated as a job input. If a path is
|
790
|
-
excluded, the corresponding field will not be treated as a job input.
|
791
|
-
|
792
|
-
Args:
|
793
|
-
include_paths: paths within config to include as job inputs.
|
794
|
-
exclude_paths: paths within config to exclude as job inputs.
|
795
|
-
|
796
|
-
Returns:
|
797
|
-
None
|
798
|
-
"""
|
799
|
-
config_parameters = pb.LaunchWandbConfigParametersRecord()
|
800
|
-
include_records = [pb.ConfigFilterPath(path=path) for path in include_paths]
|
801
|
-
exclude_records = [pb.ConfigFilterPath(path=path) for path in exclude_paths]
|
802
|
-
config_parameters.include_paths.extend(include_records)
|
803
|
-
config_parameters.exclude_paths.extend(exclude_records)
|
804
|
-
return self._publish_launch_wandb_config_parameters(config_parameters)
|
805
|
-
|
806
|
-
@abstractmethod
|
807
|
-
def _publish_launch_wandb_config_parameters(
|
808
|
-
self, config_parameters: pb.LaunchWandbConfigParametersRecord
|
809
|
-
) -> None:
|
810
|
-
raise NotImplementedError
|
811
|
-
|
812
848
|
@abstractmethod
|
813
849
|
def _deliver_run_start(self, run_start: pb.RunStartRequest) -> MailboxHandle:
|
814
850
|
raise NotImplementedError
|
@@ -100,6 +100,10 @@ class InterfaceShared(InterfaceBase):
|
|
100
100
|
rec = self._make_record(telemetry=telem)
|
101
101
|
self._publish(rec)
|
102
102
|
|
103
|
+
def _publish_job_input(self, job_input: pb.JobInputRequest) -> MailboxHandle:
|
104
|
+
record = self._make_request(job_input=job_input)
|
105
|
+
return self._deliver_record(record)
|
106
|
+
|
103
107
|
def _make_stats(self, stats_dict: dict) -> pb.StatsRecord:
|
104
108
|
stats = pb.StatsRecord()
|
105
109
|
stats.stats_type = pb.StatsRecord.StatsType.SYSTEM
|
@@ -147,6 +151,7 @@ class InterfaceShared(InterfaceBase):
|
|
147
151
|
telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
|
148
152
|
get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
|
149
153
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
154
|
+
job_input: Optional[pb.JobInputRequest] = None,
|
150
155
|
) -> pb.Record:
|
151
156
|
request = pb.Request()
|
152
157
|
if login:
|
@@ -207,6 +212,8 @@ class InterfaceShared(InterfaceBase):
|
|
207
212
|
request.sync.CopyFrom(sync)
|
208
213
|
elif python_packages:
|
209
214
|
request.python_packages.CopyFrom(python_packages)
|
215
|
+
elif job_input:
|
216
|
+
request.job_input.CopyFrom(job_input)
|
210
217
|
else:
|
211
218
|
raise Exception("Invalid request")
|
212
219
|
record = self._make_record(request=request)
|
@@ -239,9 +246,6 @@ class InterfaceShared(InterfaceBase):
|
|
239
246
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
240
247
|
output: Optional[pb.OutputRecord] = None,
|
241
248
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
242
|
-
launch_wandb_config_parameters: Optional[
|
243
|
-
pb.LaunchWandbConfigParametersRecord
|
244
|
-
] = None,
|
245
249
|
) -> pb.Record:
|
246
250
|
record = pb.Record()
|
247
251
|
if run:
|
@@ -286,8 +290,6 @@ class InterfaceShared(InterfaceBase):
|
|
286
290
|
record.output.CopyFrom(output)
|
287
291
|
elif output_raw:
|
288
292
|
record.output_raw.CopyFrom(output_raw)
|
289
|
-
elif launch_wandb_config_parameters:
|
290
|
-
record.wandb_config_parameters.CopyFrom(launch_wandb_config_parameters)
|
291
293
|
else:
|
292
294
|
raise Exception("Invalid record")
|
293
295
|
return record
|
@@ -417,14 +419,6 @@ class InterfaceShared(InterfaceBase):
|
|
417
419
|
rec = self._make_record(alert=proto_alert)
|
418
420
|
self._publish(rec)
|
419
421
|
|
420
|
-
def _publish_launch_wandb_config_parameters(
|
421
|
-
self, launch_wandb_config_parameters: pb.LaunchWandbConfigParametersRecord
|
422
|
-
) -> None:
|
423
|
-
rec = self._make_record(
|
424
|
-
launch_wandb_config_parameters=launch_wandb_config_parameters
|
425
|
-
)
|
426
|
-
self._publish(rec)
|
427
|
-
|
428
422
|
def _communicate_status(
|
429
423
|
self, status: pb.StatusRequest
|
430
424
|
) -> Optional[pb.StatusResponse]:
|
wandb/sdk/internal/datastore.py
CHANGED
@@ -14,7 +14,7 @@ from wandb.sdk.lib.paths import LogicalPath
|
|
14
14
|
|
15
15
|
if TYPE_CHECKING:
|
16
16
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
17
|
-
from wandb.sdk.artifacts.artifact_saver import SaveFn
|
17
|
+
from wandb.sdk.artifacts.artifact_saver import SaveFn
|
18
18
|
from wandb.sdk.internal import file_stream, internal_api
|
19
19
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
20
20
|
|
@@ -148,11 +148,8 @@ class FilePusher:
|
|
148
148
|
manifest: "ArtifactManifest",
|
149
149
|
artifact_id: str,
|
150
150
|
save_fn: "SaveFn",
|
151
|
-
save_fn_async: "SaveFnAsync",
|
152
151
|
) -> None:
|
153
|
-
event = step_checksum.RequestStoreManifestFiles(
|
154
|
-
manifest, artifact_id, save_fn, save_fn_async
|
155
|
-
)
|
152
|
+
event = step_checksum.RequestStoreManifestFiles(manifest, artifact_id, save_fn)
|
156
153
|
self._incoming_queue.put(event)
|
157
154
|
|
158
155
|
def commit_artifact(
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import base64
|
2
1
|
import functools
|
3
2
|
import itertools
|
4
3
|
import json
|
@@ -53,7 +52,7 @@ logger = logging.getLogger(__name__)
|
|
53
52
|
|
54
53
|
class Chunk(NamedTuple):
|
55
54
|
filename: str
|
56
|
-
data:
|
55
|
+
data: str
|
57
56
|
|
58
57
|
|
59
58
|
class DefaultFilePolicy:
|
@@ -227,7 +226,7 @@ class CRDedupeFilePolicy(DefaultFilePolicy):
|
|
227
226
|
prefix += token + " "
|
228
227
|
return prefix, rest
|
229
228
|
|
230
|
-
def process_chunks(self, chunks: List) -> List["ProcessedChunk"]:
|
229
|
+
def process_chunks(self, chunks: List[Chunk]) -> List["ProcessedChunk"]:
|
231
230
|
r"""Process chunks.
|
232
231
|
|
233
232
|
Args:
|
@@ -300,18 +299,6 @@ class CRDedupeFilePolicy(DefaultFilePolicy):
|
|
300
299
|
return ret
|
301
300
|
|
302
301
|
|
303
|
-
class BinaryFilePolicy(DefaultFilePolicy):
|
304
|
-
def __init__(self) -> None:
|
305
|
-
super().__init__()
|
306
|
-
self._offset: int = 0
|
307
|
-
|
308
|
-
def process_chunks(self, chunks: List[Chunk]) -> "ProcessedBinaryChunk":
|
309
|
-
data = b"".join([c.data for c in chunks])
|
310
|
-
enc = base64.b64encode(data).decode("ascii")
|
311
|
-
self._offset += len(data)
|
312
|
-
return {"offset": self._offset, "content": enc, "encoding": "base64"}
|
313
|
-
|
314
|
-
|
315
302
|
class FileStreamApi:
|
316
303
|
"""Pushes chunks of files to our streaming endpoint.
|
317
304
|
|
@@ -585,12 +572,12 @@ class FileStreamApi:
|
|
585
572
|
def enqueue_preempting(self) -> None:
|
586
573
|
self._queue.put(self.Preempting())
|
587
574
|
|
588
|
-
def push(self, filename: str, data:
|
575
|
+
def push(self, filename: str, data: str) -> None:
|
589
576
|
"""Push a chunk of a file to the streaming endpoint.
|
590
577
|
|
591
578
|
Arguments:
|
592
|
-
filename: Name of file
|
593
|
-
data:
|
579
|
+
filename: Name of file to append to.
|
580
|
+
data: Text to append to the file.
|
594
581
|
"""
|
595
582
|
self._queue.put(Chunk(filename, data))
|
596
583
|
|
wandb/sdk/internal/handler.py
CHANGED
@@ -50,6 +50,18 @@ SummaryDict = Dict[str, Any]
|
|
50
50
|
|
51
51
|
logger = logging.getLogger(__name__)
|
52
52
|
|
53
|
+
# Update (March 5, 2024): Since ~2020/2021, when constructing the summary
|
54
|
+
# object, we had replaced the artifact path for media types with the latest
|
55
|
+
# artifact path. The primary purpose of this was to support live updating of
|
56
|
+
# media objects in the UI (since the default artifact path was fully qualified
|
57
|
+
# and would not update). However, in March of 2024, a bug was discovered with
|
58
|
+
# this approach which causes this path to be incorrect in cases where the media
|
59
|
+
# object is logged to another artifact before being logged to the run. Setting
|
60
|
+
# this to `False` disables this copy behavior. The impact is that users will
|
61
|
+
# need to refresh to see updates. Ironically, this updating behavior is not
|
62
|
+
# currently supported in the UI, so the impact of this change is minimal.
|
63
|
+
REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False
|
64
|
+
|
53
65
|
|
54
66
|
def _dict_nested_set(target: Dict[str, Any], key_list: Sequence[str], v: Any) -> None:
|
55
67
|
# recurse down the dictionary structure:
|
@@ -371,7 +383,11 @@ class HandleManager:
|
|
371
383
|
updated = True
|
372
384
|
return updated
|
373
385
|
# If the dict is a media object, update the pointer to the latest alias
|
374
|
-
elif
|
386
|
+
elif (
|
387
|
+
REPLACE_SUMMARY_ART_PATH_WITH_LATEST
|
388
|
+
and isinstance(v, dict)
|
389
|
+
and handler_util.metric_is_wandb_dict(v)
|
390
|
+
):
|
375
391
|
if "_latest_artifact_path" in v and "artifact_path" in v:
|
376
392
|
# TODO: Make non-destructive?
|
377
393
|
v["artifact_path"] = v["_latest_artifact_path"]
|
@@ -381,7 +397,7 @@ class HandleManager:
|
|
381
397
|
def _update_summary_media_objects(self, v: Dict[str, Any]) -> Dict[str, Any]:
|
382
398
|
# For now, non-recursive - just top level
|
383
399
|
for nk, nv in v.items():
|
384
|
-
if (
|
400
|
+
if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and (
|
385
401
|
isinstance(nv, dict)
|
386
402
|
and handler_util.metric_is_wandb_dict(nv)
|
387
403
|
and "_latest_artifact_path" in nv
|