wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/__init__.pyi +3 -6
  3. wandb/_iterutils.py +26 -7
  4. wandb/_pydantic/__init__.py +2 -1
  5. wandb/_pydantic/utils.py +7 -0
  6. wandb/agents/pyagent.py +9 -15
  7. wandb/analytics/sentry.py +1 -2
  8. wandb/apis/attrs.py +3 -4
  9. wandb/apis/importers/internals/util.py +1 -1
  10. wandb/apis/importers/validation.py +2 -2
  11. wandb/apis/importers/wandb.py +30 -25
  12. wandb/apis/normalize.py +2 -2
  13. wandb/apis/public/__init__.py +1 -0
  14. wandb/apis/public/api.py +37 -33
  15. wandb/apis/public/artifacts.py +103 -72
  16. wandb/apis/public/jobs.py +3 -2
  17. wandb/apis/public/registries/registries_search.py +4 -2
  18. wandb/apis/public/registries/registry.py +1 -1
  19. wandb/apis/public/registries/utils.py +9 -9
  20. wandb/apis/public/runs.py +18 -6
  21. wandb/automations/_filters/expressions.py +1 -1
  22. wandb/automations/_filters/operators.py +1 -1
  23. wandb/automations/_filters/run_metrics.py +1 -1
  24. wandb/beta/workflows.py +6 -5
  25. wandb/bin/gpu_stats.exe +0 -0
  26. wandb/bin/wandb-core +0 -0
  27. wandb/cli/cli.py +54 -73
  28. wandb/docker/__init__.py +21 -74
  29. wandb/docker/names.py +40 -0
  30. wandb/env.py +0 -1
  31. wandb/errors/util.py +1 -1
  32. wandb/filesync/step_checksum.py +1 -1
  33. wandb/filesync/step_upload.py +1 -1
  34. wandb/integration/diffusers/resolvers/multimodal.py +1 -2
  35. wandb/integration/gym/__init__.py +5 -6
  36. wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
  37. wandb/integration/keras/keras.py +13 -19
  38. wandb/integration/kfp/kfp_patch.py +2 -3
  39. wandb/integration/langchain/wandb_tracer.py +1 -1
  40. wandb/integration/metaflow/metaflow.py +13 -13
  41. wandb/integration/openai/fine_tuning.py +3 -2
  42. wandb/integration/sagemaker/auth.py +2 -1
  43. wandb/integration/sklearn/utils.py +2 -1
  44. wandb/integration/tensorboard/__init__.py +1 -1
  45. wandb/integration/tensorboard/log.py +2 -5
  46. wandb/integration/tensorflow/__init__.py +2 -2
  47. wandb/jupyter.py +20 -17
  48. wandb/plot/confusion_matrix.py +1 -1
  49. wandb/plot/utils.py +8 -7
  50. wandb/proto/v3/wandb_internal_pb2.py +355 -335
  51. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  52. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  53. wandb/proto/v4/wandb_internal_pb2.py +339 -335
  54. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  56. wandb/proto/v5/wandb_internal_pb2.py +339 -335
  57. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  59. wandb/proto/v6/wandb_internal_pb2.py +339 -335
  60. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
  62. wandb/proto/wandb_deprecated.py +6 -8
  63. wandb/sdk/artifacts/_internal_artifact.py +43 -0
  64. wandb/sdk/artifacts/_validators.py +55 -35
  65. wandb/sdk/artifacts/artifact.py +117 -115
  66. wandb/sdk/artifacts/artifact_download_logger.py +2 -0
  67. wandb/sdk/artifacts/artifact_saver.py +1 -3
  68. wandb/sdk/artifacts/artifact_state.py +2 -0
  69. wandb/sdk/artifacts/artifact_ttl.py +2 -0
  70. wandb/sdk/artifacts/exceptions.py +14 -0
  71. wandb/sdk/artifacts/staging.py +2 -0
  72. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
  73. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  74. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
  75. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
  76. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  77. wandb/sdk/artifacts/storage_layout.py +2 -0
  78. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
  79. wandb/sdk/backend/backend.py +11 -182
  80. wandb/sdk/data_types/_dtypes.py +2 -6
  81. wandb/sdk/data_types/audio.py +20 -3
  82. wandb/sdk/data_types/base_types/media.py +12 -7
  83. wandb/sdk/data_types/base_types/wb_value.py +8 -18
  84. wandb/sdk/data_types/bokeh.py +19 -2
  85. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
  86. wandb/sdk/data_types/helper_types/image_mask.py +7 -1
  87. wandb/sdk/data_types/html.py +4 -4
  88. wandb/sdk/data_types/image.py +178 -103
  89. wandb/sdk/data_types/molecule.py +6 -6
  90. wandb/sdk/data_types/object_3d.py +10 -5
  91. wandb/sdk/data_types/saved_model.py +11 -6
  92. wandb/sdk/data_types/table.py +313 -83
  93. wandb/sdk/data_types/table_decorators.py +108 -0
  94. wandb/sdk/data_types/utils.py +43 -7
  95. wandb/sdk/data_types/video.py +21 -3
  96. wandb/sdk/interface/interface.py +10 -0
  97. wandb/sdk/internal/datastore.py +2 -6
  98. wandb/sdk/internal/file_pusher.py +1 -5
  99. wandb/sdk/internal/file_stream.py +8 -17
  100. wandb/sdk/internal/handler.py +2 -2
  101. wandb/sdk/internal/incremental_table_util.py +53 -0
  102. wandb/sdk/internal/internal.py +3 -5
  103. wandb/sdk/internal/internal_api.py +66 -89
  104. wandb/sdk/internal/job_builder.py +2 -7
  105. wandb/sdk/internal/profiler.py +2 -2
  106. wandb/sdk/internal/progress.py +1 -3
  107. wandb/sdk/internal/run.py +1 -6
  108. wandb/sdk/internal/sender.py +24 -36
  109. wandb/sdk/internal/system/assets/aggregators.py +1 -7
  110. wandb/sdk/internal/system/assets/disk.py +3 -3
  111. wandb/sdk/internal/system/assets/gpu.py +4 -4
  112. wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
  113. wandb/sdk/internal/system/assets/interfaces.py +6 -6
  114. wandb/sdk/internal/system/assets/tpu.py +1 -1
  115. wandb/sdk/internal/system/assets/trainium.py +6 -6
  116. wandb/sdk/internal/system/system_info.py +5 -7
  117. wandb/sdk/internal/system/system_monitor.py +4 -4
  118. wandb/sdk/internal/tb_watcher.py +5 -7
  119. wandb/sdk/launch/_launch.py +1 -1
  120. wandb/sdk/launch/_project_spec.py +19 -20
  121. wandb/sdk/launch/agent/agent.py +3 -3
  122. wandb/sdk/launch/agent/config.py +1 -1
  123. wandb/sdk/launch/agent/job_status_tracker.py +2 -2
  124. wandb/sdk/launch/builder/build.py +2 -3
  125. wandb/sdk/launch/builder/kaniko_builder.py +5 -4
  126. wandb/sdk/launch/environment/gcp_environment.py +1 -2
  127. wandb/sdk/launch/registry/azure_container_registry.py +2 -2
  128. wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
  129. wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
  130. wandb/sdk/launch/runner/abstract.py +5 -5
  131. wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
  132. wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
  133. wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
  134. wandb/sdk/launch/runner/vertex_runner.py +2 -7
  135. wandb/sdk/launch/sweeps/__init__.py +1 -1
  136. wandb/sdk/launch/sweeps/scheduler.py +2 -2
  137. wandb/sdk/launch/sweeps/utils.py +3 -3
  138. wandb/sdk/launch/utils.py +3 -4
  139. wandb/sdk/lib/apikey.py +5 -8
  140. wandb/sdk/lib/config_util.py +3 -3
  141. wandb/sdk/lib/fsm.py +3 -18
  142. wandb/sdk/lib/gitlib.py +6 -5
  143. wandb/sdk/lib/ipython.py +2 -2
  144. wandb/sdk/lib/json_util.py +9 -14
  145. wandb/sdk/lib/printer.py +3 -8
  146. wandb/sdk/lib/redirect.py +1 -1
  147. wandb/sdk/lib/retry.py +3 -7
  148. wandb/sdk/lib/run_moment.py +2 -2
  149. wandb/sdk/lib/service_connection.py +3 -1
  150. wandb/sdk/lib/service_token.py +1 -2
  151. wandb/sdk/mailbox/mailbox_handle.py +3 -7
  152. wandb/sdk/mailbox/response_handle.py +2 -6
  153. wandb/sdk/service/streams.py +3 -7
  154. wandb/sdk/verify/verify.py +5 -6
  155. wandb/sdk/wandb_config.py +1 -1
  156. wandb/sdk/wandb_init.py +38 -106
  157. wandb/sdk/wandb_login.py +7 -6
  158. wandb/sdk/wandb_run.py +52 -240
  159. wandb/sdk/wandb_settings.py +71 -60
  160. wandb/sdk/wandb_setup.py +40 -14
  161. wandb/sdk/wandb_watch.py +5 -7
  162. wandb/sync/__init__.py +1 -1
  163. wandb/sync/sync.py +13 -13
  164. wandb/util.py +17 -35
  165. wandb/wandb_agent.py +8 -11
  166. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
  167. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
  168. wandb/docker/auth.py +0 -435
  169. wandb/docker/www_authenticate.py +0 -94
  170. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
  171. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
  172. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,13 @@
1
1
  import hashlib
2
2
  import logging
3
3
  import os
4
+ import pathlib
4
5
  from io import BytesIO
5
6
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union, cast
6
7
  from urllib import parse
7
8
 
9
+ from packaging.version import parse as parse_version
10
+
8
11
  import wandb
9
12
  from wandb import util
10
13
  from wandb.sdk.lib import hashutil, runid
@@ -30,10 +33,74 @@ if TYPE_CHECKING: # pragma: no cover
30
33
  ImageDataType = Union[
31
34
  "matplotlib.artist.Artist", "PILImage", "TorchTensorType", "np.ndarray"
32
35
  ]
33
- ImageDataOrPathType = Union[str, "Image", ImageDataType]
36
+ ImageDataOrPathType = Union[str, pathlib.Path, "Image", ImageDataType]
34
37
  TorchTensorType = Union["torch.Tensor", "torch.Variable"]
35
38
 
36
39
 
40
+ def _warn_on_invalid_data_range(
41
+ data: "np.ndarray",
42
+ normalize: bool = True,
43
+ ) -> None:
44
+ if not normalize:
45
+ return
46
+
47
+ np = util.get_module(
48
+ "numpy",
49
+ required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
50
+ )
51
+
52
+ if np.min(data) < 0 or np.max(data) > 255:
53
+ wandb.termwarn(
54
+ "Data passed to `wandb.Image` should consist of values in the range [0, 255], "
55
+ "image data will be normalized to this range, "
56
+ "but behavior will be removed in a future version of wandb.",
57
+ repeat=False,
58
+ )
59
+
60
+
61
+ def _guess_and_rescale_to_0_255(data: "np.ndarray") -> "np.ndarray":
62
+ """Guess the image's format and rescale its values to the range [0, 255].
63
+
64
+ This is an unfortunate design flaw carried forward for backward
65
+ compatibility. A better design would have been to document the expected
66
+ data format and not mangle the data provided by the user.
67
+
68
+ If given data in the range [0, 1], we multiply all values by 255
69
+ and round down to get integers.
70
+
71
+ If given data in the range [-1, 1], we rescale it by mapping -1 to 0 and
72
+ 1 to 255, then round down to get integers.
73
+
74
+ We clip and round all other data.
75
+ """
76
+ try:
77
+ import numpy as np
78
+ except ImportError:
79
+ raise wandb.Error(
80
+ "wandb.Image requires numpy if not supplying PIL images: pip install numpy"
81
+ ) from None
82
+
83
+ data_min: float = data.min()
84
+ data_max: float = data.max()
85
+
86
+ if 0 <= data_min and data_max <= 1:
87
+ return (data * 255).astype(np.uint8)
88
+
89
+ elif -1 <= data_min and data_max <= 1:
90
+ return (255 * 0.5 * (data + 1)).astype(np.uint8)
91
+
92
+ else:
93
+ return data.clip(0, 255).astype(np.uint8)
94
+
95
+
96
+ def _convert_to_uint8(data: "np.ndarray") -> "np.ndarray":
97
+ np = util.get_module(
98
+ "numpy",
99
+ required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
100
+ )
101
+ return data.astype(np.uint8)
102
+
103
+
37
104
  def _server_accepts_image_filenames(run: "LocalRun") -> bool:
38
105
  if run.offline:
39
106
  return True
@@ -43,10 +110,9 @@ def _server_accepts_image_filenames(run: "LocalRun") -> bool:
43
110
  max_cli_version = util._get_max_cli_version()
44
111
  if max_cli_version is None:
45
112
  return False
46
- from wandb.util import parse_version
47
113
 
48
- accepts_image_filenames: bool = parse_version("0.12.10") <= parse_version(
49
- max_cli_version
114
+ accepts_image_filenames: bool = parse_version(max_cli_version) >= parse_version(
115
+ "0.12.10"
50
116
  )
51
117
  return accepts_image_filenames
52
118
 
@@ -59,69 +125,11 @@ def _server_accepts_artifact_path(run: "LocalRun") -> bool:
59
125
  if max_cli_version is None:
60
126
  return False
61
127
 
62
- return util.parse_version("0.12.14") <= util.parse_version(max_cli_version)
128
+ return parse_version(max_cli_version) >= parse_version("0.12.14")
63
129
 
64
130
 
65
131
  class Image(BatchableMedia):
66
- """Format images for logging to W&B.
67
-
68
- Args:
69
- data_or_path: (numpy array, string, io) Accepts numpy array of
70
- image data, or a PIL image. The class attempts to infer
71
- the data format and converts it.
72
- mode: (string) The PIL mode for an image. Most common are "L", "RGB",
73
- "RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
74
- caption: (string) Label for display of image.
75
-
76
- Note : When logging a `torch.Tensor` as a `wandb.Image`, images are normalized. If you do not want to normalize your images, please convert your tensors to a PIL Image.
77
-
78
- Examples:
79
- ### Create a wandb.Image from a numpy array
80
- ```python
81
- import numpy as np
82
- import wandb
83
-
84
- with wandb.init() as run:
85
- examples = []
86
- for i in range(3):
87
- pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
88
- image = wandb.Image(pixels, caption=f"random field {i}")
89
- examples.append(image)
90
- run.log({"examples": examples})
91
- ```
92
-
93
- ### Create a wandb.Image from a PILImage
94
- ```python
95
- import numpy as np
96
- from PIL import Image as PILImage
97
- import wandb
98
-
99
- with wandb.init() as run:
100
- examples = []
101
- for i in range(3):
102
- pixels = np.random.randint(
103
- low=0, high=256, size=(100, 100, 3), dtype=np.uint8
104
- )
105
- pil_image = PILImage.fromarray(pixels, mode="RGB")
106
- image = wandb.Image(pil_image, caption=f"random field {i}")
107
- examples.append(image)
108
- run.log({"examples": examples})
109
- ```
110
-
111
- ### log .jpg rather than .png (default)
112
- ```python
113
- import numpy as np
114
- import wandb
115
-
116
- with wandb.init() as run:
117
- examples = []
118
- for i in range(3):
119
- pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
120
- image = wandb.Image(pixels, caption=f"random field {i}", file_type="jpg")
121
- examples.append(image)
122
- run.log({"examples": examples})
123
- ```
124
- """
132
+ """A class for logging images to W&B."""
125
133
 
126
134
  MAX_ITEMS = 108
127
135
 
@@ -151,7 +159,85 @@ class Image(BatchableMedia):
151
159
  boxes: Optional[Union[Dict[str, "BoundingBoxes2D"], Dict[str, dict]]] = None,
152
160
  masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
153
161
  file_type: Optional[str] = None,
162
+ normalize: bool = True,
154
163
  ) -> None:
164
+ """Initialize a wandb.Image object.
165
+
166
+ Args:
167
+ data_or_path: Accepts numpy array/pytorch tensor of image data,
168
+ a PIL image object, or a path to an image file.
169
+
170
+ If a numpy array or pytorch tensor is provided,
171
+ the image data will be saved to the given file type.
172
+ If the values are not in the range [0, 255] or all values are in the range [0, 1],
173
+ the image pixel values will be normalized to the range [0, 255]
174
+ unless `normalize` is set to False.
175
+ - pytorch tensor should be in the format (channel, height, width)
176
+ - numpy array should be in the format (height, width, channel)
177
+ mode: The PIL mode for an image. Most common are "L", "RGB",
178
+ "RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
179
+ caption: Label for display of image.
180
+ grouping: The grouping number for the image.
181
+ classes: A list of class information for the image,
182
+ used for labeling bounding boxes, and image masks.
183
+ boxes: A dictionary containing bounding box information for the image.
184
+ see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
185
+ masks: A dictionary containing mask information for the image.
186
+ see: https://docs.wandb.ai/ref/python/data-types/imagemask/
187
+ file_type: The file type to save the image as.
188
+ This parameter has no effect if data_or_path is a path to an image file.
189
+ normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
190
+ Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
191
+
192
+ Examples:
193
+ ### Create a wandb.Image from a numpy array
194
+ ```python
195
+ import numpy as np
196
+ import wandb
197
+
198
+ with wandb.init() as run:
199
+ examples = []
200
+ for i in range(3):
201
+ pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
202
+ image = wandb.Image(pixels, caption=f"random field {i}")
203
+ examples.append(image)
204
+ run.log({"examples": examples})
205
+ ```
206
+
207
+ ### Create a wandb.Image from a PILImage
208
+ ```python
209
+ import numpy as np
210
+ from PIL import Image as PILImage
211
+ import wandb
212
+
213
+ with wandb.init() as run:
214
+ examples = []
215
+ for i in range(3):
216
+ pixels = np.random.randint(
217
+ low=0, high=256, size=(100, 100, 3), dtype=np.uint8
218
+ )
219
+ pil_image = PILImage.fromarray(pixels, mode="RGB")
220
+ image = wandb.Image(pil_image, caption=f"random field {i}")
221
+ examples.append(image)
222
+ run.log({"examples": examples})
223
+ ```
224
+
225
+ ### log .jpg rather than .png (default)
226
+ ```python
227
+ import numpy as np
228
+ import wandb
229
+
230
+ with wandb.init() as run:
231
+ examples = []
232
+ for i in range(3):
233
+ pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
234
+ image = wandb.Image(
235
+ pixels, caption=f"random field {i}", file_type="jpg"
236
+ )
237
+ examples.append(image)
238
+ run.log({"examples": examples})
239
+ ```
240
+ """
155
241
  super().__init__(caption=caption)
156
242
  # TODO: We should remove grouping, it's a terrible name and I don't
157
243
  # think anyone uses it.
@@ -169,13 +255,15 @@ class Image(BatchableMedia):
169
255
  # only overriding additional metadata passed in. If this pattern is compelling, we can generalize.
170
256
  if isinstance(data_or_path, Image):
171
257
  self._initialize_from_wbimage(data_or_path)
172
- elif isinstance(data_or_path, str):
258
+ elif isinstance(data_or_path, (str, pathlib.Path)):
259
+ data_or_path = str(data_or_path)
260
+
173
261
  if self.path_is_reference(data_or_path):
174
262
  self._initialize_from_reference(data_or_path)
175
263
  else:
176
264
  self._initialize_from_path(data_or_path)
177
265
  else:
178
- self._initialize_from_data(data_or_path, mode, file_type)
266
+ self._initialize_from_data(data_or_path, mode, file_type, normalize)
179
267
  self._set_initialization_meta(
180
268
  grouping, caption, classes, boxes, masks, file_type
181
269
  )
@@ -288,6 +376,7 @@ class Image(BatchableMedia):
288
376
  data: "ImageDataType",
289
377
  mode: Optional[str] = None,
290
378
  file_type: Optional[str] = None,
379
+ normalize: bool = True,
291
380
  ) -> None:
292
381
  pil_image = util.get_module(
293
382
  "PIL.Image",
@@ -309,28 +398,39 @@ class Image(BatchableMedia):
309
398
  elif isinstance(data, pil_image.Image):
310
399
  self._image = data
311
400
  elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
312
- vis_util = util.get_module(
313
- "torchvision.utils", "torchvision is required to render images"
314
- )
315
401
  if hasattr(data, "requires_grad") and data.requires_grad:
316
402
  data = data.detach() # type: ignore
317
403
  if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
318
- data = data.to(float)
319
- data = vis_util.make_grid(data, normalize=True)
404
+ data = data.to(float) # type: ignore [union-attr]
320
405
  mode = mode or self.guess_mode(data, file_type)
406
+ data = data.permute(1, 2, 0).cpu().numpy() # type: ignore [union-attr]
407
+
408
+ _warn_on_invalid_data_range(data, normalize)
409
+
410
+ data = _guess_and_rescale_to_0_255(data) if normalize else data # type: ignore [arg-type]
411
+ data = _convert_to_uint8(data)
412
+
413
+ if data.ndim > 2:
414
+ data = data.squeeze()
415
+
321
416
  self._image = pil_image.fromarray(
322
- data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(),
417
+ data,
323
418
  mode=mode,
324
419
  )
325
420
  else:
326
421
  if hasattr(data, "numpy"): # TF data eager tensors
327
422
  data = data.numpy()
328
- if data.ndim > 2:
329
- data = data.squeeze() # get rid of trivial dimensions as a convenience
423
+ if data.ndim > 2: # type: ignore [union-attr]
424
+ # get rid of trivial dimensions as a convenience
425
+ data = data.squeeze() # type: ignore [union-attr]
426
+
427
+ _warn_on_invalid_data_range(data, normalize) # type: ignore [arg-type]
330
428
 
331
429
  mode = mode or self.guess_mode(data, file_type)
430
+ data = _guess_and_rescale_to_0_255(data) if normalize else data # type: ignore [arg-type]
431
+ data = _convert_to_uint8(data) # type: ignore [arg-type]
332
432
  self._image = pil_image.fromarray(
333
- self.to_uint8(data),
433
+ data,
334
434
  mode=mode,
335
435
  )
336
436
 
@@ -459,7 +559,7 @@ class Image(BatchableMedia):
459
559
  }
460
560
 
461
561
  elif not isinstance(run_or_artifact, Run):
462
- raise ValueError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
562
+ raise TypeError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
463
563
 
464
564
  if self._boxes:
465
565
  json_dict["boxes"] = {
@@ -485,7 +585,7 @@ class Image(BatchableMedia):
485
585
  else:
486
586
  num_channels = data.shape[-1]
487
587
 
488
- if ndims == 2:
588
+ if ndims == 2 or num_channels == 1:
489
589
  return "L"
490
590
  elif num_channels == 3:
491
591
  return "RGB"
@@ -501,34 +601,9 @@ class Image(BatchableMedia):
501
601
  return "RGBA"
502
602
  else:
503
603
  raise ValueError(
504
- "Un-supported shape for image conversion {}".format(list(data.shape))
604
+ f"Un-supported shape for image conversion {list(data.shape)}"
505
605
  )
506
606
 
507
- @classmethod
508
- def to_uint8(cls, data: "np.ndarray") -> "np.ndarray":
509
- """Convert image data to uint8.
510
-
511
- Convert floating point image on the range [0,1] and integer images on the range
512
- [0,255] to uint8, clipping if necessary.
513
- """
514
- np = util.get_module(
515
- "numpy",
516
- required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
517
- )
518
-
519
- # I think it's better to check the image range vs the data type, since many
520
- # image libraries will return floats between 0 and 255
521
-
522
- # some images have range -1...1 or 0-1
523
- dmin = np.min(data)
524
- if dmin < 0:
525
- data = (data - np.min(data)) / np.ptp(data)
526
- if np.max(data) <= 1.0:
527
- data = (data * 255).astype(np.int32)
528
-
529
- # assert issubclass(data.dtype.type, np.integer), 'Illegal image format.'
530
- return data.clip(0, 255).astype(np.uint8)
531
-
532
607
  @classmethod
533
608
  def seq_to_json(
534
609
  cls: Type["Image"],
@@ -26,7 +26,7 @@ class Molecule(BatchableMedia):
26
26
  """Wandb class for 3D Molecular data.
27
27
 
28
28
  Args:
29
- data_or_path: (string, io)
29
+ data_or_path: (pathlib.Path, string, io)
30
30
  Molecule can be initialized from a file name or an io object.
31
31
  caption: (string)
32
32
  Caption associated with the molecule for display.
@@ -49,7 +49,7 @@ class Molecule(BatchableMedia):
49
49
 
50
50
  def __init__(
51
51
  self,
52
- data_or_path: Union[str, "TextIO"],
52
+ data_or_path: Union[str, pathlib.Path, "TextIO"],
53
53
  caption: Optional[str] = None,
54
54
  **kwargs: str,
55
55
  ) -> None:
@@ -82,7 +82,9 @@ class Molecule(BatchableMedia):
82
82
  f.write(molecule)
83
83
 
84
84
  self._set_file(tmp_path, is_tmp=True)
85
- elif isinstance(data_or_path, str):
85
+ elif isinstance(data_or_path, (str, pathlib.Path)):
86
+ data_or_path = str(data_or_path)
87
+
86
88
  extension = os.path.splitext(data_or_path)[1][1:]
87
89
  if extension not in Molecule.SUPPORTED_TYPES:
88
90
  raise ValueError(
@@ -144,9 +146,7 @@ class Molecule(BatchableMedia):
144
146
  elif isinstance(data_or_path, rdkit_chem.rdchem.Mol):
145
147
  molecule = data_or_path
146
148
  else:
147
- raise ValueError(
148
- "Data must be file name or an rdkit.Chem.rdchem.Mol object"
149
- )
149
+ raise TypeError("Data must be file name or an rdkit.Chem.rdchem.Mol object")
150
150
 
151
151
  if convert_to_3d_and_optimize:
152
152
  molecule = rdkit_chem.AddHs(molecule)
@@ -2,6 +2,7 @@ import codecs
2
2
  import itertools
3
3
  import json
4
4
  import os
5
+ import pathlib
5
6
  from typing import (
6
7
  TYPE_CHECKING,
7
8
  ClassVar,
@@ -187,7 +188,7 @@ class Object3D(BatchableMedia):
187
188
  """Wandb class for 3D point clouds.
188
189
 
189
190
  Args:
190
- data_or_path: (numpy array, string, io)
191
+ data_or_path: (numpy array, pathlib.Path, string, io)
191
192
  Object3D can be initialized from a file or a numpy array.
192
193
 
193
194
  You can pass a path to a file or an io object and a file_type
@@ -214,14 +215,16 @@ class Object3D(BatchableMedia):
214
215
 
215
216
  def __init__(
216
217
  self,
217
- data_or_path: Union["np.ndarray", str, "TextIO", dict],
218
+ data_or_path: Union["np.ndarray", str, pathlib.Path, "TextIO", dict],
218
219
  caption: Optional[str] = None,
219
220
  **kwargs: Optional[Union[str, "FileFormat3D"]],
220
221
  ) -> None:
221
222
  super().__init__(caption=caption)
222
223
 
223
- if hasattr(data_or_path, "name"):
224
- # if the file has a path, we just detect the type and copy it from there
224
+ if hasattr(data_or_path, "name") and not isinstance(data_or_path, pathlib.Path):
225
+ # if the file has a path, we just detect the type and copy it from there.
226
+ # this does not work for pathlib.Path objects,
227
+ # where `.name` returns the last directory in the path.
225
228
  data_or_path = data_or_path.name
226
229
 
227
230
  if hasattr(data_or_path, "read"):
@@ -247,7 +250,9 @@ class Object3D(BatchableMedia):
247
250
  f.write(object_3d)
248
251
 
249
252
  self._set_file(tmp_path, is_tmp=True, extension=extension)
250
- elif isinstance(data_or_path, str):
253
+ elif isinstance(data_or_path, (str, pathlib.Path)):
254
+ data_or_path = str(data_or_path)
255
+
251
256
  path = data_or_path
252
257
  extension = None
253
258
  for supported_type in Object3D.SUPPORTED_TYPES:
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ import pathlib
4
5
  import shutil
5
6
  import sys
6
7
  from types import ModuleType
@@ -72,10 +73,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
72
73
 
73
74
  _model_obj: SavedModelObjType | None
74
75
  _path: str | None
75
- _input_obj_or_path: SavedModelObjType | str
76
+ _input_obj_or_path: SavedModelObjType | str | pathlib.Path
76
77
 
77
78
  # Public Methods
78
- def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
79
+ def __init__(
80
+ self, obj_or_path: SavedModelObjType | str | pathlib.Path, **kwargs: Any
81
+ ) -> None:
79
82
  super().__init__()
80
83
  if self.__class__ == _SavedModel:
81
84
  raise TypeError(
@@ -84,9 +87,11 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
84
87
  self._model_obj = None
85
88
  self._path = None
86
89
  self._input_obj_or_path = obj_or_path
87
- input_is_path = isinstance(obj_or_path, str) and os.path.exists(obj_or_path)
90
+ input_is_path = isinstance(obj_or_path, (str, pathlib.Path)) and os.path.exists(
91
+ obj_or_path
92
+ )
88
93
  if input_is_path:
89
- assert isinstance(obj_or_path, str) # mypy
94
+ obj_or_path = str(obj_or_path)
90
95
  self._set_obj(self._deserialize(obj_or_path))
91
96
  else:
92
97
  self._set_obj(obj_or_path)
@@ -140,7 +145,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
140
145
  from wandb.sdk.wandb_run import Run
141
146
 
142
147
  if isinstance(run_or_artifact, Run):
143
- raise ValueError("SavedModel cannot be added to run - must use artifact")
148
+ raise TypeError("SavedModel cannot be added to run - must use artifact")
144
149
  artifact = run_or_artifact
145
150
  json_obj = {
146
151
  "type": self._log_type,
@@ -280,7 +285,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
280
285
 
281
286
  def __init__(
282
287
  self,
283
- obj_or_path: SavedModelObjType | str,
288
+ obj_or_path: SavedModelObjType | str | pathlib.Path,
284
289
  dep_py_files: list[str] | None = None,
285
290
  ):
286
291
  super().__init__(obj_or_path)