wandb 0.19.12rc1__py3-none-macosx_11_0_arm64.whl → 0.20.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/__init__.pyi +3 -6
  3. wandb/_iterutils.py +26 -7
  4. wandb/_pydantic/__init__.py +2 -1
  5. wandb/_pydantic/utils.py +7 -0
  6. wandb/agents/pyagent.py +9 -15
  7. wandb/analytics/sentry.py +1 -2
  8. wandb/apis/attrs.py +3 -4
  9. wandb/apis/importers/internals/util.py +1 -1
  10. wandb/apis/importers/validation.py +2 -2
  11. wandb/apis/importers/wandb.py +30 -25
  12. wandb/apis/normalize.py +2 -2
  13. wandb/apis/public/__init__.py +1 -0
  14. wandb/apis/public/api.py +37 -33
  15. wandb/apis/public/artifacts.py +103 -72
  16. wandb/apis/public/jobs.py +3 -2
  17. wandb/apis/public/registries/registries_search.py +4 -2
  18. wandb/apis/public/registries/registry.py +1 -1
  19. wandb/apis/public/registries/utils.py +9 -9
  20. wandb/apis/public/runs.py +18 -6
  21. wandb/automations/_filters/expressions.py +1 -1
  22. wandb/automations/_filters/operators.py +1 -1
  23. wandb/automations/_filters/run_metrics.py +1 -1
  24. wandb/beta/workflows.py +6 -5
  25. wandb/bin/gpu_stats +0 -0
  26. wandb/bin/wandb-core +0 -0
  27. wandb/cli/cli.py +54 -73
  28. wandb/docker/__init__.py +21 -74
  29. wandb/docker/names.py +40 -0
  30. wandb/env.py +0 -1
  31. wandb/errors/util.py +1 -1
  32. wandb/filesync/step_checksum.py +1 -1
  33. wandb/filesync/step_upload.py +1 -1
  34. wandb/integration/diffusers/resolvers/multimodal.py +1 -2
  35. wandb/integration/gym/__init__.py +5 -6
  36. wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
  37. wandb/integration/keras/keras.py +13 -19
  38. wandb/integration/kfp/kfp_patch.py +2 -3
  39. wandb/integration/langchain/wandb_tracer.py +1 -1
  40. wandb/integration/metaflow/metaflow.py +13 -13
  41. wandb/integration/openai/fine_tuning.py +3 -2
  42. wandb/integration/sagemaker/auth.py +2 -1
  43. wandb/integration/sklearn/utils.py +2 -1
  44. wandb/integration/tensorboard/__init__.py +1 -1
  45. wandb/integration/tensorboard/log.py +2 -5
  46. wandb/integration/tensorflow/__init__.py +2 -2
  47. wandb/jupyter.py +20 -17
  48. wandb/plot/confusion_matrix.py +1 -1
  49. wandb/plot/utils.py +8 -7
  50. wandb/proto/v3/wandb_internal_pb2.py +355 -335
  51. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  52. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  53. wandb/proto/v4/wandb_internal_pb2.py +339 -335
  54. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  56. wandb/proto/v5/wandb_internal_pb2.py +339 -335
  57. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  59. wandb/proto/v6/wandb_internal_pb2.py +339 -335
  60. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
  62. wandb/proto/wandb_deprecated.py +6 -8
  63. wandb/sdk/artifacts/_internal_artifact.py +43 -0
  64. wandb/sdk/artifacts/_validators.py +55 -35
  65. wandb/sdk/artifacts/artifact.py +117 -115
  66. wandb/sdk/artifacts/artifact_download_logger.py +2 -0
  67. wandb/sdk/artifacts/artifact_saver.py +1 -3
  68. wandb/sdk/artifacts/artifact_state.py +2 -0
  69. wandb/sdk/artifacts/artifact_ttl.py +2 -0
  70. wandb/sdk/artifacts/exceptions.py +14 -0
  71. wandb/sdk/artifacts/staging.py +2 -0
  72. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
  73. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  74. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
  75. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
  76. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  77. wandb/sdk/artifacts/storage_layout.py +2 -0
  78. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
  79. wandb/sdk/backend/backend.py +11 -182
  80. wandb/sdk/data_types/_dtypes.py +2 -6
  81. wandb/sdk/data_types/audio.py +20 -3
  82. wandb/sdk/data_types/base_types/media.py +12 -7
  83. wandb/sdk/data_types/base_types/wb_value.py +8 -18
  84. wandb/sdk/data_types/bokeh.py +19 -2
  85. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
  86. wandb/sdk/data_types/helper_types/image_mask.py +7 -1
  87. wandb/sdk/data_types/html.py +4 -4
  88. wandb/sdk/data_types/image.py +164 -103
  89. wandb/sdk/data_types/molecule.py +6 -6
  90. wandb/sdk/data_types/object_3d.py +10 -5
  91. wandb/sdk/data_types/saved_model.py +11 -6
  92. wandb/sdk/data_types/table.py +313 -83
  93. wandb/sdk/data_types/table_decorators.py +108 -0
  94. wandb/sdk/data_types/utils.py +43 -7
  95. wandb/sdk/data_types/video.py +21 -3
  96. wandb/sdk/interface/interface.py +10 -0
  97. wandb/sdk/internal/datastore.py +2 -6
  98. wandb/sdk/internal/file_pusher.py +1 -5
  99. wandb/sdk/internal/file_stream.py +8 -17
  100. wandb/sdk/internal/handler.py +2 -2
  101. wandb/sdk/internal/incremental_table_util.py +53 -0
  102. wandb/sdk/internal/internal.py +3 -5
  103. wandb/sdk/internal/internal_api.py +66 -89
  104. wandb/sdk/internal/job_builder.py +2 -7
  105. wandb/sdk/internal/profiler.py +2 -2
  106. wandb/sdk/internal/progress.py +1 -3
  107. wandb/sdk/internal/run.py +1 -6
  108. wandb/sdk/internal/sender.py +24 -36
  109. wandb/sdk/internal/system/assets/aggregators.py +1 -7
  110. wandb/sdk/internal/system/assets/disk.py +3 -3
  111. wandb/sdk/internal/system/assets/gpu.py +4 -4
  112. wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
  113. wandb/sdk/internal/system/assets/interfaces.py +6 -6
  114. wandb/sdk/internal/system/assets/tpu.py +1 -1
  115. wandb/sdk/internal/system/assets/trainium.py +6 -6
  116. wandb/sdk/internal/system/system_info.py +5 -7
  117. wandb/sdk/internal/system/system_monitor.py +4 -4
  118. wandb/sdk/internal/tb_watcher.py +5 -7
  119. wandb/sdk/launch/_launch.py +1 -1
  120. wandb/sdk/launch/_project_spec.py +19 -20
  121. wandb/sdk/launch/agent/agent.py +3 -3
  122. wandb/sdk/launch/agent/config.py +1 -1
  123. wandb/sdk/launch/agent/job_status_tracker.py +2 -2
  124. wandb/sdk/launch/builder/build.py +2 -3
  125. wandb/sdk/launch/builder/kaniko_builder.py +5 -4
  126. wandb/sdk/launch/environment/gcp_environment.py +1 -2
  127. wandb/sdk/launch/registry/azure_container_registry.py +2 -2
  128. wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
  129. wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
  130. wandb/sdk/launch/runner/abstract.py +5 -5
  131. wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
  132. wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
  133. wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
  134. wandb/sdk/launch/runner/vertex_runner.py +2 -7
  135. wandb/sdk/launch/sweeps/__init__.py +1 -1
  136. wandb/sdk/launch/sweeps/scheduler.py +2 -2
  137. wandb/sdk/launch/sweeps/utils.py +3 -3
  138. wandb/sdk/launch/utils.py +3 -4
  139. wandb/sdk/lib/apikey.py +5 -8
  140. wandb/sdk/lib/config_util.py +3 -3
  141. wandb/sdk/lib/fsm.py +3 -18
  142. wandb/sdk/lib/gitlib.py +6 -5
  143. wandb/sdk/lib/ipython.py +2 -2
  144. wandb/sdk/lib/json_util.py +9 -14
  145. wandb/sdk/lib/printer.py +3 -8
  146. wandb/sdk/lib/redirect.py +1 -1
  147. wandb/sdk/lib/retry.py +3 -7
  148. wandb/sdk/lib/run_moment.py +2 -2
  149. wandb/sdk/lib/service_connection.py +3 -1
  150. wandb/sdk/lib/service_token.py +1 -2
  151. wandb/sdk/mailbox/mailbox_handle.py +3 -7
  152. wandb/sdk/mailbox/response_handle.py +2 -6
  153. wandb/sdk/service/streams.py +3 -7
  154. wandb/sdk/verify/verify.py +5 -6
  155. wandb/sdk/wandb_config.py +1 -1
  156. wandb/sdk/wandb_init.py +38 -106
  157. wandb/sdk/wandb_login.py +7 -6
  158. wandb/sdk/wandb_run.py +52 -240
  159. wandb/sdk/wandb_settings.py +71 -60
  160. wandb/sdk/wandb_setup.py +40 -14
  161. wandb/sdk/wandb_watch.py +5 -7
  162. wandb/sync/__init__.py +1 -1
  163. wandb/sync/sync.py +13 -13
  164. wandb/util.py +17 -35
  165. wandb/wandb_agent.py +8 -11
  166. {wandb-0.19.12rc1.dist-info → wandb-0.20.0.dist-info}/METADATA +5 -5
  167. {wandb-0.19.12rc1.dist-info → wandb-0.20.0.dist-info}/RECORD +170 -168
  168. wandb/docker/auth.py +0 -435
  169. wandb/docker/www_authenticate.py +0 -94
  170. {wandb-0.19.12rc1.dist-info → wandb-0.20.0.dist-info}/WHEEL +0 -0
  171. {wandb-0.19.12rc1.dist-info → wandb-0.20.0.dist-info}/entry_points.txt +0 -0
  172. {wandb-0.19.12rc1.dist-info → wandb-0.20.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,13 @@
1
1
  import hashlib
2
2
  import logging
3
3
  import os
4
+ import pathlib
4
5
  from io import BytesIO
5
6
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union, cast
6
7
  from urllib import parse
7
8
 
9
+ from packaging.version import parse as parse_version
10
+
8
11
  import wandb
9
12
  from wandb import util
10
13
  from wandb.sdk.lib import hashutil, runid
@@ -30,10 +33,60 @@ if TYPE_CHECKING: # pragma: no cover
30
33
  ImageDataType = Union[
31
34
  "matplotlib.artist.Artist", "PILImage", "TorchTensorType", "np.ndarray"
32
35
  ]
33
- ImageDataOrPathType = Union[str, "Image", ImageDataType]
36
+ ImageDataOrPathType = Union[str, pathlib.Path, "Image", ImageDataType]
34
37
  TorchTensorType = Union["torch.Tensor", "torch.Variable"]
35
38
 
36
39
 
40
+ def _warn_on_invalid_data_range(
41
+ data: "np.ndarray",
42
+ normalize: bool = True,
43
+ ) -> None:
44
+ if not normalize:
45
+ return
46
+
47
+ np = util.get_module(
48
+ "numpy",
49
+ required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
50
+ )
51
+
52
+ if np.min(data) < 0 or np.max(data) > 255:
53
+ wandb.termwarn(
54
+ "Data passed to `wandb.Image` should consist of values in the range [0, 255], "
55
+ "image data will be normalized to this range, "
56
+ "but behavior will be removed in a future version of wandb.",
57
+ repeat=False,
58
+ )
59
+
60
+
61
+ def _normalize(data: "np.ndarray") -> "np.ndarray":
62
+ """Normalizes and converts image pixel values to uint8 in the range [0, 255]."""
63
+ np = util.get_module(
64
+ "numpy",
65
+ required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
66
+ )
67
+
68
+ # if an image has negative values, set all values to be in the range [0, 1]
69
+ # This can lead to inconsistent behavior when an image has only a single negative value
70
+ if np.min(data) < 0:
71
+ data = data - np.min(data)
72
+
73
+ if np.ptp(data) != 0:
74
+ data = data / np.ptp(data)
75
+
76
+ if np.max(data) <= 1.0:
77
+ data = (255 * data).astype(np.int32)
78
+
79
+ return data.clip(0, 255)
80
+
81
+
82
+ def _convert_to_uint8(data: "np.ndarray") -> "np.ndarray":
83
+ np = util.get_module(
84
+ "numpy",
85
+ required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
86
+ )
87
+ return data.astype(np.uint8)
88
+
89
+
37
90
  def _server_accepts_image_filenames(run: "LocalRun") -> bool:
38
91
  if run.offline:
39
92
  return True
@@ -43,10 +96,9 @@ def _server_accepts_image_filenames(run: "LocalRun") -> bool:
43
96
  max_cli_version = util._get_max_cli_version()
44
97
  if max_cli_version is None:
45
98
  return False
46
- from wandb.util import parse_version
47
99
 
48
- accepts_image_filenames: bool = parse_version("0.12.10") <= parse_version(
49
- max_cli_version
100
+ accepts_image_filenames: bool = parse_version(max_cli_version) >= parse_version(
101
+ "0.12.10"
50
102
  )
51
103
  return accepts_image_filenames
52
104
 
@@ -59,69 +111,11 @@ def _server_accepts_artifact_path(run: "LocalRun") -> bool:
59
111
  if max_cli_version is None:
60
112
  return False
61
113
 
62
- return util.parse_version("0.12.14") <= util.parse_version(max_cli_version)
114
+ return parse_version(max_cli_version) >= parse_version("0.12.14")
63
115
 
64
116
 
65
117
  class Image(BatchableMedia):
66
- """Format images for logging to W&B.
67
-
68
- Args:
69
- data_or_path: (numpy array, string, io) Accepts numpy array of
70
- image data, or a PIL image. The class attempts to infer
71
- the data format and converts it.
72
- mode: (string) The PIL mode for an image. Most common are "L", "RGB",
73
- "RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
74
- caption: (string) Label for display of image.
75
-
76
- Note : When logging a `torch.Tensor` as a `wandb.Image`, images are normalized. If you do not want to normalize your images, please convert your tensors to a PIL Image.
77
-
78
- Examples:
79
- ### Create a wandb.Image from a numpy array
80
- ```python
81
- import numpy as np
82
- import wandb
83
-
84
- with wandb.init() as run:
85
- examples = []
86
- for i in range(3):
87
- pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
88
- image = wandb.Image(pixels, caption=f"random field {i}")
89
- examples.append(image)
90
- run.log({"examples": examples})
91
- ```
92
-
93
- ### Create a wandb.Image from a PILImage
94
- ```python
95
- import numpy as np
96
- from PIL import Image as PILImage
97
- import wandb
98
-
99
- with wandb.init() as run:
100
- examples = []
101
- for i in range(3):
102
- pixels = np.random.randint(
103
- low=0, high=256, size=(100, 100, 3), dtype=np.uint8
104
- )
105
- pil_image = PILImage.fromarray(pixels, mode="RGB")
106
- image = wandb.Image(pil_image, caption=f"random field {i}")
107
- examples.append(image)
108
- run.log({"examples": examples})
109
- ```
110
-
111
- ### log .jpg rather than .png (default)
112
- ```python
113
- import numpy as np
114
- import wandb
115
-
116
- with wandb.init() as run:
117
- examples = []
118
- for i in range(3):
119
- pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
120
- image = wandb.Image(pixels, caption=f"random field {i}", file_type="jpg")
121
- examples.append(image)
122
- run.log({"examples": examples})
123
- ```
124
- """
118
+ """A class for logging images to W&B."""
125
119
 
126
120
  MAX_ITEMS = 108
127
121
 
@@ -151,7 +145,85 @@ class Image(BatchableMedia):
151
145
  boxes: Optional[Union[Dict[str, "BoundingBoxes2D"], Dict[str, dict]]] = None,
152
146
  masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
153
147
  file_type: Optional[str] = None,
148
+ normalize: bool = True,
154
149
  ) -> None:
150
+ """Initialize a wandb.Image object.
151
+
152
+ Args:
153
+ data_or_path: Accepts numpy array/pytorch tensor of image data,
154
+ a PIL image object, or a path to an image file.
155
+
156
+ If a numpy array or pytorch tensor is provided,
157
+ the image data will be saved to the given file type.
158
+ If the values are not in the range [0, 255] or all values are in the range [0, 1],
159
+ the image pixel values will be normalized to the range [0, 255]
160
+ unless `normalize` is set to False.
161
+ - pytorch tensor should be in the format (channel, height, width)
162
+ - numpy array should be in the format (height, width, channel)
163
+ mode: The PIL mode for an image. Most common are "L", "RGB",
164
+ "RGBA". Full explanation at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
165
+ caption: Label for display of image.
166
+ grouping: The grouping number for the image.
167
+ classes: A list of class information for the image,
168
+ used for labeling bounding boxes, and image masks.
169
+ boxes: A dictionary containing bounding box information for the image.
170
+ see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
171
+ masks: A dictionary containing mask information for the image.
172
+ see: https://docs.wandb.ai/ref/python/data-types/imagemask/
173
+ file_type: The file type to save the image as.
174
+ This parameter has no effect if data_or_path is a path to an image file.
175
+ normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
176
+ Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
177
+
178
+ Examples:
179
+ ### Create a wandb.Image from a numpy array
180
+ ```python
181
+ import numpy as np
182
+ import wandb
183
+
184
+ with wandb.init() as run:
185
+ examples = []
186
+ for i in range(3):
187
+ pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
188
+ image = wandb.Image(pixels, caption=f"random field {i}")
189
+ examples.append(image)
190
+ run.log({"examples": examples})
191
+ ```
192
+
193
+ ### Create a wandb.Image from a PILImage
194
+ ```python
195
+ import numpy as np
196
+ from PIL import Image as PILImage
197
+ import wandb
198
+
199
+ with wandb.init() as run:
200
+ examples = []
201
+ for i in range(3):
202
+ pixels = np.random.randint(
203
+ low=0, high=256, size=(100, 100, 3), dtype=np.uint8
204
+ )
205
+ pil_image = PILImage.fromarray(pixels, mode="RGB")
206
+ image = wandb.Image(pil_image, caption=f"random field {i}")
207
+ examples.append(image)
208
+ run.log({"examples": examples})
209
+ ```
210
+
211
+ ### log .jpg rather than .png (default)
212
+ ```python
213
+ import numpy as np
214
+ import wandb
215
+
216
+ with wandb.init() as run:
217
+ examples = []
218
+ for i in range(3):
219
+ pixels = np.random.randint(low=0, high=256, size=(100, 100, 3))
220
+ image = wandb.Image(
221
+ pixels, caption=f"random field {i}", file_type="jpg"
222
+ )
223
+ examples.append(image)
224
+ run.log({"examples": examples})
225
+ ```
226
+ """
155
227
  super().__init__(caption=caption)
156
228
  # TODO: We should remove grouping, it's a terrible name and I don't
157
229
  # think anyone uses it.
@@ -169,13 +241,15 @@ class Image(BatchableMedia):
169
241
  # only overriding additional metadata passed in. If this pattern is compelling, we can generalize.
170
242
  if isinstance(data_or_path, Image):
171
243
  self._initialize_from_wbimage(data_or_path)
172
- elif isinstance(data_or_path, str):
244
+ elif isinstance(data_or_path, (str, pathlib.Path)):
245
+ data_or_path = str(data_or_path)
246
+
173
247
  if self.path_is_reference(data_or_path):
174
248
  self._initialize_from_reference(data_or_path)
175
249
  else:
176
250
  self._initialize_from_path(data_or_path)
177
251
  else:
178
- self._initialize_from_data(data_or_path, mode, file_type)
252
+ self._initialize_from_data(data_or_path, mode, file_type, normalize)
179
253
  self._set_initialization_meta(
180
254
  grouping, caption, classes, boxes, masks, file_type
181
255
  )
@@ -288,6 +362,7 @@ class Image(BatchableMedia):
288
362
  data: "ImageDataType",
289
363
  mode: Optional[str] = None,
290
364
  file_type: Optional[str] = None,
365
+ normalize: bool = True,
291
366
  ) -> None:
292
367
  pil_image = util.get_module(
293
368
  "PIL.Image",
@@ -309,28 +384,39 @@ class Image(BatchableMedia):
309
384
  elif isinstance(data, pil_image.Image):
310
385
  self._image = data
311
386
  elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
312
- vis_util = util.get_module(
313
- "torchvision.utils", "torchvision is required to render images"
314
- )
315
387
  if hasattr(data, "requires_grad") and data.requires_grad:
316
388
  data = data.detach() # type: ignore
317
389
  if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
318
- data = data.to(float)
319
- data = vis_util.make_grid(data, normalize=True)
390
+ data = data.to(float) # type: ignore [union-attr]
320
391
  mode = mode or self.guess_mode(data, file_type)
392
+ data = data.permute(1, 2, 0).cpu().numpy() # type: ignore [union-attr]
393
+
394
+ _warn_on_invalid_data_range(data, normalize)
395
+
396
+ data = _normalize(data) if normalize else data # type: ignore [arg-type]
397
+ data = _convert_to_uint8(data)
398
+
399
+ if data.ndim > 2:
400
+ data = data.squeeze()
401
+
321
402
  self._image = pil_image.fromarray(
322
- data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(),
403
+ data,
323
404
  mode=mode,
324
405
  )
325
406
  else:
326
407
  if hasattr(data, "numpy"): # TF data eager tensors
327
408
  data = data.numpy()
328
- if data.ndim > 2:
329
- data = data.squeeze() # get rid of trivial dimensions as a convenience
409
+ if data.ndim > 2: # type: ignore [union-attr]
410
+ # get rid of trivial dimensions as a convenience
411
+ data = data.squeeze() # type: ignore [union-attr]
412
+
413
+ _warn_on_invalid_data_range(data, normalize) # type: ignore [arg-type]
330
414
 
331
415
  mode = mode or self.guess_mode(data, file_type)
416
+ data = _normalize(data) if normalize else data # type: ignore [arg-type]
417
+ data = _convert_to_uint8(data) # type: ignore [arg-type]
332
418
  self._image = pil_image.fromarray(
333
- self.to_uint8(data),
419
+ data,
334
420
  mode=mode,
335
421
  )
336
422
 
@@ -459,7 +545,7 @@ class Image(BatchableMedia):
459
545
  }
460
546
 
461
547
  elif not isinstance(run_or_artifact, Run):
462
- raise ValueError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
548
+ raise TypeError("to_json accepts wandb_run.Run or wandb_artifact.Artifact")
463
549
 
464
550
  if self._boxes:
465
551
  json_dict["boxes"] = {
@@ -485,7 +571,7 @@ class Image(BatchableMedia):
485
571
  else:
486
572
  num_channels = data.shape[-1]
487
573
 
488
- if ndims == 2:
574
+ if ndims == 2 or num_channels == 1:
489
575
  return "L"
490
576
  elif num_channels == 3:
491
577
  return "RGB"
@@ -501,34 +587,9 @@ class Image(BatchableMedia):
501
587
  return "RGBA"
502
588
  else:
503
589
  raise ValueError(
504
- "Un-supported shape for image conversion {}".format(list(data.shape))
590
+ f"Un-supported shape for image conversion {list(data.shape)}"
505
591
  )
506
592
 
507
- @classmethod
508
- def to_uint8(cls, data: "np.ndarray") -> "np.ndarray":
509
- """Convert image data to uint8.
510
-
511
- Convert floating point image on the range [0,1] and integer images on the range
512
- [0,255] to uint8, clipping if necessary.
513
- """
514
- np = util.get_module(
515
- "numpy",
516
- required="wandb.Image requires numpy if not supplying PIL Images: pip install numpy",
517
- )
518
-
519
- # I think it's better to check the image range vs the data type, since many
520
- # image libraries will return floats between 0 and 255
521
-
522
- # some images have range -1...1 or 0-1
523
- dmin = np.min(data)
524
- if dmin < 0:
525
- data = (data - np.min(data)) / np.ptp(data)
526
- if np.max(data) <= 1.0:
527
- data = (data * 255).astype(np.int32)
528
-
529
- # assert issubclass(data.dtype.type, np.integer), 'Illegal image format.'
530
- return data.clip(0, 255).astype(np.uint8)
531
-
532
593
  @classmethod
533
594
  def seq_to_json(
534
595
  cls: Type["Image"],
@@ -26,7 +26,7 @@ class Molecule(BatchableMedia):
26
26
  """Wandb class for 3D Molecular data.
27
27
 
28
28
  Args:
29
- data_or_path: (string, io)
29
+ data_or_path: (pathlib.Path, string, io)
30
30
  Molecule can be initialized from a file name or an io object.
31
31
  caption: (string)
32
32
  Caption associated with the molecule for display.
@@ -49,7 +49,7 @@ class Molecule(BatchableMedia):
49
49
 
50
50
  def __init__(
51
51
  self,
52
- data_or_path: Union[str, "TextIO"],
52
+ data_or_path: Union[str, pathlib.Path, "TextIO"],
53
53
  caption: Optional[str] = None,
54
54
  **kwargs: str,
55
55
  ) -> None:
@@ -82,7 +82,9 @@ class Molecule(BatchableMedia):
82
82
  f.write(molecule)
83
83
 
84
84
  self._set_file(tmp_path, is_tmp=True)
85
- elif isinstance(data_or_path, str):
85
+ elif isinstance(data_or_path, (str, pathlib.Path)):
86
+ data_or_path = str(data_or_path)
87
+
86
88
  extension = os.path.splitext(data_or_path)[1][1:]
87
89
  if extension not in Molecule.SUPPORTED_TYPES:
88
90
  raise ValueError(
@@ -144,9 +146,7 @@ class Molecule(BatchableMedia):
144
146
  elif isinstance(data_or_path, rdkit_chem.rdchem.Mol):
145
147
  molecule = data_or_path
146
148
  else:
147
- raise ValueError(
148
- "Data must be file name or an rdkit.Chem.rdchem.Mol object"
149
- )
149
+ raise TypeError("Data must be file name or an rdkit.Chem.rdchem.Mol object")
150
150
 
151
151
  if convert_to_3d_and_optimize:
152
152
  molecule = rdkit_chem.AddHs(molecule)
@@ -2,6 +2,7 @@ import codecs
2
2
  import itertools
3
3
  import json
4
4
  import os
5
+ import pathlib
5
6
  from typing import (
6
7
  TYPE_CHECKING,
7
8
  ClassVar,
@@ -187,7 +188,7 @@ class Object3D(BatchableMedia):
187
188
  """Wandb class for 3D point clouds.
188
189
 
189
190
  Args:
190
- data_or_path: (numpy array, string, io)
191
+ data_or_path: (numpy array, pathlib.Path, string, io)
191
192
  Object3D can be initialized from a file or a numpy array.
192
193
 
193
194
  You can pass a path to a file or an io object and a file_type
@@ -214,14 +215,16 @@ class Object3D(BatchableMedia):
214
215
 
215
216
  def __init__(
216
217
  self,
217
- data_or_path: Union["np.ndarray", str, "TextIO", dict],
218
+ data_or_path: Union["np.ndarray", str, pathlib.Path, "TextIO", dict],
218
219
  caption: Optional[str] = None,
219
220
  **kwargs: Optional[Union[str, "FileFormat3D"]],
220
221
  ) -> None:
221
222
  super().__init__(caption=caption)
222
223
 
223
- if hasattr(data_or_path, "name"):
224
- # if the file has a path, we just detect the type and copy it from there
224
+ if hasattr(data_or_path, "name") and not isinstance(data_or_path, pathlib.Path):
225
+ # if the file has a path, we just detect the type and copy it from there.
226
+ # this does not work for pathlib.Path objects,
227
+ # where `.name` returns the last directory in the path.
225
228
  data_or_path = data_or_path.name
226
229
 
227
230
  if hasattr(data_or_path, "read"):
@@ -247,7 +250,9 @@ class Object3D(BatchableMedia):
247
250
  f.write(object_3d)
248
251
 
249
252
  self._set_file(tmp_path, is_tmp=True, extension=extension)
250
- elif isinstance(data_or_path, str):
253
+ elif isinstance(data_or_path, (str, pathlib.Path)):
254
+ data_or_path = str(data_or_path)
255
+
251
256
  path = data_or_path
252
257
  extension = None
253
258
  for supported_type in Object3D.SUPPORTED_TYPES:
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
+ import pathlib
4
5
  import shutil
5
6
  import sys
6
7
  from types import ModuleType
@@ -72,10 +73,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
72
73
 
73
74
  _model_obj: SavedModelObjType | None
74
75
  _path: str | None
75
- _input_obj_or_path: SavedModelObjType | str
76
+ _input_obj_or_path: SavedModelObjType | str | pathlib.Path
76
77
 
77
78
  # Public Methods
78
- def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
79
+ def __init__(
80
+ self, obj_or_path: SavedModelObjType | str | pathlib.Path, **kwargs: Any
81
+ ) -> None:
79
82
  super().__init__()
80
83
  if self.__class__ == _SavedModel:
81
84
  raise TypeError(
@@ -84,9 +87,11 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
84
87
  self._model_obj = None
85
88
  self._path = None
86
89
  self._input_obj_or_path = obj_or_path
87
- input_is_path = isinstance(obj_or_path, str) and os.path.exists(obj_or_path)
90
+ input_is_path = isinstance(obj_or_path, (str, pathlib.Path)) and os.path.exists(
91
+ obj_or_path
92
+ )
88
93
  if input_is_path:
89
- assert isinstance(obj_or_path, str) # mypy
94
+ obj_or_path = str(obj_or_path)
90
95
  self._set_obj(self._deserialize(obj_or_path))
91
96
  else:
92
97
  self._set_obj(obj_or_path)
@@ -140,7 +145,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
140
145
  from wandb.sdk.wandb_run import Run
141
146
 
142
147
  if isinstance(run_or_artifact, Run):
143
- raise ValueError("SavedModel cannot be added to run - must use artifact")
148
+ raise TypeError("SavedModel cannot be added to run - must use artifact")
144
149
  artifact = run_or_artifact
145
150
  json_obj = {
146
151
  "type": self._log_type,
@@ -280,7 +285,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
280
285
 
281
286
  def __init__(
282
287
  self,
283
- obj_or_path: SavedModelObjType | str,
288
+ obj_or_path: SavedModelObjType | str | pathlib.Path,
284
289
  dep_py_files: list[str] | None = None,
285
290
  ):
286
291
  super().__init__(obj_or_path)