wandb 0.19.11__py3-none-win_amd64.whl → 0.20.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -2
- wandb/__init__.pyi +3 -6
- wandb/_iterutils.py +26 -7
- wandb/_pydantic/__init__.py +2 -1
- wandb/_pydantic/utils.py +7 -0
- wandb/agents/pyagent.py +9 -15
- wandb/analytics/sentry.py +1 -2
- wandb/apis/attrs.py +3 -4
- wandb/apis/importers/internals/util.py +1 -1
- wandb/apis/importers/validation.py +2 -2
- wandb/apis/importers/wandb.py +30 -25
- wandb/apis/normalize.py +2 -2
- wandb/apis/public/__init__.py +1 -0
- wandb/apis/public/api.py +37 -33
- wandb/apis/public/artifacts.py +103 -72
- wandb/apis/public/jobs.py +3 -2
- wandb/apis/public/registries/registries_search.py +4 -2
- wandb/apis/public/registries/registry.py +1 -1
- wandb/apis/public/registries/utils.py +9 -9
- wandb/apis/public/runs.py +18 -6
- wandb/automations/_filters/expressions.py +1 -1
- wandb/automations/_filters/operators.py +1 -1
- wandb/automations/_filters/run_metrics.py +1 -1
- wandb/beta/workflows.py +6 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +54 -73
- wandb/docker/__init__.py +21 -74
- wandb/docker/names.py +40 -0
- wandb/env.py +0 -1
- wandb/errors/util.py +1 -1
- wandb/filesync/step_checksum.py +1 -1
- wandb/filesync/step_upload.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +1 -2
- wandb/integration/gym/__init__.py +5 -6
- wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
- wandb/integration/keras/keras.py +13 -19
- wandb/integration/kfp/kfp_patch.py +2 -3
- wandb/integration/langchain/wandb_tracer.py +1 -1
- wandb/integration/metaflow/metaflow.py +13 -13
- wandb/integration/openai/fine_tuning.py +3 -2
- wandb/integration/sagemaker/auth.py +2 -1
- wandb/integration/sklearn/utils.py +2 -1
- wandb/integration/tensorboard/__init__.py +1 -1
- wandb/integration/tensorboard/log.py +2 -5
- wandb/integration/tensorflow/__init__.py +2 -2
- wandb/jupyter.py +20 -17
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/utils.py +8 -7
- wandb/proto/v3/wandb_internal_pb2.py +355 -335
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +339 -335
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +339 -335
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_internal_pb2.py +339 -335
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +6 -8
- wandb/sdk/artifacts/_internal_artifact.py +43 -0
- wandb/sdk/artifacts/_validators.py +55 -35
- wandb/sdk/artifacts/artifact.py +117 -115
- wandb/sdk/artifacts/artifact_download_logger.py +2 -0
- wandb/sdk/artifacts/artifact_saver.py +1 -3
- wandb/sdk/artifacts/artifact_state.py +2 -0
- wandb/sdk/artifacts/artifact_ttl.py +2 -0
- wandb/sdk/artifacts/exceptions.py +14 -0
- wandb/sdk/artifacts/staging.py +2 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/artifacts/storage_layout.py +2 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
- wandb/sdk/backend/backend.py +11 -182
- wandb/sdk/data_types/_dtypes.py +2 -6
- wandb/sdk/data_types/audio.py +20 -3
- wandb/sdk/data_types/base_types/media.py +12 -7
- wandb/sdk/data_types/base_types/wb_value.py +8 -18
- wandb/sdk/data_types/bokeh.py +19 -2
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
- wandb/sdk/data_types/helper_types/image_mask.py +7 -1
- wandb/sdk/data_types/html.py +4 -4
- wandb/sdk/data_types/image.py +164 -103
- wandb/sdk/data_types/molecule.py +6 -6
- wandb/sdk/data_types/object_3d.py +10 -5
- wandb/sdk/data_types/saved_model.py +11 -6
- wandb/sdk/data_types/table.py +313 -83
- wandb/sdk/data_types/table_decorators.py +108 -0
- wandb/sdk/data_types/utils.py +43 -7
- wandb/sdk/data_types/video.py +21 -3
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/internal/datastore.py +2 -6
- wandb/sdk/internal/file_pusher.py +1 -5
- wandb/sdk/internal/file_stream.py +8 -17
- wandb/sdk/internal/handler.py +2 -2
- wandb/sdk/internal/incremental_table_util.py +53 -0
- wandb/sdk/internal/internal.py +3 -5
- wandb/sdk/internal/internal_api.py +66 -89
- wandb/sdk/internal/job_builder.py +2 -7
- wandb/sdk/internal/profiler.py +2 -2
- wandb/sdk/internal/progress.py +1 -3
- wandb/sdk/internal/run.py +1 -6
- wandb/sdk/internal/sender.py +24 -36
- wandb/sdk/internal/system/assets/aggregators.py +1 -7
- wandb/sdk/internal/system/assets/disk.py +3 -3
- wandb/sdk/internal/system/assets/gpu.py +4 -4
- wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
- wandb/sdk/internal/system/assets/interfaces.py +6 -6
- wandb/sdk/internal/system/assets/tpu.py +1 -1
- wandb/sdk/internal/system/assets/trainium.py +6 -6
- wandb/sdk/internal/system/system_info.py +5 -7
- wandb/sdk/internal/system/system_monitor.py +4 -4
- wandb/sdk/internal/tb_watcher.py +5 -7
- wandb/sdk/launch/_launch.py +1 -1
- wandb/sdk/launch/_project_spec.py +19 -20
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/config.py +1 -1
- wandb/sdk/launch/agent/job_status_tracker.py +2 -2
- wandb/sdk/launch/builder/build.py +2 -3
- wandb/sdk/launch/builder/kaniko_builder.py +5 -4
- wandb/sdk/launch/environment/gcp_environment.py +1 -2
- wandb/sdk/launch/registry/azure_container_registry.py +2 -2
- wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
- wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
- wandb/sdk/launch/runner/abstract.py +5 -5
- wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
- wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
- wandb/sdk/launch/runner/vertex_runner.py +2 -7
- wandb/sdk/launch/sweeps/__init__.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -2
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -4
- wandb/sdk/lib/apikey.py +5 -8
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/fsm.py +3 -18
- wandb/sdk/lib/gitlib.py +6 -5
- wandb/sdk/lib/ipython.py +2 -2
- wandb/sdk/lib/json_util.py +9 -14
- wandb/sdk/lib/printer.py +3 -8
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/lib/retry.py +3 -7
- wandb/sdk/lib/run_moment.py +2 -2
- wandb/sdk/lib/service_connection.py +19 -21
- wandb/sdk/lib/service_token.py +1 -2
- wandb/sdk/mailbox/mailbox_handle.py +3 -7
- wandb/sdk/mailbox/response_handle.py +2 -6
- wandb/sdk/service/streams.py +3 -7
- wandb/sdk/verify/verify.py +5 -6
- wandb/sdk/wandb_config.py +1 -1
- wandb/sdk/wandb_init.py +38 -106
- wandb/sdk/wandb_login.py +7 -6
- wandb/sdk/wandb_run.py +52 -240
- wandb/sdk/wandb_settings.py +71 -60
- wandb/sdk/wandb_setup.py +43 -17
- wandb/sdk/wandb_watch.py +5 -7
- wandb/sync/__init__.py +1 -1
- wandb/sync/sync.py +13 -13
- wandb/util.py +17 -35
- wandb/wandb_agent.py +8 -11
- {wandb-0.19.11.dist-info → wandb-0.20.0.dist-info}/METADATA +5 -5
- {wandb-0.19.11.dist-info → wandb-0.20.0.dist-info}/RECORD +170 -168
- wandb/docker/auth.py +0 -435
- wandb/docker/www_authenticate.py +0 -94
- {wandb-0.19.11.dist-info → wandb-0.20.0.dist-info}/WHEEL +0 -0
- {wandb-0.19.11.dist-info → wandb-0.20.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.11.dist-info → wandb-0.20.0.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/image.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1
1
|
import hashlib
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import pathlib
|
4
5
|
from io import BytesIO
|
5
6
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union, cast
|
6
7
|
from urllib import parse
|
7
8
|
|
9
|
+
from packaging.version import parse as parse_version
|
10
|
+
|
8
11
|
import wandb
|
9
12
|
from wandb import util
|
10
13
|
from wandb.sdk.lib import hashutil, runid
|
@@ -30,10 +33,60 @@ if TYPE_CHECKING: # pragma: no cover
|
|
30
33
|
ImageDataType = Union[
|
31
34
|
"matplotlib.artist.Artist", "PILImage", "TorchTensorType", "np.ndarray"
|
32
35
|
]
|
33
|
-
ImageDataOrPathType = Union[str, "Image", ImageDataType]
|
36
|
+
ImageDataOrPathType = Union[str, pathlib.Path, "Image", ImageDataType]
|
34
37
|
TorchTensorType = Union["torch.Tensor", "torch.Variable"]
|
35
38
|
|
36
39
|
|
40
|
+
def _warn_on_invalid_data_range(
|
41
|
+
data: "np.ndarray",
|
42
|
+
normalize: bool = True,
|
43
|
+
) -> None:
|
44
|
+
if not normalize:
|
45
|
+
return
|
46
|
+
|
47
|
+
np = util.get_module(
|
48
|
+
"numpy",
|
49
|
+
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
50
|
+
)
|
51
|
+
|
52
|
+
if np.min(data) < 0 or np.max(data) > 255:
|
53
|
+
wandb.termwarn(
|
54
|
+
"Data passed to `wandb.Image` should consist of values in the range [0, 255], "
|
55
|
+
"image data will be normalized to this range, "
|
56
|
+
"but behavior will be removed in a future version of wandb.",
|
57
|
+
repeat=False,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def _normalize(data: "np.ndarray") -> "np.ndarray":
|
62
|
+
"""Normalizes and converts image pixel values to uint8 in the range [0, 255]."""
|
63
|
+
np = util.get_module(
|
64
|
+
"numpy",
|
65
|
+
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
66
|
+
)
|
67
|
+
|
68
|
+
# if an image has negative values, set all values to be in the range [0, 1]
|
69
|
+
# This can lead to inconsistent behavior when an image has only a single negative value
|
70
|
+
if np.min(data) < 0:
|
71
|
+
data = data - np.min(data)
|
72
|
+
|
73
|
+
if np.ptp(data) != 0:
|
74
|
+
data = data / np.ptp(data)
|
75
|
+
|
76
|
+
if np.max(data) <= 1.0:
|
77
|
+
data = (255 * data).astype(np.int32)
|
78
|
+
|
79
|
+
return data.clip(0, 255)
|
80
|
+
|
81
|
+
|
82
|
+
def _convert_to_uint8(data: "np.ndarray") -> "np.ndarray":
|
83
|
+
np = util.get_module(
|
84
|
+
"numpy",
|
85
|
+
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
86
|
+
)
|
87
|
+
return data.astype(np.uint8)
|
88
|
+
|
89
|
+
|
37
90
|
def _server_accepts_image_filenames(run: "LocalRun") -> bool:
|
38
91
|
if run.offline:
|
39
92
|
return True
|
@@ -43,10 +96,9 @@ def _server_accepts_image_filenames(run: "LocalRun") -> bool:
|
|
43
96
|
max_cli_version = util._get_max_cli_version()
|
44
97
|
if max_cli_version is None:
|
45
98
|
return False
|
46
|
-
from wandb.util import parse_version
|
47
99
|
|
48
|
-
accepts_image_filenames: bool = parse_version(
|
49
|
-
|
100
|
+
accepts_image_filenames: bool = parse_version(max_cli_version) >= parse_version(
|
101
|
+
"0.12.10"
|
50
102
|
)
|
51
103
|
return accepts_image_filenames
|
52
104
|
|
@@ -59,69 +111,11 @@ def _server_accepts_artifact_path(run: "LocalRun") -> bool:
|
|
59
111
|
if max_cli_version is None:
|
60
112
|
return False
|
61
113
|
|
62
|
-
return
|
114
|
+
return parse_version(max_cli_version) >= parse_version("0.12.14")
|
63
115
|
|
64
116
|
|
65
117
|
class Image(BatchableMedia):
|
66
|
-
"""
|
67
|
-
|
68
|
-
Args:
|
69
|
-
data_or_path: (numpy array, string, io) Accepts numpy array of
|
70
|
-
image data, or a PIL image. The class attempts to infer
|
71
|
-
the data format and converts it.
|
72
|
-
mode: (string) The PIL mode for an image. Most common are "L", "RGB",
|
73
|
-
"RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
|
74
|
-
caption: (string) Label for display of image.
|
75
|
-
|
76
|
-
Note : When logging a `torch.Tensor` as a `wandb.Image`, images are normalized. If you do not want to normalize your images, please convert your tensors to a PIL Image.
|
77
|
-
|
78
|
-
Examples:
|
79
|
-
### Create a wandb.Image from a numpy array
|
80
|
-
```python
|
81
|
-
import numpy as np
|
82
|
-
import wandb
|
83
|
-
|
84
|
-
with wandb.init() as run:
|
85
|
-
examples = []
|
86
|
-
for i in range(3):
|
87
|
-
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
88
|
-
image = wandb.Image(pixels, caption=f"random field {i}")
|
89
|
-
examples.append(image)
|
90
|
-
run.log({"examples": examples})
|
91
|
-
```
|
92
|
-
|
93
|
-
### Create a wandb.Image from a PILImage
|
94
|
-
```python
|
95
|
-
import numpy as np
|
96
|
-
from PIL import Image as PILImage
|
97
|
-
import wandb
|
98
|
-
|
99
|
-
with wandb.init() as run:
|
100
|
-
examples = []
|
101
|
-
for i in range(3):
|
102
|
-
pixels = np.random.randint(
|
103
|
-
low=0, high=256, size=(100, 100, 3), dtype=np.uint8
|
104
|
-
)
|
105
|
-
pil_image = PILImage.fromarray(pixels, mode="RGB")
|
106
|
-
image = wandb.Image(pil_image, caption=f"random field {i}")
|
107
|
-
examples.append(image)
|
108
|
-
run.log({"examples": examples})
|
109
|
-
```
|
110
|
-
|
111
|
-
### log .jpg rather than .png (default)
|
112
|
-
```python
|
113
|
-
import numpy as np
|
114
|
-
import wandb
|
115
|
-
|
116
|
-
with wandb.init() as run:
|
117
|
-
examples = []
|
118
|
-
for i in range(3):
|
119
|
-
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
120
|
-
image = wandb.Image(pixels, caption=f"random field {i}", file_type="jpg")
|
121
|
-
examples.append(image)
|
122
|
-
run.log({"examples": examples})
|
123
|
-
```
|
124
|
-
"""
|
118
|
+
"""A class for logging images to W&B."""
|
125
119
|
|
126
120
|
MAX_ITEMS = 108
|
127
121
|
|
@@ -151,7 +145,85 @@ class Image(BatchableMedia):
|
|
151
145
|
boxes: Optional[Union[Dict[str, "BoundingBoxes2D"], Dict[str, dict]]] = None,
|
152
146
|
masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
|
153
147
|
file_type: Optional[str] = None,
|
148
|
+
normalize: bool = True,
|
154
149
|
) -> None:
|
150
|
+
"""Initialize a wandb.Image object.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
data_or_path: Accepts numpy array/pytorch tensor of image data,
|
154
|
+
a PIL image object, or a path to an image file.
|
155
|
+
|
156
|
+
If a numpy array or pytorch tensor is provided,
|
157
|
+
the image data will be saved to the given file type.
|
158
|
+
If the values are not in the range [0, 255] or all values are in the range [0, 1],
|
159
|
+
the image pixel values will be normalized to the range [0, 255]
|
160
|
+
unless `normalize` is set to False.
|
161
|
+
- pytorch tensor should be in the format (channel, height, width)
|
162
|
+
- numpy array should be in the format (height, width, channel)
|
163
|
+
mode: The PIL mode for an image. Most common are "L", "RGB",
|
164
|
+
"RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
|
165
|
+
caption: Label for display of image.
|
166
|
+
grouping: The grouping number for the image.
|
167
|
+
classes: A list of class information for the image,
|
168
|
+
used for labeling bounding boxes, and image masks.
|
169
|
+
boxes: A dictionary containing bounding box information for the image.
|
170
|
+
see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
|
171
|
+
masks: A dictionary containing mask information for the image.
|
172
|
+
see: https://docs.wandb.ai/ref/python/data-types/imagemask/
|
173
|
+
file_type: The file type to save the image as.
|
174
|
+
This parameter has no effect if data_or_path is a path to an image file.
|
175
|
+
normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
|
176
|
+
Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
|
177
|
+
|
178
|
+
Examples:
|
179
|
+
### Create a wandb.Image from a numpy array
|
180
|
+
```python
|
181
|
+
import numpy as np
|
182
|
+
import wandb
|
183
|
+
|
184
|
+
with wandb.init() as run:
|
185
|
+
examples = []
|
186
|
+
for i in range(3):
|
187
|
+
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
188
|
+
image = wandb.Image(pixels, caption=f"random field {i}")
|
189
|
+
examples.append(image)
|
190
|
+
run.log({"examples": examples})
|
191
|
+
```
|
192
|
+
|
193
|
+
### Create a wandb.Image from a PILImage
|
194
|
+
```python
|
195
|
+
import numpy as np
|
196
|
+
from PIL import Image as PILImage
|
197
|
+
import wandb
|
198
|
+
|
199
|
+
with wandb.init() as run:
|
200
|
+
examples = []
|
201
|
+
for i in range(3):
|
202
|
+
pixels = np.random.randint(
|
203
|
+
low=0, high=256, size=(100, 100, 3), dtype=np.uint8
|
204
|
+
)
|
205
|
+
pil_image = PILImage.fromarray(pixels, mode="RGB")
|
206
|
+
image = wandb.Image(pil_image, caption=f"random field {i}")
|
207
|
+
examples.append(image)
|
208
|
+
run.log({"examples": examples})
|
209
|
+
```
|
210
|
+
|
211
|
+
### log .jpg rather than .png (default)
|
212
|
+
```python
|
213
|
+
import numpy as np
|
214
|
+
import wandb
|
215
|
+
|
216
|
+
with wandb.init() as run:
|
217
|
+
examples = []
|
218
|
+
for i in range(3):
|
219
|
+
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
220
|
+
image = wandb.Image(
|
221
|
+
pixels, caption=f"random field {i}", file_type="jpg"
|
222
|
+
)
|
223
|
+
examples.append(image)
|
224
|
+
run.log({"examples": examples})
|
225
|
+
```
|
226
|
+
"""
|
155
227
|
super().__init__(caption=caption)
|
156
228
|
# TODO: We should remove grouping, it's a terrible name and I don't
|
157
229
|
# think anyone uses it.
|
@@ -169,13 +241,15 @@ class Image(BatchableMedia):
|
|
169
241
|
# only overriding additional metadata passed in. If this pattern is compelling, we can generalize.
|
170
242
|
if isinstance(data_or_path, Image):
|
171
243
|
self._initialize_from_wbimage(data_or_path)
|
172
|
-
elif isinstance(data_or_path, str):
|
244
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
245
|
+
data_or_path = str(data_or_path)
|
246
|
+
|
173
247
|
if self.path_is_reference(data_or_path):
|
174
248
|
self._initialize_from_reference(data_or_path)
|
175
249
|
else:
|
176
250
|
self._initialize_from_path(data_or_path)
|
177
251
|
else:
|
178
|
-
self._initialize_from_data(data_or_path, mode, file_type)
|
252
|
+
self._initialize_from_data(data_or_path, mode, file_type, normalize)
|
179
253
|
self._set_initialization_meta(
|
180
254
|
grouping, caption, classes, boxes, masks, file_type
|
181
255
|
)
|
@@ -288,6 +362,7 @@ class Image(BatchableMedia):
|
|
288
362
|
data: "ImageDataType",
|
289
363
|
mode: Optional[str] = None,
|
290
364
|
file_type: Optional[str] = None,
|
365
|
+
normalize: bool = True,
|
291
366
|
) -> None:
|
292
367
|
pil_image = util.get_module(
|
293
368
|
"PIL.Image",
|
@@ -309,28 +384,39 @@ class Image(BatchableMedia):
|
|
309
384
|
elif isinstance(data, pil_image.Image):
|
310
385
|
self._image = data
|
311
386
|
elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
|
312
|
-
vis_util = util.get_module(
|
313
|
-
"torchvision.utils", "torchvision is required to render images"
|
314
|
-
)
|
315
387
|
if hasattr(data, "requires_grad") and data.requires_grad:
|
316
388
|
data = data.detach() # type: ignore
|
317
389
|
if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
|
318
|
-
data = data.to(float)
|
319
|
-
data = vis_util.make_grid(data, normalize=True)
|
390
|
+
data = data.to(float) # type: ignore [union-attr]
|
320
391
|
mode = mode or self.guess_mode(data, file_type)
|
392
|
+
data = data.permute(1, 2, 0).cpu().numpy() # type: ignore [union-attr]
|
393
|
+
|
394
|
+
_warn_on_invalid_data_range(data, normalize)
|
395
|
+
|
396
|
+
data = _normalize(data) if normalize else data # type: ignore [arg-type]
|
397
|
+
data = _convert_to_uint8(data)
|
398
|
+
|
399
|
+
if data.ndim > 2:
|
400
|
+
data = data.squeeze()
|
401
|
+
|
321
402
|
self._image = pil_image.fromarray(
|
322
|
-
data
|
403
|
+
data,
|
323
404
|
mode=mode,
|
324
405
|
)
|
325
406
|
else:
|
326
407
|
if hasattr(data, "numpy"): # TF data eager tensors
|
327
408
|
data = data.numpy()
|
328
|
-
if data.ndim > 2:
|
329
|
-
|
409
|
+
if data.ndim > 2: # type: ignore [union-attr]
|
410
|
+
# get rid of trivial dimensions as a convenience
|
411
|
+
data = data.squeeze() # type: ignore [union-attr]
|
412
|
+
|
413
|
+
_warn_on_invalid_data_range(data, normalize) # type: ignore [arg-type]
|
330
414
|
|
331
415
|
mode = mode or self.guess_mode(data, file_type)
|
416
|
+
data = _normalize(data) if normalize else data # type: ignore [arg-type]
|
417
|
+
data = _convert_to_uint8(data) # type: ignore [arg-type]
|
332
418
|
self._image = pil_image.fromarray(
|
333
|
-
|
419
|
+
data,
|
334
420
|
mode=mode,
|
335
421
|
)
|
336
422
|
|
@@ -459,7 +545,7 @@ class Image(BatchableMedia):
|
|
459
545
|
}
|
460
546
|
|
461
547
|
elif not isinstance(run_or_artifact, Run):
|
462
|
-
raise
|
548
|
+
raise TypeError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
|
463
549
|
|
464
550
|
if self._boxes:
|
465
551
|
json_dict["boxes"] = {
|
@@ -485,7 +571,7 @@ class Image(BatchableMedia):
|
|
485
571
|
else:
|
486
572
|
num_channels = data.shape[-1]
|
487
573
|
|
488
|
-
if ndims == 2:
|
574
|
+
if ndims == 2 or num_channels == 1:
|
489
575
|
return "L"
|
490
576
|
elif num_channels == 3:
|
491
577
|
return "RGB"
|
@@ -501,34 +587,9 @@ class Image(BatchableMedia):
|
|
501
587
|
return "RGBA"
|
502
588
|
else:
|
503
589
|
raise ValueError(
|
504
|
-
"Un-supported shape for image conversion {
|
590
|
+
f"Un-supported shape for image conversion {list(data.shape)}"
|
505
591
|
)
|
506
592
|
|
507
|
-
@classmethod
|
508
|
-
def to_uint8(cls, data: "np.ndarray") -> "np.ndarray":
|
509
|
-
"""Convert image data to uint8.
|
510
|
-
|
511
|
-
Convert floating point image on the range [0,1] and integer images on the range
|
512
|
-
[0,255] to uint8, clipping if necessary.
|
513
|
-
"""
|
514
|
-
np = util.get_module(
|
515
|
-
"numpy",
|
516
|
-
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
517
|
-
)
|
518
|
-
|
519
|
-
# I think it's better to check the image range vs the data type, since many
|
520
|
-
# image libraries will return floats between 0 and 255
|
521
|
-
|
522
|
-
# some images have range -1...1 or 0-1
|
523
|
-
dmin = np.min(data)
|
524
|
-
if dmin < 0:
|
525
|
-
data = (data - np.min(data)) / np.ptp(data)
|
526
|
-
if np.max(data) <= 1.0:
|
527
|
-
data = (data * 255).astype(np.int32)
|
528
|
-
|
529
|
-
# assert issubclass(data.dtype.type, np.integer), 'Illegal image format.'
|
530
|
-
return data.clip(0, 255).astype(np.uint8)
|
531
|
-
|
532
593
|
@classmethod
|
533
594
|
def seq_to_json(
|
534
595
|
cls: Type["Image"],
|
wandb/sdk/data_types/molecule.py
CHANGED
@@ -26,7 +26,7 @@ class Molecule(BatchableMedia):
|
|
26
26
|
"""Wandb class for 3D Molecular data.
|
27
27
|
|
28
28
|
Args:
|
29
|
-
data_or_path: (string, io)
|
29
|
+
data_or_path: (pathlib.Path, string, io)
|
30
30
|
Molecule can be initialized from a file name or an io object.
|
31
31
|
caption: (string)
|
32
32
|
Caption associated with the molecule for display.
|
@@ -49,7 +49,7 @@ class Molecule(BatchableMedia):
|
|
49
49
|
|
50
50
|
def __init__(
|
51
51
|
self,
|
52
|
-
data_or_path: Union[str, "TextIO"],
|
52
|
+
data_or_path: Union[str, pathlib.Path, "TextIO"],
|
53
53
|
caption: Optional[str] = None,
|
54
54
|
**kwargs: str,
|
55
55
|
) -> None:
|
@@ -82,7 +82,9 @@ class Molecule(BatchableMedia):
|
|
82
82
|
f.write(molecule)
|
83
83
|
|
84
84
|
self._set_file(tmp_path, is_tmp=True)
|
85
|
-
elif isinstance(data_or_path, str):
|
85
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
86
|
+
data_or_path = str(data_or_path)
|
87
|
+
|
86
88
|
extension = os.path.splitext(data_or_path)[1][1:]
|
87
89
|
if extension not in Molecule.SUPPORTED_TYPES:
|
88
90
|
raise ValueError(
|
@@ -144,9 +146,7 @@ class Molecule(BatchableMedia):
|
|
144
146
|
elif isinstance(data_or_path, rdkit_chem.rdchem.Mol):
|
145
147
|
molecule = data_or_path
|
146
148
|
else:
|
147
|
-
raise
|
148
|
-
"Data must be file name or an rdkit.Chem.rdchem.Mol object"
|
149
|
-
)
|
149
|
+
raise TypeError("Data must be file name or an rdkit.Chem.rdchem.Mol object")
|
150
150
|
|
151
151
|
if convert_to_3d_and_optimize:
|
152
152
|
molecule = rdkit_chem.AddHs(molecule)
|
@@ -2,6 +2,7 @@ import codecs
|
|
2
2
|
import itertools
|
3
3
|
import json
|
4
4
|
import os
|
5
|
+
import pathlib
|
5
6
|
from typing import (
|
6
7
|
TYPE_CHECKING,
|
7
8
|
ClassVar,
|
@@ -187,7 +188,7 @@ class Object3D(BatchableMedia):
|
|
187
188
|
"""Wandb class for 3D point clouds.
|
188
189
|
|
189
190
|
Args:
|
190
|
-
data_or_path: (numpy array, string, io)
|
191
|
+
data_or_path: (numpy array, pathlib.Path, string, io)
|
191
192
|
Object3D can be initialized from a file or a numpy array.
|
192
193
|
|
193
194
|
You can pass a path to a file or an io object and a file_type
|
@@ -214,14 +215,16 @@ class Object3D(BatchableMedia):
|
|
214
215
|
|
215
216
|
def __init__(
|
216
217
|
self,
|
217
|
-
data_or_path: Union["np.ndarray", str, "TextIO", dict],
|
218
|
+
data_or_path: Union["np.ndarray", str, pathlib.Path, "TextIO", dict],
|
218
219
|
caption: Optional[str] = None,
|
219
220
|
**kwargs: Optional[Union[str, "FileFormat3D"]],
|
220
221
|
) -> None:
|
221
222
|
super().__init__(caption=caption)
|
222
223
|
|
223
|
-
if hasattr(data_or_path, "name"):
|
224
|
-
# if the file has a path, we just detect the type and copy it from there
|
224
|
+
if hasattr(data_or_path, "name") and not isinstance(data_or_path, pathlib.Path):
|
225
|
+
# if the file has a path, we just detect the type and copy it from there.
|
226
|
+
# this does not work for pathlib.Path objects,
|
227
|
+
# where `.name` returns the last directory in the path.
|
225
228
|
data_or_path = data_or_path.name
|
226
229
|
|
227
230
|
if hasattr(data_or_path, "read"):
|
@@ -247,7 +250,9 @@ class Object3D(BatchableMedia):
|
|
247
250
|
f.write(object_3d)
|
248
251
|
|
249
252
|
self._set_file(tmp_path, is_tmp=True, extension=extension)
|
250
|
-
elif isinstance(data_or_path, str):
|
253
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
254
|
+
data_or_path = str(data_or_path)
|
255
|
+
|
251
256
|
path = data_or_path
|
252
257
|
extension = None
|
253
258
|
for supported_type in Object3D.SUPPORTED_TYPES:
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import os
|
4
|
+
import pathlib
|
4
5
|
import shutil
|
5
6
|
import sys
|
6
7
|
from types import ModuleType
|
@@ -72,10 +73,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
72
73
|
|
73
74
|
_model_obj: SavedModelObjType | None
|
74
75
|
_path: str | None
|
75
|
-
_input_obj_or_path: SavedModelObjType | str
|
76
|
+
_input_obj_or_path: SavedModelObjType | str | pathlib.Path
|
76
77
|
|
77
78
|
# Public Methods
|
78
|
-
def __init__(
|
79
|
+
def __init__(
|
80
|
+
self, obj_or_path: SavedModelObjType | str | pathlib.Path, **kwargs: Any
|
81
|
+
) -> None:
|
79
82
|
super().__init__()
|
80
83
|
if self.__class__ == _SavedModel:
|
81
84
|
raise TypeError(
|
@@ -84,9 +87,11 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
84
87
|
self._model_obj = None
|
85
88
|
self._path = None
|
86
89
|
self._input_obj_or_path = obj_or_path
|
87
|
-
input_is_path = isinstance(obj_or_path, str) and os.path.exists(
|
90
|
+
input_is_path = isinstance(obj_or_path, (str, pathlib.Path)) and os.path.exists(
|
91
|
+
obj_or_path
|
92
|
+
)
|
88
93
|
if input_is_path:
|
89
|
-
|
94
|
+
obj_or_path = str(obj_or_path)
|
90
95
|
self._set_obj(self._deserialize(obj_or_path))
|
91
96
|
else:
|
92
97
|
self._set_obj(obj_or_path)
|
@@ -140,7 +145,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
140
145
|
from wandb.sdk.wandb_run import Run
|
141
146
|
|
142
147
|
if isinstance(run_or_artifact, Run):
|
143
|
-
raise
|
148
|
+
raise TypeError("SavedModel cannot be added to run - must use artifact")
|
144
149
|
artifact = run_or_artifact
|
145
150
|
json_obj = {
|
146
151
|
"type": self._log_type,
|
@@ -280,7 +285,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
280
285
|
|
281
286
|
def __init__(
|
282
287
|
self,
|
283
|
-
obj_or_path: SavedModelObjType | str,
|
288
|
+
obj_or_path: SavedModelObjType | str | pathlib.Path,
|
284
289
|
dep_py_files: list[str] | None = None,
|
285
290
|
):
|
286
291
|
super().__init__(obj_or_path)
|