wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -2
- wandb/__init__.pyi +3 -6
- wandb/_iterutils.py +26 -7
- wandb/_pydantic/__init__.py +2 -1
- wandb/_pydantic/utils.py +7 -0
- wandb/agents/pyagent.py +9 -15
- wandb/analytics/sentry.py +1 -2
- wandb/apis/attrs.py +3 -4
- wandb/apis/importers/internals/util.py +1 -1
- wandb/apis/importers/validation.py +2 -2
- wandb/apis/importers/wandb.py +30 -25
- wandb/apis/normalize.py +2 -2
- wandb/apis/public/__init__.py +1 -0
- wandb/apis/public/api.py +37 -33
- wandb/apis/public/artifacts.py +103 -72
- wandb/apis/public/jobs.py +3 -2
- wandb/apis/public/registries/registries_search.py +4 -2
- wandb/apis/public/registries/registry.py +1 -1
- wandb/apis/public/registries/utils.py +9 -9
- wandb/apis/public/runs.py +18 -6
- wandb/automations/_filters/expressions.py +1 -1
- wandb/automations/_filters/operators.py +1 -1
- wandb/automations/_filters/run_metrics.py +1 -1
- wandb/beta/workflows.py +6 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +54 -73
- wandb/docker/__init__.py +21 -74
- wandb/docker/names.py +40 -0
- wandb/env.py +0 -1
- wandb/errors/util.py +1 -1
- wandb/filesync/step_checksum.py +1 -1
- wandb/filesync/step_upload.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +1 -2
- wandb/integration/gym/__init__.py +5 -6
- wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
- wandb/integration/keras/keras.py +13 -19
- wandb/integration/kfp/kfp_patch.py +2 -3
- wandb/integration/langchain/wandb_tracer.py +1 -1
- wandb/integration/metaflow/metaflow.py +13 -13
- wandb/integration/openai/fine_tuning.py +3 -2
- wandb/integration/sagemaker/auth.py +2 -1
- wandb/integration/sklearn/utils.py +2 -1
- wandb/integration/tensorboard/__init__.py +1 -1
- wandb/integration/tensorboard/log.py +2 -5
- wandb/integration/tensorflow/__init__.py +2 -2
- wandb/jupyter.py +20 -17
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/utils.py +8 -7
- wandb/proto/v3/wandb_internal_pb2.py +355 -335
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +339 -335
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +339 -335
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_internal_pb2.py +339 -335
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +6 -8
- wandb/sdk/artifacts/_internal_artifact.py +43 -0
- wandb/sdk/artifacts/_validators.py +55 -35
- wandb/sdk/artifacts/artifact.py +117 -115
- wandb/sdk/artifacts/artifact_download_logger.py +2 -0
- wandb/sdk/artifacts/artifact_saver.py +1 -3
- wandb/sdk/artifacts/artifact_state.py +2 -0
- wandb/sdk/artifacts/artifact_ttl.py +2 -0
- wandb/sdk/artifacts/exceptions.py +14 -0
- wandb/sdk/artifacts/staging.py +2 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/artifacts/storage_layout.py +2 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
- wandb/sdk/backend/backend.py +11 -182
- wandb/sdk/data_types/_dtypes.py +2 -6
- wandb/sdk/data_types/audio.py +20 -3
- wandb/sdk/data_types/base_types/media.py +12 -7
- wandb/sdk/data_types/base_types/wb_value.py +8 -18
- wandb/sdk/data_types/bokeh.py +19 -2
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
- wandb/sdk/data_types/helper_types/image_mask.py +7 -1
- wandb/sdk/data_types/html.py +4 -4
- wandb/sdk/data_types/image.py +178 -103
- wandb/sdk/data_types/molecule.py +6 -6
- wandb/sdk/data_types/object_3d.py +10 -5
- wandb/sdk/data_types/saved_model.py +11 -6
- wandb/sdk/data_types/table.py +313 -83
- wandb/sdk/data_types/table_decorators.py +108 -0
- wandb/sdk/data_types/utils.py +43 -7
- wandb/sdk/data_types/video.py +21 -3
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/internal/datastore.py +2 -6
- wandb/sdk/internal/file_pusher.py +1 -5
- wandb/sdk/internal/file_stream.py +8 -17
- wandb/sdk/internal/handler.py +2 -2
- wandb/sdk/internal/incremental_table_util.py +53 -0
- wandb/sdk/internal/internal.py +3 -5
- wandb/sdk/internal/internal_api.py +66 -89
- wandb/sdk/internal/job_builder.py +2 -7
- wandb/sdk/internal/profiler.py +2 -2
- wandb/sdk/internal/progress.py +1 -3
- wandb/sdk/internal/run.py +1 -6
- wandb/sdk/internal/sender.py +24 -36
- wandb/sdk/internal/system/assets/aggregators.py +1 -7
- wandb/sdk/internal/system/assets/disk.py +3 -3
- wandb/sdk/internal/system/assets/gpu.py +4 -4
- wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
- wandb/sdk/internal/system/assets/interfaces.py +6 -6
- wandb/sdk/internal/system/assets/tpu.py +1 -1
- wandb/sdk/internal/system/assets/trainium.py +6 -6
- wandb/sdk/internal/system/system_info.py +5 -7
- wandb/sdk/internal/system/system_monitor.py +4 -4
- wandb/sdk/internal/tb_watcher.py +5 -7
- wandb/sdk/launch/_launch.py +1 -1
- wandb/sdk/launch/_project_spec.py +19 -20
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/config.py +1 -1
- wandb/sdk/launch/agent/job_status_tracker.py +2 -2
- wandb/sdk/launch/builder/build.py +2 -3
- wandb/sdk/launch/builder/kaniko_builder.py +5 -4
- wandb/sdk/launch/environment/gcp_environment.py +1 -2
- wandb/sdk/launch/registry/azure_container_registry.py +2 -2
- wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
- wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
- wandb/sdk/launch/runner/abstract.py +5 -5
- wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
- wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
- wandb/sdk/launch/runner/vertex_runner.py +2 -7
- wandb/sdk/launch/sweeps/__init__.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -2
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -4
- wandb/sdk/lib/apikey.py +5 -8
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/fsm.py +3 -18
- wandb/sdk/lib/gitlib.py +6 -5
- wandb/sdk/lib/ipython.py +2 -2
- wandb/sdk/lib/json_util.py +9 -14
- wandb/sdk/lib/printer.py +3 -8
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/lib/retry.py +3 -7
- wandb/sdk/lib/run_moment.py +2 -2
- wandb/sdk/lib/service_connection.py +3 -1
- wandb/sdk/lib/service_token.py +1 -2
- wandb/sdk/mailbox/mailbox_handle.py +3 -7
- wandb/sdk/mailbox/response_handle.py +2 -6
- wandb/sdk/service/streams.py +3 -7
- wandb/sdk/verify/verify.py +5 -6
- wandb/sdk/wandb_config.py +1 -1
- wandb/sdk/wandb_init.py +38 -106
- wandb/sdk/wandb_login.py +7 -6
- wandb/sdk/wandb_run.py +52 -240
- wandb/sdk/wandb_settings.py +71 -60
- wandb/sdk/wandb_setup.py +40 -14
- wandb/sdk/wandb_watch.py +5 -7
- wandb/sync/__init__.py +1 -1
- wandb/sync/sync.py +13 -13
- wandb/util.py +17 -35
- wandb/wandb_agent.py +8 -11
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
- wandb/docker/auth.py +0 -435
- wandb/docker/www_authenticate.py +0 -94
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/image.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1
1
|
import hashlib
|
2
2
|
import logging
|
3
3
|
import os
|
4
|
+
import pathlib
|
4
5
|
from io import BytesIO
|
5
6
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union, cast
|
6
7
|
from urllib import parse
|
7
8
|
|
9
|
+
from packaging.version import parse as parse_version
|
10
|
+
|
8
11
|
import wandb
|
9
12
|
from wandb import util
|
10
13
|
from wandb.sdk.lib import hashutil, runid
|
@@ -30,10 +33,74 @@ if TYPE_CHECKING: # pragma: no cover
|
|
30
33
|
ImageDataType = Union[
|
31
34
|
"matplotlib.artist.Artist", "PILImage", "TorchTensorType", "np.ndarray"
|
32
35
|
]
|
33
|
-
ImageDataOrPathType = Union[str, "Image", ImageDataType]
|
36
|
+
ImageDataOrPathType = Union[str, pathlib.Path, "Image", ImageDataType]
|
34
37
|
TorchTensorType = Union["torch.Tensor", "torch.Variable"]
|
35
38
|
|
36
39
|
|
40
|
+
def _warn_on_invalid_data_range(
|
41
|
+
data: "np.ndarray",
|
42
|
+
normalize: bool = True,
|
43
|
+
) -> None:
|
44
|
+
if not normalize:
|
45
|
+
return
|
46
|
+
|
47
|
+
np = util.get_module(
|
48
|
+
"numpy",
|
49
|
+
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
50
|
+
)
|
51
|
+
|
52
|
+
if np.min(data) < 0 or np.max(data) > 255:
|
53
|
+
wandb.termwarn(
|
54
|
+
"Data passed to `wandb.Image` should consist of values in the range [0, 255], "
|
55
|
+
"image data will be normalized to this range, "
|
56
|
+
"but behavior will be removed in a future version of wandb.",
|
57
|
+
repeat=False,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
def _guess_and_rescale_to_0_255(data: "np.ndarray") -> "np.ndarray":
|
62
|
+
"""Guess the image's format and rescale its values to the range [0, 255].
|
63
|
+
|
64
|
+
This is an unfortunate design flaw carried forward for backward
|
65
|
+
compatibility. A better design would have been to document the expected
|
66
|
+
data format and not mangle the data provided by the user.
|
67
|
+
|
68
|
+
If given data in the range [0, 1], we multiply all values by 255
|
69
|
+
and round down to get integers.
|
70
|
+
|
71
|
+
If given data in the range [-1, 1], we rescale it by mapping -1 to 0 and
|
72
|
+
1 to 255, then round down to get integers.
|
73
|
+
|
74
|
+
We clip and round all other data.
|
75
|
+
"""
|
76
|
+
try:
|
77
|
+
import numpy as np
|
78
|
+
except ImportError:
|
79
|
+
raise wandb.Error(
|
80
|
+
"wandb.Image requires numpy if not supplying PIL images: pip install numpy"
|
81
|
+
) from None
|
82
|
+
|
83
|
+
data_min: float = data.min()
|
84
|
+
data_max: float = data.max()
|
85
|
+
|
86
|
+
if 0 <= data_min and data_max <= 1:
|
87
|
+
return (data * 255).astype(np.uint8)
|
88
|
+
|
89
|
+
elif -1 <= data_min and data_max <= 1:
|
90
|
+
return (255 * 0.5 * (data + 1)).astype(np.uint8)
|
91
|
+
|
92
|
+
else:
|
93
|
+
return data.clip(0, 255).astype(np.uint8)
|
94
|
+
|
95
|
+
|
96
|
+
def _convert_to_uint8(data: "np.ndarray") -> "np.ndarray":
|
97
|
+
np = util.get_module(
|
98
|
+
"numpy",
|
99
|
+
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
100
|
+
)
|
101
|
+
return data.astype(np.uint8)
|
102
|
+
|
103
|
+
|
37
104
|
def _server_accepts_image_filenames(run: "LocalRun") -> bool:
|
38
105
|
if run.offline:
|
39
106
|
return True
|
@@ -43,10 +110,9 @@ def _server_accepts_image_filenames(run: "LocalRun") -> bool:
|
|
43
110
|
max_cli_version = util._get_max_cli_version()
|
44
111
|
if max_cli_version is None:
|
45
112
|
return False
|
46
|
-
from wandb.util import parse_version
|
47
113
|
|
48
|
-
accepts_image_filenames: bool = parse_version(
|
49
|
-
|
114
|
+
accepts_image_filenames: bool = parse_version(max_cli_version) >= parse_version(
|
115
|
+
"0.12.10"
|
50
116
|
)
|
51
117
|
return accepts_image_filenames
|
52
118
|
|
@@ -59,69 +125,11 @@ def _server_accepts_artifact_path(run: "LocalRun") -> bool:
|
|
59
125
|
if max_cli_version is None:
|
60
126
|
return False
|
61
127
|
|
62
|
-
return
|
128
|
+
return parse_version(max_cli_version) >= parse_version("0.12.14")
|
63
129
|
|
64
130
|
|
65
131
|
class Image(BatchableMedia):
|
66
|
-
"""
|
67
|
-
|
68
|
-
Args:
|
69
|
-
data_or_path: (numpy array, string, io) Accepts numpy array of
|
70
|
-
image data, or a PIL image. The class attempts to infer
|
71
|
-
the data format and converts it.
|
72
|
-
mode: (string) The PIL mode for an image. Most common are "L", "RGB",
|
73
|
-
"RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
|
74
|
-
caption: (string) Label for display of image.
|
75
|
-
|
76
|
-
Note : When logging a `torch.Tensor` as a `wandb.Image`, images are normalized. If you do not want to normalize your images, please convert your tensors to a PIL Image.
|
77
|
-
|
78
|
-
Examples:
|
79
|
-
### Create a wandb.Image from a numpy array
|
80
|
-
```python
|
81
|
-
import numpy as np
|
82
|
-
import wandb
|
83
|
-
|
84
|
-
with wandb.init() as run:
|
85
|
-
examples = []
|
86
|
-
for i in range(3):
|
87
|
-
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
88
|
-
image = wandb.Image(pixels, caption=f"random field {i}")
|
89
|
-
examples.append(image)
|
90
|
-
run.log({"examples": examples})
|
91
|
-
```
|
92
|
-
|
93
|
-
### Create a wandb.Image from a PILImage
|
94
|
-
```python
|
95
|
-
import numpy as np
|
96
|
-
from PIL import Image as PILImage
|
97
|
-
import wandb
|
98
|
-
|
99
|
-
with wandb.init() as run:
|
100
|
-
examples = []
|
101
|
-
for i in range(3):
|
102
|
-
pixels = np.random.randint(
|
103
|
-
low=0, high=256, size=(100, 100, 3), dtype=np.uint8
|
104
|
-
)
|
105
|
-
pil_image = PILImage.fromarray(pixels, mode="RGB")
|
106
|
-
image = wandb.Image(pil_image, caption=f"random field {i}")
|
107
|
-
examples.append(image)
|
108
|
-
run.log({"examples": examples})
|
109
|
-
```
|
110
|
-
|
111
|
-
### log .jpg rather than .png (default)
|
112
|
-
```python
|
113
|
-
import numpy as np
|
114
|
-
import wandb
|
115
|
-
|
116
|
-
with wandb.init() as run:
|
117
|
-
examples = []
|
118
|
-
for i in range(3):
|
119
|
-
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
120
|
-
image = wandb.Image(pixels, caption=f"random field {i}", file_type="jpg")
|
121
|
-
examples.append(image)
|
122
|
-
run.log({"examples": examples})
|
123
|
-
```
|
124
|
-
"""
|
132
|
+
"""A class for logging images to W&B."""
|
125
133
|
|
126
134
|
MAX_ITEMS = 108
|
127
135
|
|
@@ -151,7 +159,85 @@ class Image(BatchableMedia):
|
|
151
159
|
boxes: Optional[Union[Dict[str, "BoundingBoxes2D"], Dict[str, dict]]] = None,
|
152
160
|
masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
|
153
161
|
file_type: Optional[str] = None,
|
162
|
+
normalize: bool = True,
|
154
163
|
) -> None:
|
164
|
+
"""Initialize a wandb.Image object.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
data_or_path: Accepts numpy array/pytorch tensor of image data,
|
168
|
+
a PIL image object, or a path to an image file.
|
169
|
+
|
170
|
+
If a numpy array or pytorch tensor is provided,
|
171
|
+
the image data will be saved to the given file type.
|
172
|
+
If the values are not in the range [0, 255] or all values are in the range [0, 1],
|
173
|
+
the image pixel values will be normalized to the range [0, 255]
|
174
|
+
unless `normalize` is set to False.
|
175
|
+
- pytorch tensor should be in the format (channel, height, width)
|
176
|
+
- numpy array should be in the format (height, width, channel)
|
177
|
+
mode: The PIL mode for an image. Most common are "L", "RGB",
|
178
|
+
"RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
|
179
|
+
caption: Label for display of image.
|
180
|
+
grouping: The grouping number for the image.
|
181
|
+
classes: A list of class information for the image,
|
182
|
+
used for labeling bounding boxes, and image masks.
|
183
|
+
boxes: A dictionary containing bounding box information for the image.
|
184
|
+
see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
|
185
|
+
masks: A dictionary containing mask information for the image.
|
186
|
+
see: https://docs.wandb.ai/ref/python/data-types/imagemask/
|
187
|
+
file_type: The file type to save the image as.
|
188
|
+
This parameter has no effect if data_or_path is a path to an image file.
|
189
|
+
normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
|
190
|
+
Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
|
191
|
+
|
192
|
+
Examples:
|
193
|
+
### Create a wandb.Image from a numpy array
|
194
|
+
```python
|
195
|
+
import numpy as np
|
196
|
+
import wandb
|
197
|
+
|
198
|
+
with wandb.init() as run:
|
199
|
+
examples = []
|
200
|
+
for i in range(3):
|
201
|
+
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
202
|
+
image = wandb.Image(pixels, caption=f"random field {i}")
|
203
|
+
examples.append(image)
|
204
|
+
run.log({"examples": examples})
|
205
|
+
```
|
206
|
+
|
207
|
+
### Create a wandb.Image from a PILImage
|
208
|
+
```python
|
209
|
+
import numpy as np
|
210
|
+
from PIL import Image as PILImage
|
211
|
+
import wandb
|
212
|
+
|
213
|
+
with wandb.init() as run:
|
214
|
+
examples = []
|
215
|
+
for i in range(3):
|
216
|
+
pixels = np.random.randint(
|
217
|
+
low=0, high=256, size=(100, 100, 3), dtype=np.uint8
|
218
|
+
)
|
219
|
+
pil_image = PILImage.fromarray(pixels, mode="RGB")
|
220
|
+
image = wandb.Image(pil_image, caption=f"random field {i}")
|
221
|
+
examples.append(image)
|
222
|
+
run.log({"examples": examples})
|
223
|
+
```
|
224
|
+
|
225
|
+
### log .jpg rather than .png (default)
|
226
|
+
```python
|
227
|
+
import numpy as np
|
228
|
+
import wandb
|
229
|
+
|
230
|
+
with wandb.init() as run:
|
231
|
+
examples = []
|
232
|
+
for i in range(3):
|
233
|
+
pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
|
234
|
+
image = wandb.Image(
|
235
|
+
pixels, caption=f"random field {i}", file_type="jpg"
|
236
|
+
)
|
237
|
+
examples.append(image)
|
238
|
+
run.log({"examples": examples})
|
239
|
+
```
|
240
|
+
"""
|
155
241
|
super().__init__(caption=caption)
|
156
242
|
# TODO: We should remove grouping, it's a terrible name and I don't
|
157
243
|
# think anyone uses it.
|
@@ -169,13 +255,15 @@ class Image(BatchableMedia):
|
|
169
255
|
# only overriding additional metadata passed in. If this pattern is compelling, we can generalize.
|
170
256
|
if isinstance(data_or_path, Image):
|
171
257
|
self._initialize_from_wbimage(data_or_path)
|
172
|
-
elif isinstance(data_or_path, str):
|
258
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
259
|
+
data_or_path = str(data_or_path)
|
260
|
+
|
173
261
|
if self.path_is_reference(data_or_path):
|
174
262
|
self._initialize_from_reference(data_or_path)
|
175
263
|
else:
|
176
264
|
self._initialize_from_path(data_or_path)
|
177
265
|
else:
|
178
|
-
self._initialize_from_data(data_or_path, mode, file_type)
|
266
|
+
self._initialize_from_data(data_or_path, mode, file_type, normalize)
|
179
267
|
self._set_initialization_meta(
|
180
268
|
grouping, caption, classes, boxes, masks, file_type
|
181
269
|
)
|
@@ -288,6 +376,7 @@ class Image(BatchableMedia):
|
|
288
376
|
data: "ImageDataType",
|
289
377
|
mode: Optional[str] = None,
|
290
378
|
file_type: Optional[str] = None,
|
379
|
+
normalize: bool = True,
|
291
380
|
) -> None:
|
292
381
|
pil_image = util.get_module(
|
293
382
|
"PIL.Image",
|
@@ -309,28 +398,39 @@ class Image(BatchableMedia):
|
|
309
398
|
elif isinstance(data, pil_image.Image):
|
310
399
|
self._image = data
|
311
400
|
elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
|
312
|
-
vis_util = util.get_module(
|
313
|
-
"torchvision.utils", "torchvision is required to render images"
|
314
|
-
)
|
315
401
|
if hasattr(data, "requires_grad") and data.requires_grad:
|
316
402
|
data = data.detach() # type: ignore
|
317
403
|
if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
|
318
|
-
data = data.to(float)
|
319
|
-
data = vis_util.make_grid(data, normalize=True)
|
404
|
+
data = data.to(float) # type: ignore [union-attr]
|
320
405
|
mode = mode or self.guess_mode(data, file_type)
|
406
|
+
data = data.permute(1, 2, 0).cpu().numpy() # type: ignore [union-attr]
|
407
|
+
|
408
|
+
_warn_on_invalid_data_range(data, normalize)
|
409
|
+
|
410
|
+
data = _guess_and_rescale_to_0_255(data) if normalize else data # type: ignore [arg-type]
|
411
|
+
data = _convert_to_uint8(data)
|
412
|
+
|
413
|
+
if data.ndim > 2:
|
414
|
+
data = data.squeeze()
|
415
|
+
|
321
416
|
self._image = pil_image.fromarray(
|
322
|
-
data
|
417
|
+
data,
|
323
418
|
mode=mode,
|
324
419
|
)
|
325
420
|
else:
|
326
421
|
if hasattr(data, "numpy"): # TF data eager tensors
|
327
422
|
data = data.numpy()
|
328
|
-
if data.ndim > 2:
|
329
|
-
|
423
|
+
if data.ndim > 2: # type: ignore [union-attr]
|
424
|
+
# get rid of trivial dimensions as a convenience
|
425
|
+
data = data.squeeze() # type: ignore [union-attr]
|
426
|
+
|
427
|
+
_warn_on_invalid_data_range(data, normalize) # type: ignore [arg-type]
|
330
428
|
|
331
429
|
mode = mode or self.guess_mode(data, file_type)
|
430
|
+
data = _guess_and_rescale_to_0_255(data) if normalize else data # type: ignore [arg-type]
|
431
|
+
data = _convert_to_uint8(data) # type: ignore [arg-type]
|
332
432
|
self._image = pil_image.fromarray(
|
333
|
-
|
433
|
+
data,
|
334
434
|
mode=mode,
|
335
435
|
)
|
336
436
|
|
@@ -459,7 +559,7 @@ class Image(BatchableMedia):
|
|
459
559
|
}
|
460
560
|
|
461
561
|
elif not isinstance(run_or_artifact, Run):
|
462
|
-
raise
|
562
|
+
raise TypeError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
|
463
563
|
|
464
564
|
if self._boxes:
|
465
565
|
json_dict["boxes"] = {
|
@@ -485,7 +585,7 @@ class Image(BatchableMedia):
|
|
485
585
|
else:
|
486
586
|
num_channels = data.shape[-1]
|
487
587
|
|
488
|
-
if ndims == 2:
|
588
|
+
if ndims == 2 or num_channels == 1:
|
489
589
|
return "L"
|
490
590
|
elif num_channels == 3:
|
491
591
|
return "RGB"
|
@@ -501,34 +601,9 @@ class Image(BatchableMedia):
|
|
501
601
|
return "RGBA"
|
502
602
|
else:
|
503
603
|
raise ValueError(
|
504
|
-
"Un-supported shape for image conversion {
|
604
|
+
f"Un-supported shape for image conversion {list(data.shape)}"
|
505
605
|
)
|
506
606
|
|
507
|
-
@classmethod
|
508
|
-
def to_uint8(cls, data: "np.ndarray") -> "np.ndarray":
|
509
|
-
"""Convert image data to uint8.
|
510
|
-
|
511
|
-
Convert floating point image on the range [0,1] and integer images on the range
|
512
|
-
[0,255] to uint8, clipping if necessary.
|
513
|
-
"""
|
514
|
-
np = util.get_module(
|
515
|
-
"numpy",
|
516
|
-
required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
|
517
|
-
)
|
518
|
-
|
519
|
-
# I think it's better to check the image range vs the data type, since many
|
520
|
-
# image libraries will return floats between 0 and 255
|
521
|
-
|
522
|
-
# some images have range -1...1 or 0-1
|
523
|
-
dmin = np.min(data)
|
524
|
-
if dmin < 0:
|
525
|
-
data = (data - np.min(data)) / np.ptp(data)
|
526
|
-
if np.max(data) <= 1.0:
|
527
|
-
data = (data * 255).astype(np.int32)
|
528
|
-
|
529
|
-
# assert issubclass(data.dtype.type, np.integer), 'Illegal image format.'
|
530
|
-
return data.clip(0, 255).astype(np.uint8)
|
531
|
-
|
532
607
|
@classmethod
|
533
608
|
def seq_to_json(
|
534
609
|
cls: Type["Image"],
|
wandb/sdk/data_types/molecule.py
CHANGED
@@ -26,7 +26,7 @@ class Molecule(BatchableMedia):
|
|
26
26
|
"""Wandb class for 3D Molecular data.
|
27
27
|
|
28
28
|
Args:
|
29
|
-
data_or_path: (string, io)
|
29
|
+
data_or_path: (pathlib.Path, string, io)
|
30
30
|
Molecule can be initialized from a file name or an io object.
|
31
31
|
caption: (string)
|
32
32
|
Caption associated with the molecule for display.
|
@@ -49,7 +49,7 @@ class Molecule(BatchableMedia):
|
|
49
49
|
|
50
50
|
def __init__(
|
51
51
|
self,
|
52
|
-
data_or_path: Union[str, "TextIO"],
|
52
|
+
data_or_path: Union[str, pathlib.Path, "TextIO"],
|
53
53
|
caption: Optional[str] = None,
|
54
54
|
**kwargs: str,
|
55
55
|
) -> None:
|
@@ -82,7 +82,9 @@ class Molecule(BatchableMedia):
|
|
82
82
|
f.write(molecule)
|
83
83
|
|
84
84
|
self._set_file(tmp_path, is_tmp=True)
|
85
|
-
elif isinstance(data_or_path, str):
|
85
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
86
|
+
data_or_path = str(data_or_path)
|
87
|
+
|
86
88
|
extension = os.path.splitext(data_or_path)[1][1:]
|
87
89
|
if extension not in Molecule.SUPPORTED_TYPES:
|
88
90
|
raise ValueError(
|
@@ -144,9 +146,7 @@ class Molecule(BatchableMedia):
|
|
144
146
|
elif isinstance(data_or_path, rdkit_chem.rdchem.Mol):
|
145
147
|
molecule = data_or_path
|
146
148
|
else:
|
147
|
-
raise
|
148
|
-
"Data must be file name or an rdkit.Chem.rdchem.Mol object"
|
149
|
-
)
|
149
|
+
raise TypeError("Data must be file name or an rdkit.Chem.rdchem.Mol object")
|
150
150
|
|
151
151
|
if convert_to_3d_and_optimize:
|
152
152
|
molecule = rdkit_chem.AddHs(molecule)
|
@@ -2,6 +2,7 @@ import codecs
|
|
2
2
|
import itertools
|
3
3
|
import json
|
4
4
|
import os
|
5
|
+
import pathlib
|
5
6
|
from typing import (
|
6
7
|
TYPE_CHECKING,
|
7
8
|
ClassVar,
|
@@ -187,7 +188,7 @@ class Object3D(BatchableMedia):
|
|
187
188
|
"""Wandb class for 3D point clouds.
|
188
189
|
|
189
190
|
Args:
|
190
|
-
data_or_path: (numpy array, string, io)
|
191
|
+
data_or_path: (numpy array, pathlib.Path, string, io)
|
191
192
|
Object3D can be initialized from a file or a numpy array.
|
192
193
|
|
193
194
|
You can pass a path to a file or an io object and a file_type
|
@@ -214,14 +215,16 @@ class Object3D(BatchableMedia):
|
|
214
215
|
|
215
216
|
def __init__(
|
216
217
|
self,
|
217
|
-
data_or_path: Union["np.ndarray", str, "TextIO", dict],
|
218
|
+
data_or_path: Union["np.ndarray", str, pathlib.Path, "TextIO", dict],
|
218
219
|
caption: Optional[str] = None,
|
219
220
|
**kwargs: Optional[Union[str, "FileFormat3D"]],
|
220
221
|
) -> None:
|
221
222
|
super().__init__(caption=caption)
|
222
223
|
|
223
|
-
if hasattr(data_or_path, "name"):
|
224
|
-
# if the file has a path, we just detect the type and copy it from there
|
224
|
+
if hasattr(data_or_path, "name") and not isinstance(data_or_path, pathlib.Path):
|
225
|
+
# if the file has a path, we just detect the type and copy it from there.
|
226
|
+
# this does not work for pathlib.Path objects,
|
227
|
+
# where `.name` returns the last directory in the path.
|
225
228
|
data_or_path = data_or_path.name
|
226
229
|
|
227
230
|
if hasattr(data_or_path, "read"):
|
@@ -247,7 +250,9 @@ class Object3D(BatchableMedia):
|
|
247
250
|
f.write(object_3d)
|
248
251
|
|
249
252
|
self._set_file(tmp_path, is_tmp=True, extension=extension)
|
250
|
-
elif isinstance(data_or_path, str):
|
253
|
+
elif isinstance(data_or_path, (str, pathlib.Path)):
|
254
|
+
data_or_path = str(data_or_path)
|
255
|
+
|
251
256
|
path = data_or_path
|
252
257
|
extension = None
|
253
258
|
for supported_type in Object3D.SUPPORTED_TYPES:
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import os
|
4
|
+
import pathlib
|
4
5
|
import shutil
|
5
6
|
import sys
|
6
7
|
from types import ModuleType
|
@@ -72,10 +73,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
72
73
|
|
73
74
|
_model_obj: SavedModelObjType | None
|
74
75
|
_path: str | None
|
75
|
-
_input_obj_or_path: SavedModelObjType | str
|
76
|
+
_input_obj_or_path: SavedModelObjType | str | pathlib.Path
|
76
77
|
|
77
78
|
# Public Methods
|
78
|
-
def __init__(
|
79
|
+
def __init__(
|
80
|
+
self, obj_or_path: SavedModelObjType | str | pathlib.Path, **kwargs: Any
|
81
|
+
) -> None:
|
79
82
|
super().__init__()
|
80
83
|
if self.__class__ == _SavedModel:
|
81
84
|
raise TypeError(
|
@@ -84,9 +87,11 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
84
87
|
self._model_obj = None
|
85
88
|
self._path = None
|
86
89
|
self._input_obj_or_path = obj_or_path
|
87
|
-
input_is_path = isinstance(obj_or_path, str) and os.path.exists(
|
90
|
+
input_is_path = isinstance(obj_or_path, (str, pathlib.Path)) and os.path.exists(
|
91
|
+
obj_or_path
|
92
|
+
)
|
88
93
|
if input_is_path:
|
89
|
-
|
94
|
+
obj_or_path = str(obj_or_path)
|
90
95
|
self._set_obj(self._deserialize(obj_or_path))
|
91
96
|
else:
|
92
97
|
self._set_obj(obj_or_path)
|
@@ -140,7 +145,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
140
145
|
from wandb.sdk.wandb_run import Run
|
141
146
|
|
142
147
|
if isinstance(run_or_artifact, Run):
|
143
|
-
raise
|
148
|
+
raise TypeError("SavedModel cannot be added to run - must use artifact")
|
144
149
|
artifact = run_or_artifact
|
145
150
|
json_obj = {
|
146
151
|
"type": self._log_type,
|
@@ -280,7 +285,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
280
285
|
|
281
286
|
def __init__(
|
282
287
|
self,
|
283
|
-
obj_or_path: SavedModelObjType | str,
|
288
|
+
obj_or_path: SavedModelObjType | str | pathlib.Path,
|
284
289
|
dep_py_files: list[str] | None = None,
|
285
290
|
):
|
286
291
|
super().__init__(obj_or_path)
|