wandb 0.19.8__py3-none-win_amd64.whl → 0.19.10__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/__init__.pyi +15 -8
  3. wandb/_pydantic/__init__.py +30 -0
  4. wandb/_pydantic/base.py +148 -0
  5. wandb/_pydantic/utils.py +66 -0
  6. wandb/_pydantic/v1_compat.py +284 -0
  7. wandb/apis/paginator.py +82 -38
  8. wandb/apis/public/__init__.py +2 -2
  9. wandb/apis/public/api.py +111 -53
  10. wandb/apis/public/artifacts.py +387 -639
  11. wandb/apis/public/automations.py +69 -0
  12. wandb/apis/public/files.py +2 -2
  13. wandb/apis/public/integrations.py +168 -0
  14. wandb/apis/public/projects.py +32 -2
  15. wandb/apis/public/reports.py +2 -2
  16. wandb/apis/public/runs.py +19 -11
  17. wandb/apis/public/utils.py +107 -1
  18. wandb/automations/__init__.py +81 -0
  19. wandb/automations/_filters/__init__.py +40 -0
  20. wandb/automations/_filters/expressions.py +179 -0
  21. wandb/automations/_filters/operators.py +267 -0
  22. wandb/automations/_filters/run_metrics.py +183 -0
  23. wandb/automations/_generated/__init__.py +184 -0
  24. wandb/automations/_generated/create_filter_trigger.py +21 -0
  25. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  26. wandb/automations/_generated/delete_trigger.py +19 -0
  27. wandb/automations/_generated/enums.py +33 -0
  28. wandb/automations/_generated/fragments.py +343 -0
  29. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  30. wandb/automations/_generated/get_triggers.py +24 -0
  31. wandb/automations/_generated/get_triggers_by_entity.py +24 -0
  32. wandb/automations/_generated/input_types.py +104 -0
  33. wandb/automations/_generated/integrations_by_entity.py +22 -0
  34. wandb/automations/_generated/operations.py +710 -0
  35. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  36. wandb/automations/_generated/update_filter_trigger.py +21 -0
  37. wandb/automations/_utils.py +123 -0
  38. wandb/automations/_validators.py +73 -0
  39. wandb/automations/actions.py +205 -0
  40. wandb/automations/automations.py +109 -0
  41. wandb/automations/events.py +235 -0
  42. wandb/automations/integrations.py +26 -0
  43. wandb/automations/scopes.py +76 -0
  44. wandb/beta/workflows.py +9 -10
  45. wandb/bin/gpu_stats.exe +0 -0
  46. wandb/bin/wandb-core +0 -0
  47. wandb/cli/cli.py +3 -3
  48. wandb/integration/keras/keras.py +2 -1
  49. wandb/integration/langchain/wandb_tracer.py +2 -1
  50. wandb/integration/metaflow/metaflow.py +19 -17
  51. wandb/integration/sacred/__init__.py +1 -1
  52. wandb/jupyter.py +155 -133
  53. wandb/old/summary.py +0 -2
  54. wandb/proto/v3/wandb_internal_pb2.py +297 -292
  55. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/v4/wandb_internal_pb2.py +292 -292
  58. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  59. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  60. wandb/proto/v5/wandb_internal_pb2.py +292 -292
  61. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  62. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  63. wandb/proto/v6/wandb_base_pb2.py +41 -0
  64. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  65. wandb/proto/v6/wandb_server_pb2.py +78 -0
  66. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  67. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  68. wandb/proto/wandb_base_pb2.py +2 -0
  69. wandb/proto/wandb_deprecated.py +10 -0
  70. wandb/proto/wandb_internal_pb2.py +3 -1
  71. wandb/proto/wandb_server_pb2.py +2 -0
  72. wandb/proto/wandb_settings_pb2.py +2 -0
  73. wandb/proto/wandb_telemetry_pb2.py +2 -0
  74. wandb/sdk/artifacts/_generated/__init__.py +248 -0
  75. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  76. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  77. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  78. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  79. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  80. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  81. wandb/sdk/artifacts/_generated/enums.py +17 -0
  82. wandb/sdk/artifacts/_generated/fragments.py +186 -0
  83. wandb/sdk/artifacts/_generated/input_types.py +16 -0
  84. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  85. wandb/sdk/artifacts/_generated/operations.py +510 -0
  86. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  87. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  88. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  89. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  90. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  91. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  92. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  93. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  94. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  95. wandb/sdk/artifacts/_graphql_fragments.py +56 -81
  96. wandb/sdk/artifacts/_validators.py +1 -0
  97. wandb/sdk/artifacts/artifact.py +110 -49
  98. wandb/sdk/artifacts/artifact_manifest_entry.py +2 -1
  99. wandb/sdk/artifacts/artifact_saver.py +16 -2
  100. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  101. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
  102. wandb/sdk/data_types/audio.py +1 -3
  103. wandb/sdk/data_types/base_types/media.py +13 -7
  104. wandb/sdk/data_types/base_types/wb_value.py +34 -11
  105. wandb/sdk/data_types/html.py +36 -9
  106. wandb/sdk/data_types/image.py +56 -37
  107. wandb/sdk/data_types/molecule.py +1 -5
  108. wandb/sdk/data_types/object_3d.py +2 -1
  109. wandb/sdk/data_types/saved_model.py +7 -9
  110. wandb/sdk/data_types/table.py +5 -0
  111. wandb/sdk/data_types/trace_tree.py +2 -0
  112. wandb/sdk/data_types/utils.py +1 -1
  113. wandb/sdk/data_types/video.py +15 -30
  114. wandb/sdk/interface/interface.py +2 -0
  115. wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
  116. wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
  117. wandb/sdk/internal/internal_api.py +138 -47
  118. wandb/sdk/internal/profiler.py +6 -5
  119. wandb/sdk/internal/run.py +13 -6
  120. wandb/sdk/internal/sender.py +2 -0
  121. wandb/sdk/internal/sender_config.py +8 -11
  122. wandb/sdk/internal/settings_static.py +24 -2
  123. wandb/sdk/lib/apikey.py +40 -20
  124. wandb/sdk/lib/asyncio_compat.py +1 -1
  125. wandb/sdk/lib/deprecate.py +13 -22
  126. wandb/sdk/lib/disabled.py +2 -1
  127. wandb/sdk/lib/printer.py +37 -8
  128. wandb/sdk/lib/printer_asyncio.py +46 -0
  129. wandb/sdk/lib/redirect.py +10 -5
  130. wandb/sdk/lib/run_moment.py +4 -6
  131. wandb/sdk/lib/wb_logging.py +161 -0
  132. wandb/sdk/service/server_sock.py +19 -14
  133. wandb/sdk/service/service.py +9 -7
  134. wandb/sdk/service/streams.py +5 -0
  135. wandb/sdk/verify/verify.py +6 -3
  136. wandb/sdk/wandb_config.py +44 -43
  137. wandb/sdk/wandb_init.py +323 -141
  138. wandb/sdk/wandb_login.py +13 -4
  139. wandb/sdk/wandb_metadata.py +107 -91
  140. wandb/sdk/wandb_run.py +529 -325
  141. wandb/sdk/wandb_settings.py +422 -202
  142. wandb/sdk/wandb_setup.py +52 -1
  143. wandb/util.py +29 -29
  144. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/METADATA +7 -7
  145. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/RECORD +151 -94
  146. wandb/_globals.py +0 -19
  147. wandb/apis/public/_generated/base.py +0 -128
  148. wandb/apis/public/_generated/typing_compat.py +0 -14
  149. /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
  150. /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
  151. /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
  152. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/WHEEL +0 -0
  153. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/entry_points.txt +0 -0
  154. {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/licenses/LICENSE +0 -0
@@ -59,6 +59,8 @@ class ArtifactSaver:
59
59
 
60
60
  def save(
61
61
  self,
62
+ entity: str,
63
+ project: str,
62
64
  type: str,
63
65
  name: str,
64
66
  client_id: str,
@@ -76,6 +78,8 @@ class ArtifactSaver:
76
78
  base_id: str | None = None,
77
79
  ) -> dict | None:
78
80
  return self._save_internal(
81
+ entity,
82
+ project,
79
83
  type,
80
84
  name,
81
85
  client_id,
@@ -95,6 +99,8 @@ class ArtifactSaver:
95
99
 
96
100
  def _save_internal(
97
101
  self,
102
+ entity: str,
103
+ project: str,
98
104
  type: str,
99
105
  name: str,
100
106
  client_id: str,
@@ -140,7 +146,11 @@ class ArtifactSaver:
140
146
  base_id = latest["id"]
141
147
  if self._server_artifact["state"] == "COMMITTED":
142
148
  if use_after_commit:
143
- self._api.use_artifact(artifact_id)
149
+ self._api.use_artifact(
150
+ artifact_id,
151
+ artifact_entity_name=entity,
152
+ artifact_project_name=project,
153
+ )
144
154
  return self._server_artifact
145
155
  if (
146
156
  self._server_artifact["state"] != "PENDING"
@@ -244,7 +254,11 @@ class ArtifactSaver:
244
254
  step_prepare.shutdown()
245
255
 
246
256
  if finalize and use_after_commit:
247
- self._api.use_artifact(artifact_id)
257
+ self._api.use_artifact(
258
+ artifact_id,
259
+ artifact_entity_name=entity,
260
+ artifact_project_name=project,
261
+ )
248
262
 
249
263
  return self._server_artifact
250
264
 
@@ -171,6 +171,7 @@ class AzureHandler(StorageHandler):
171
171
  def _get_credential(
172
172
  self, account_url: str
173
173
  ) -> azure.identity.DefaultAzureCredential | str:
174
+ # NOTE: Always returns default credential for reinit="create_new" runs.
174
175
  if (
175
176
  wandb.run
176
177
  and wandb.run.settings.azure_account_url_to_access_key is not None
@@ -13,6 +13,7 @@ import requests
13
13
  import urllib3
14
14
 
15
15
  from wandb.errors.term import termwarn
16
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
16
17
  from wandb.sdk.artifacts.artifact_file_cache import (
17
18
  ArtifactFileCache,
18
19
  get_artifact_file_cache,
@@ -144,9 +145,14 @@ class WandbStoragePolicy(StoragePolicy):
144
145
  http_headers["Authorization"] = f"Bearer {self._api.access_token}"
145
146
  elif _thread_local_api_settings.cookies is None:
146
147
  auth = ("api", self._api.api_key or "")
147
-
148
148
  response = self._session.get(
149
- self._file_url(self._api, artifact.entity, manifest_entry),
149
+ self._file_url(
150
+ self._api,
151
+ artifact.entity,
152
+ artifact.project,
153
+ artifact.name.split(":")[0],
154
+ manifest_entry,
155
+ ),
150
156
  auth=auth,
151
157
  cookies=_thread_local_api_settings.cookies,
152
158
  headers=http_headers,
@@ -187,6 +193,8 @@ class WandbStoragePolicy(StoragePolicy):
187
193
  self,
188
194
  api: InternalApi,
189
195
  entity_name: str,
196
+ project_name: str,
197
+ artifact_name: str,
190
198
  manifest_entry: ArtifactManifestEntry,
191
199
  ) -> str:
192
200
  storage_layout = self._config.get("storageLayout", StorageLayout.V1)
@@ -198,6 +206,19 @@ class WandbStoragePolicy(StoragePolicy):
198
206
  api.settings("base_url"), entity_name, md5_hex
199
207
  )
200
208
  elif storage_layout == StorageLayout.V2:
209
+ if api._check_server_feature_with_fallback(
210
+ ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER # type: ignore
211
+ ):
212
+ return "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}".format(
213
+ api.settings("base_url"),
214
+ storage_region,
215
+ quote(entity_name),
216
+ quote(project_name),
217
+ quote(artifact_name),
218
+ quote(manifest_entry.birth_artifact_id or ""),
219
+ md5_hex,
220
+ manifest_entry.path.name,
221
+ )
201
222
  return "{}/artifactsV2/{}/{}/{}/{}".format(
202
223
  api.settings("base_url"),
203
224
  storage_region,
@@ -25,10 +25,9 @@ class Audio(BatchableMedia):
25
25
 
26
26
  def __init__(self, data_or_path, sample_rate=None, caption=None):
27
27
  """Accept a path to an audio file or a numpy array of audio data."""
28
- super().__init__()
28
+ super().__init__(caption=caption)
29
29
  self._duration = None
30
30
  self._sample_rate = sample_rate
31
- self._caption = caption
32
31
 
33
32
  if isinstance(data_or_path, str):
34
33
  if self.path_is_reference(data_or_path):
@@ -80,7 +79,6 @@ class Audio(BatchableMedia):
80
79
  json_dict.update(
81
80
  {
82
81
  "_type": self._log_type,
83
- "caption": self._caption,
84
82
  }
85
83
  )
86
84
  return json_dict
@@ -3,11 +3,10 @@ import os
3
3
  import platform
4
4
  import re
5
5
  import shutil
6
- from typing import TYPE_CHECKING, Optional, Sequence, Type, Union, cast
6
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union, cast
7
7
 
8
8
  import wandb
9
9
  from wandb import util
10
- from wandb._globals import _datatypes_callback
11
10
  from wandb.sdk.lib import filesystem
12
11
  from wandb.sdk.lib.paths import LogicalPath
13
12
 
@@ -192,7 +191,7 @@ class Media(WBValue):
192
191
  shutil.move(self._path, new_path)
193
192
  self._path = new_path
194
193
  self._is_tmp = False
195
- _datatypes_callback(media_path)
194
+ run._publish_file(media_path)
196
195
  else:
197
196
  try:
198
197
  shutil.copy(self._path, new_path)
@@ -200,7 +199,7 @@ class Media(WBValue):
200
199
  if not ignore_copy_err:
201
200
  raise e
202
201
  self._path = new_path
203
- _datatypes_callback(media_path)
202
+ run._publish_file(media_path)
204
203
 
205
204
  def to_json(self, run: Union["LocalRun", "Artifact"]) -> dict:
206
205
  """Serialize the object into a JSON blob.
@@ -222,7 +221,10 @@ class Media(WBValue):
222
221
  from wandb.data_types import Audio
223
222
  from wandb.sdk.wandb_run import Run
224
223
 
225
- json_obj = {}
224
+ json_obj: Dict[str, Any] = {}
225
+
226
+ if self._caption is not None:
227
+ json_obj["caption"] = self._caption
226
228
 
227
229
  if isinstance(run, Run):
228
230
  json_obj.update(
@@ -232,6 +234,7 @@ class Media(WBValue):
232
234
  "size": self._size,
233
235
  }
234
236
  )
237
+
235
238
  artifact_entry_url = self._get_artifact_entry_ref_url()
236
239
  if artifact_entry_url is not None:
237
240
  json_obj["artifact_path"] = artifact_entry_url
@@ -337,8 +340,11 @@ class BatchableMedia(Media):
337
340
  organize files by name in the media directory.
338
341
  """
339
342
 
340
- def __init__(self) -> None:
341
- super().__init__()
343
+ def __init__(
344
+ self,
345
+ caption: Optional[str] = None,
346
+ ) -> None:
347
+ super().__init__(caption=caption)
342
348
 
343
349
  @classmethod
344
350
  def seq_to_json(
@@ -1,7 +1,7 @@
1
1
  from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, Union
2
2
 
3
- import wandb
4
3
  from wandb import util
4
+ from wandb.sdk import wandb_setup
5
5
 
6
6
  if TYPE_CHECKING: # pragma: no cover
7
7
  from wandb.sdk.artifacts.artifact import Artifact
@@ -11,6 +11,31 @@ if TYPE_CHECKING: # pragma: no cover
11
11
  TypeMappingType = Dict[str, Type["WBValue"]]
12
12
 
13
13
 
14
+ def _is_maybe_offline() -> bool:
15
+ """Guess whether wandb is configured to be offline.
16
+
17
+ This is an anti-pattern because there is no library-level "offline" mode:
18
+ only runs can be offline. Online and offline runs can exist in the same
19
+ process. This function is a heuristic that works only if there is at most
20
+ one run in the process, and could otherwise produce unexpected results.
21
+
22
+ Returns:
23
+ Whether the user likely configured wandb to be offline.
24
+ """
25
+ singleton = wandb_setup._setup(start_service=False)
26
+
27
+ # First check: if there's a run, check if it is offline.
28
+ #
29
+ # This covers uses like `wandb.init(mode="offline")` which don't modify
30
+ # the singleton's settings.
31
+ if run := singleton.most_recent_active_run:
32
+ return run.offline
33
+
34
+ # Second check: default to global defaults derived from environment
35
+ # variables or passed explicitly to `wandb.setup()`.
36
+ return singleton.settings._offline
37
+
38
+
14
39
  def _server_accepts_client_ids() -> bool:
15
40
  from wandb.util import parse_version
16
41
 
@@ -25,15 +50,13 @@ def _server_accepts_client_ids() -> bool:
25
50
  # AS OF NOW, 2024/11/06, we assume that all customer's server deployments accept
26
51
  # client IDs.
27
52
 
28
- if util._is_offline():
29
- # If there are any users with issues on an older backend, customers can disable the
30
- # setting `allow_offline_artifacts` to revert the SDK's behavior back to not
31
- # using client IDs in offline mode.
32
- if wandb.run and not wandb.run.settings.allow_offline_artifacts:
33
- return False
34
- # Assume client IDs are accepted
53
+ if _is_maybe_offline():
54
+ singleton = wandb_setup._setup(start_service=False)
55
+
56
+ if run := singleton.most_recent_active_run:
57
+ return run._settings.allow_offline_artifacts
35
58
  else:
36
- return True
59
+ return singleton.settings.allow_offline_artifacts
37
60
 
38
61
  # If the script is online, request the max_cli_version and ensure the server
39
62
  # is of a high enough version.
@@ -240,7 +263,7 @@ class WBValue:
240
263
  self._artifact_target
241
264
  and self._artifact_target.name
242
265
  and self._artifact_target.artifact._is_draft_save_started()
243
- and not util._is_offline()
266
+ and not _is_maybe_offline()
244
267
  and not _server_accepts_client_ids()
245
268
  ):
246
269
  self._artifact_target.artifact.wait()
@@ -271,7 +294,7 @@ class WBValue:
271
294
  self._artifact_target
272
295
  and self._artifact_target.name
273
296
  and self._artifact_target.artifact._is_draft_save_started()
274
- and not util._is_offline()
297
+ and not _is_maybe_offline()
275
298
  and not _server_accepts_client_ids()
276
299
  ):
277
300
  self._artifact_target.artifact.wait()
@@ -16,19 +16,46 @@ if TYPE_CHECKING: # pragma: no cover
16
16
 
17
17
 
18
18
  class Html(BatchableMedia):
19
- """Wandb class for arbitrary html.
20
-
21
- Args:
22
- data: (string or io object) HTML to display in wandb
23
- inject: (boolean) Add a stylesheet to the HTML object. If set
24
- to False the HTML will pass through unchanged.
25
- """
19
+ """A class for logging HTML content to W&B."""
26
20
 
27
21
  _log_type = "html-file"
28
22
 
29
- def __init__(self, data: Union[str, "TextIO"], inject: bool = True) -> None:
23
+ def __init__(
24
+ self,
25
+ data: Union[str, "TextIO"],
26
+ inject: bool = True,
27
+ data_is_not_path: bool = False,
28
+ ) -> None:
29
+ """Creates a W&B HTML object.
30
+
31
+ It can be initialized by providing a path to a file:
32
+ ```
33
+ with wandb.init() as run:
34
+ run.log({"html": wandb.Html("./index.html")})
35
+ ```
36
+
37
+ Alternatively, it can be initialized by providing literal HTML,
38
+ in either a string or IO object:
39
+ ```
40
+ with wandb.init() as run:
41
+ run.log({"html": wandb.Html("<h1>Hello, world!</h1>")})
42
+ ```
43
+
44
+ Args:
45
+ data:
46
+ A string that is a path to a file with the extension ".html",
47
+ or a string or IO object containing literal HTML.
48
+ inject: Add a stylesheet to the HTML object. If set
49
+ to False the HTML will pass through unchanged.
50
+ data_is_not_path: If set to False, the data will be
51
+ treated as a path to a file.
52
+ """
30
53
  super().__init__()
31
- data_is_path = isinstance(data, str) and os.path.exists(data)
54
+ data_is_path = (
55
+ isinstance(data, str)
56
+ and os.path.isfile(data)
57
+ and os.path.splitext(data)[1] == ".html"
58
+ ) and not data_is_not_path
32
59
  data_path = ""
33
60
  if data_is_path:
34
61
  assert isinstance(data, str)
@@ -34,8 +34,8 @@ if TYPE_CHECKING: # pragma: no cover
34
34
  TorchTensorType = Union["torch.Tensor", "torch.Variable"]
35
35
 
36
36
 
37
- def _server_accepts_image_filenames() -> bool:
38
- if util._is_offline():
37
+ def _server_accepts_image_filenames(run: "LocalRun") -> bool:
38
+ if run.offline:
39
39
  return True
40
40
 
41
41
  # Newer versions of wandb accept large image filenames arrays
@@ -51,15 +51,15 @@ def _server_accepts_image_filenames() -> bool:
51
51
  return accepts_image_filenames
52
52
 
53
53
 
54
- def _server_accepts_artifact_path() -> bool:
55
- from wandb.util import parse_version
54
+ def _server_accepts_artifact_path(run: "LocalRun") -> bool:
55
+ if run.offline:
56
+ return False
57
+
58
+ max_cli_version = util._get_max_cli_version()
59
+ if max_cli_version is None:
60
+ return False
56
61
 
57
- target_version = "0.12.14"
58
- max_cli_version = util._get_max_cli_version() if not util._is_offline() else None
59
- accepts_artifact_path: bool = max_cli_version is not None and parse_version(
60
- target_version
61
- ) <= parse_version(max_cli_version)
62
- return accepts_artifact_path
62
+ return util.parse_version("0.12.14") <= util.parse_version(max_cli_version)
63
63
 
64
64
 
65
65
  class Image(BatchableMedia):
@@ -152,12 +152,11 @@ class Image(BatchableMedia):
152
152
  masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
153
153
  file_type: Optional[str] = None,
154
154
  ) -> None:
155
- super().__init__()
155
+ super().__init__(caption=caption)
156
156
  # TODO: We should remove grouping, it's a terrible name and I don't
157
157
  # think anyone uses it.
158
158
 
159
159
  self._grouping = None
160
- self._caption = None
161
160
  self._width = None
162
161
  self._height = None
163
162
  self._image = None
@@ -193,9 +192,6 @@ class Image(BatchableMedia):
193
192
  if grouping is not None:
194
193
  self._grouping = grouping
195
194
 
196
- if caption is not None:
197
- self._caption = caption
198
-
199
195
  total_classes = {}
200
196
 
201
197
  if boxes:
@@ -297,10 +293,19 @@ class Image(BatchableMedia):
297
293
  "PIL.Image",
298
294
  required='wandb.Image needs the PIL package. To get it, run "pip install pillow".',
299
295
  )
296
+
297
+ accepted_formats = ["png", "jpg", "jpeg", "bmp"]
298
+ self.format = file_type or "png"
299
+
300
+ if self.format not in accepted_formats:
301
+ raise ValueError(f"file_type must be one of {accepted_formats}")
302
+
303
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + "." + self.format)
304
+
300
305
  if util.is_matplotlib_typename(util.get_full_typename(data)):
301
306
  buf = BytesIO()
302
- util.ensure_matplotlib_figure(data).savefig(buf, format="png")
303
- self._image = pil_image.open(buf, formats=["PNG"])
307
+ util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
308
+ self._image = pil_image.open(buf)
304
309
  elif isinstance(data, pil_image.Image):
305
310
  self._image = data
306
311
  elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
@@ -312,26 +317,23 @@ class Image(BatchableMedia):
312
317
  if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
313
318
  data = data.to(float)
314
319
  data = vis_util.make_grid(data, normalize=True)
320
+ mode = mode or self.guess_mode(data, file_type)
315
321
  self._image = pil_image.fromarray(
316
- data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
322
+ data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(),
323
+ mode=mode,
317
324
  )
318
325
  else:
319
326
  if hasattr(data, "numpy"): # TF data eager tensors
320
327
  data = data.numpy()
321
328
  if data.ndim > 2:
322
329
  data = data.squeeze() # get rid of trivial dimensions as a convenience
330
+
331
+ mode = mode or self.guess_mode(data, file_type)
323
332
  self._image = pil_image.fromarray(
324
- self.to_uint8(data), mode=mode or self.guess_mode(data)
333
+ self.to_uint8(data),
334
+ mode=mode,
325
335
  )
326
- accepted_formats = ["png", "jpg", "jpeg", "bmp"]
327
- if file_type is None:
328
- self.format = "png"
329
- else:
330
- self.format = file_type
331
- assert (
332
- self.format in accepted_formats
333
- ), f"file_type must be one of {accepted_formats}"
334
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + "." + self.format)
336
+
335
337
  assert self._image is not None
336
338
  self._image.save(tmp_path, transparency=None)
337
339
  self._set_file(tmp_path, is_tmp=True)
@@ -399,7 +401,7 @@ class Image(BatchableMedia):
399
401
  )
400
402
 
401
403
  if (
402
- not _server_accepts_artifact_path()
404
+ not _server_accepts_artifact_path(run)
403
405
  or self._get_artifact_entry_ref_url() is None
404
406
  ):
405
407
  super().bind_to_run(run, key, step, id_, ignore_copy_err=ignore_copy_err)
@@ -430,8 +432,6 @@ class Image(BatchableMedia):
430
432
  json_dict["height"] = self._height
431
433
  if self._grouping:
432
434
  json_dict["grouping"] = self._grouping
433
- if self._caption:
434
- json_dict["caption"] = self._caption
435
435
 
436
436
  if isinstance(run_or_artifact, wandb.Artifact):
437
437
  artifact = run_or_artifact
@@ -471,15 +471,34 @@ class Image(BatchableMedia):
471
471
  }
472
472
  return json_dict
473
473
 
474
- def guess_mode(self, data: "np.ndarray") -> str:
474
+ def guess_mode(
475
+ self,
476
+ data: Union["np.ndarray", "torch.Tensor"],
477
+ file_type: Optional[str] = None,
478
+ ) -> str:
475
479
  """Guess what type of image the np.array is representing."""
476
480
  # TODO: do we want to support dimensions being at the beginning of the array?
477
- if data.ndim == 2:
481
+ ndims = data.ndim
482
+ if util.is_pytorch_tensor_typename(util.get_full_typename(data)):
483
+ # Torch tenors typically have the channels dimension first
484
+ num_channels = data.shape[0]
485
+ else:
486
+ num_channels = data.shape[-1]
487
+
488
+ if ndims == 2:
478
489
  return "L"
479
- elif data.shape[-1] == 3:
490
+ elif num_channels == 3:
480
491
  return "RGB"
481
- elif data.shape[-1] == 4:
482
- return "RGBA"
492
+ elif num_channels == 4:
493
+ if file_type in ["jpg", "jpeg"]:
494
+ wandb.termwarn(
495
+ "JPEG format does not support transparency. "
496
+ "Ignoring alpha channel.",
497
+ repeat=False,
498
+ )
499
+ return "RGB"
500
+ else:
501
+ return "RGBA"
483
502
  else:
484
503
  raise ValueError(
485
504
  "Un-supported shape for image conversion {}".format(list(data.shape))
@@ -556,7 +575,7 @@ class Image(BatchableMedia):
556
575
  "format": format,
557
576
  "count": num_images_to_log,
558
577
  }
559
- if _server_accepts_image_filenames():
578
+ if _server_accepts_image_filenames(run):
560
579
  meta["filenames"] = [
561
580
  obj.get("path", obj.get("artifact_path")) for obj in jsons
562
581
  ]
@@ -53,9 +53,7 @@ class Molecule(BatchableMedia):
53
53
  caption: Optional[str] = None,
54
54
  **kwargs: str,
55
55
  ) -> None:
56
- super().__init__()
57
-
58
- self._caption = caption
56
+ super().__init__(caption=caption)
59
57
 
60
58
  if hasattr(data_or_path, "name"):
61
59
  # if the file has a path, we just detect the type and copy it from there
@@ -208,8 +206,6 @@ class Molecule(BatchableMedia):
208
206
  def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
209
207
  json_dict = super().to_json(run_or_artifact)
210
208
  json_dict["_type"] = self._log_type
211
- if self._caption:
212
- json_dict["caption"] = self._caption
213
209
  return json_dict
214
210
 
215
211
  @classmethod
@@ -215,9 +215,10 @@ class Object3D(BatchableMedia):
215
215
  def __init__(
216
216
  self,
217
217
  data_or_path: Union["np.ndarray", str, "TextIO", dict],
218
+ caption: Optional[str] = None,
218
219
  **kwargs: Optional[Union[str, "FileFormat3D"]],
219
220
  ) -> None:
220
- super().__init__()
221
+ super().__init__(caption=caption)
221
222
 
222
223
  if hasattr(data_or_path, "name"):
223
224
  # if the file has a path, we just detect the type and copy it from there
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  import shutil
5
5
  import sys
6
+ from types import ModuleType
6
7
  from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
7
8
 
8
9
  import wandb
@@ -15,9 +16,6 @@ from ._private import MEDIA_TMP
15
16
  from .base_types.wb_value import WBValue
16
17
 
17
18
  if TYPE_CHECKING:
18
- from types import ModuleType
19
-
20
- import cloudpickle # type: ignore
21
19
  import sklearn # type: ignore
22
20
  import tensorflow # type: ignore
23
21
  import torch # type: ignore
@@ -264,9 +262,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
264
262
  self._serialize(self._model_obj, target_path)
265
263
 
266
264
 
267
- def _get_cloudpickle() -> "cloudpickle":
265
+ def _get_cloudpickle() -> ModuleType:
268
266
  return cast(
269
- "cloudpickle",
267
+ ModuleType,
270
268
  util.get_module("cloudpickle", "ModelAdapter requires `cloudpickle`"),
271
269
  )
272
270
 
@@ -338,9 +336,9 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
338
336
  return json_obj
339
337
 
340
338
 
341
- def _get_torch() -> "torch":
339
+ def _get_torch() -> ModuleType:
342
340
  return cast(
343
- "torch",
341
+ ModuleType,
344
342
  util.get_module("torch", "ModelAdapter requires `torch`"),
345
343
  )
346
344
 
@@ -366,9 +364,9 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
366
364
  )
367
365
 
368
366
 
369
- def _get_sklearn() -> "sklearn":
367
+ def _get_sklearn() -> ModuleType:
370
368
  return cast(
371
- "sklearn",
369
+ ModuleType,
372
370
  util.get_module("sklearn", "ModelAdapter requires `sklearn`"),
373
371
  )
374
372
 
@@ -480,6 +480,11 @@ class Table(Media):
480
480
  max_rows = Table.MAX_ROWS
481
481
  n_rows = len(self.data)
482
482
  if n_rows > max_rows and warn:
483
+ # NOTE: Never raises for reinit="create_new" runs.
484
+ # Since this is called by bind_to_run(), this can be fixed by
485
+ # propagating the run. It cannot be fixed for to_json() calls
486
+ # that are given an artifact, other than by deferring to singleton
487
+ # settings.
483
488
  if wandb.run and (
484
489
  wandb.run.settings.table_raise_on_max_row_limit_exceeded
485
490
  or wandb.run.settings.strict
@@ -431,6 +431,8 @@ class Trace:
431
431
  name: The name of the trace to be logged
432
432
  """
433
433
  trace_tree = WBTraceTree(self._span, self._model_dict)
434
+ # NOTE: Does not work for reinit="create_new" runs.
435
+ # This method should be deprecated and users should call run.log().
434
436
  assert (
435
437
  wandb.run is not None
436
438
  ), "You must call wandb.init() before logging a trace"
@@ -101,7 +101,7 @@ def val_to_json(
101
101
 
102
102
  items = _prune_max_seq(val)
103
103
 
104
- if _server_accepts_image_filenames():
104
+ if _server_accepts_image_filenames(run):
105
105
  for item in items:
106
106
  item.bind_to_run(
107
107
  run=run,