wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (193) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -3
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/internal.py +0 -1
  6. wandb/apis/importers/internals/protocols.py +30 -56
  7. wandb/apis/importers/mlflow.py +13 -26
  8. wandb/apis/importers/wandb.py +8 -14
  9. wandb/apis/internal.py +0 -3
  10. wandb/apis/public/api.py +55 -3
  11. wandb/apis/public/artifacts.py +1 -0
  12. wandb/apis/public/files.py +1 -0
  13. wandb/apis/public/history.py +1 -0
  14. wandb/apis/public/jobs.py +17 -4
  15. wandb/apis/public/projects.py +1 -0
  16. wandb/apis/public/reports.py +1 -0
  17. wandb/apis/public/runs.py +15 -17
  18. wandb/apis/public/sweeps.py +1 -0
  19. wandb/apis/public/teams.py +1 -0
  20. wandb/apis/public/users.py +1 -0
  21. wandb/apis/reports/v1/_blocks.py +3 -7
  22. wandb/apis/reports/v2/gql.py +1 -0
  23. wandb/apis/reports/v2/interface.py +3 -4
  24. wandb/apis/reports/v2/internal.py +5 -8
  25. wandb/cli/cli.py +92 -22
  26. wandb/data_types.py +9 -6
  27. wandb/docker/__init__.py +1 -1
  28. wandb/env.py +38 -8
  29. wandb/errors/__init__.py +5 -0
  30. wandb/errors/term.py +10 -2
  31. wandb/filesync/step_checksum.py +1 -4
  32. wandb/filesync/step_prepare.py +4 -24
  33. wandb/filesync/step_upload.py +4 -106
  34. wandb/filesync/upload_job.py +0 -76
  35. wandb/integration/catboost/catboost.py +1 -1
  36. wandb/integration/fastai/__init__.py +1 -0
  37. wandb/integration/huggingface/resolver.py +2 -2
  38. wandb/integration/keras/__init__.py +1 -0
  39. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  40. wandb/integration/keras/keras.py +7 -7
  41. wandb/integration/langchain/wandb_tracer.py +1 -0
  42. wandb/integration/lightning/fabric/logger.py +1 -3
  43. wandb/integration/metaflow/metaflow.py +41 -6
  44. wandb/integration/openai/fine_tuning.py +3 -3
  45. wandb/integration/prodigy/prodigy.py +1 -1
  46. wandb/old/summary.py +1 -1
  47. wandb/plot/confusion_matrix.py +1 -1
  48. wandb/plot/pr_curve.py +2 -1
  49. wandb/plot/roc_curve.py +2 -1
  50. wandb/{plots → plot}/utils.py +13 -25
  51. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  52. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  54. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  55. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/wandb_deprecated.py +7 -1
  58. wandb/proto/wandb_internal_codegen.py +3 -29
  59. wandb/sdk/artifacts/artifact.py +26 -11
  60. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  61. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  62. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  63. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  64. wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
  65. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  66. wandb/sdk/artifacts/artifact_saver.py +2 -8
  67. wandb/sdk/artifacts/artifact_state.py +1 -0
  68. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  69. wandb/sdk/artifacts/exceptions.py +1 -0
  70. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  71. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  72. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  73. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  74. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  75. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  76. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  77. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  78. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  79. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
  80. wandb/sdk/artifacts/storage_policy.py +2 -12
  81. wandb/sdk/data_types/_dtypes.py +8 -8
  82. wandb/sdk/data_types/base_types/media.py +3 -6
  83. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  84. wandb/sdk/data_types/image.py +1 -1
  85. wandb/sdk/data_types/video.py +1 -1
  86. wandb/sdk/integration_utils/auto_logging.py +5 -6
  87. wandb/sdk/integration_utils/data_logging.py +10 -6
  88. wandb/sdk/interface/interface.py +68 -32
  89. wandb/sdk/interface/interface_shared.py +7 -13
  90. wandb/sdk/internal/datastore.py +1 -1
  91. wandb/sdk/internal/file_pusher.py +2 -5
  92. wandb/sdk/internal/file_stream.py +5 -18
  93. wandb/sdk/internal/handler.py +18 -2
  94. wandb/sdk/internal/internal.py +0 -1
  95. wandb/sdk/internal/internal_api.py +1 -129
  96. wandb/sdk/internal/internal_util.py +0 -1
  97. wandb/sdk/internal/job_builder.py +159 -45
  98. wandb/sdk/internal/profiler.py +1 -0
  99. wandb/sdk/internal/progress.py +0 -28
  100. wandb/sdk/internal/run.py +1 -0
  101. wandb/sdk/internal/sender.py +1 -2
  102. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  103. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  104. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  105. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  106. wandb/sdk/internal/system/assets/trainium.py +1 -3
  107. wandb/sdk/launch/__init__.py +9 -1
  108. wandb/sdk/launch/_launch.py +4 -24
  109. wandb/sdk/launch/_launch_add.py +1 -3
  110. wandb/sdk/launch/_project_spec.py +186 -224
  111. wandb/sdk/launch/agent/agent.py +37 -13
  112. wandb/sdk/launch/agent/config.py +72 -14
  113. wandb/sdk/launch/builder/abstract.py +69 -1
  114. wandb/sdk/launch/builder/build.py +156 -555
  115. wandb/sdk/launch/builder/context_manager.py +235 -0
  116. wandb/sdk/launch/builder/docker_builder.py +8 -23
  117. wandb/sdk/launch/builder/kaniko_builder.py +12 -25
  118. wandb/sdk/launch/builder/noop.py +1 -0
  119. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  120. wandb/sdk/launch/create_job.py +47 -37
  121. wandb/sdk/launch/environment/abstract.py +1 -0
  122. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  123. wandb/sdk/launch/environment/local_environment.py +1 -0
  124. wandb/sdk/launch/inputs/files.py +148 -0
  125. wandb/sdk/launch/inputs/internal.py +217 -0
  126. wandb/sdk/launch/inputs/manage.py +95 -0
  127. wandb/sdk/launch/loader.py +1 -0
  128. wandb/sdk/launch/registry/abstract.py +1 -0
  129. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  130. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  131. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  132. wandb/sdk/launch/registry/local_registry.py +1 -0
  133. wandb/sdk/launch/runner/abstract.py +1 -0
  134. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  135. wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
  136. wandb/sdk/launch/runner/local_container.py +2 -3
  137. wandb/sdk/launch/runner/local_process.py +8 -29
  138. wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
  139. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  140. wandb/sdk/launch/sweeps/scheduler.py +4 -3
  141. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  142. wandb/sdk/launch/sweeps/utils.py +3 -3
  143. wandb/sdk/launch/utils.py +15 -140
  144. wandb/sdk/lib/_settings_toposort_generated.py +0 -5
  145. wandb/sdk/lib/fsm.py +8 -12
  146. wandb/sdk/lib/gitlib.py +4 -4
  147. wandb/sdk/lib/import_hooks.py +1 -1
  148. wandb/sdk/lib/lazyloader.py +0 -1
  149. wandb/sdk/lib/proto_util.py +23 -2
  150. wandb/sdk/lib/redirect.py +19 -14
  151. wandb/sdk/lib/retry.py +3 -2
  152. wandb/sdk/lib/tracelog.py +1 -1
  153. wandb/sdk/service/service.py +19 -16
  154. wandb/sdk/verify/verify.py +2 -1
  155. wandb/sdk/wandb_init.py +14 -55
  156. wandb/sdk/wandb_manager.py +2 -2
  157. wandb/sdk/wandb_require.py +5 -0
  158. wandb/sdk/wandb_run.py +114 -56
  159. wandb/sdk/wandb_settings.py +0 -48
  160. wandb/sdk/wandb_setup.py +1 -1
  161. wandb/sklearn/__init__.py +1 -0
  162. wandb/sklearn/plot/__init__.py +1 -0
  163. wandb/sklearn/plot/classifier.py +11 -12
  164. wandb/sklearn/plot/clusterer.py +2 -1
  165. wandb/sklearn/plot/regressor.py +1 -0
  166. wandb/sklearn/plot/shared.py +1 -0
  167. wandb/sklearn/utils.py +1 -0
  168. wandb/testing/relay.py +4 -4
  169. wandb/trigger.py +1 -0
  170. wandb/util.py +67 -54
  171. wandb/wandb_controller.py +2 -3
  172. wandb/wandb_torch.py +1 -2
  173. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
  174. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
  175. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
  176. wandb/bin/apple_gpu_stats +0 -0
  177. wandb/catboost/__init__.py +0 -9
  178. wandb/fastai/__init__.py +0 -9
  179. wandb/keras/__init__.py +0 -18
  180. wandb/lightgbm/__init__.py +0 -9
  181. wandb/plots/__init__.py +0 -6
  182. wandb/plots/explain_text.py +0 -36
  183. wandb/plots/heatmap.py +0 -81
  184. wandb/plots/named_entity.py +0 -43
  185. wandb/plots/part_of_speech.py +0 -50
  186. wandb/plots/plot_definitions.py +0 -768
  187. wandb/plots/precision_recall.py +0 -121
  188. wandb/plots/roc.py +0 -103
  189. wandb/sacred/__init__.py +0 -3
  190. wandb/xgboost/__init__.py +0 -9
  191. wandb-0.16.6.dist-info/top_level.txt +0 -1
  192. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
  193. {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  """Batching file prepare requests to our API."""
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import logging
6
5
  import queue
@@ -8,7 +7,6 @@ import sys
8
7
  import threading
9
8
  from typing import (
10
9
  TYPE_CHECKING,
11
- Awaitable,
12
10
  Callable,
13
11
  MutableMapping,
14
12
  MutableSequence,
@@ -43,7 +41,6 @@ if TYPE_CHECKING:
43
41
  PreCommitFn = Callable[[], None]
44
42
  OnRequestFinishFn = Callable[[], None]
45
43
  SaveFn = Callable[["progress.ProgressFn"], bool]
46
- SaveFnAsync = Callable[["progress.ProgressFn"], Awaitable[bool]]
47
44
 
48
45
  logger = logging.getLogger(__name__)
49
46
 
@@ -55,7 +52,6 @@ class RequestUpload(NamedTuple):
55
52
  md5: Optional[str]
56
53
  copied: bool
57
54
  save_fn: Optional[SaveFn]
58
- save_fn_async: Optional[SaveFnAsync]
59
55
  digest: Optional[str]
60
56
 
61
57
 
@@ -78,47 +74,6 @@ class EventJobDone(NamedTuple):
78
74
  Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
79
75
 
80
76
 
81
- class AsyncExecutor:
82
- """Runs async file uploads in a background thread."""
83
-
84
- def __init__(
85
- self,
86
- pool: concurrent.futures.ThreadPoolExecutor,
87
- concurrency_limit: Optional[int],
88
- ) -> None:
89
- self.loop = asyncio.new_event_loop()
90
- self.loop.set_default_executor(pool)
91
- self.loop_thread = threading.Thread(
92
- target=self.loop.run_forever,
93
- daemon=True,
94
- name="wandb-upload-async",
95
- )
96
-
97
- self.concurrency_limiter = asyncio.Semaphore(
98
- value=concurrency_limit or 128,
99
- # Before Python 3.10: if we don't set `loop=loop`,
100
- # then the Semaphore will bind to the wrong event loop,
101
- # causing errors when a coroutine tries to wait for it;
102
- # see https://pastebin.com/XcrS9suX .
103
- # After 3.10: the `loop` argument doesn't exist.
104
- # So we need to only conditionally pass in `loop`.
105
- **({} if sys.version_info >= (3, 10) else {"loop": self.loop}),
106
- )
107
-
108
- def start(self) -> None:
109
- self.loop_thread.start()
110
-
111
- def stop(self) -> None:
112
- self.loop.call_soon_threadsafe(self.loop.stop)
113
-
114
- def submit(self, coro: Awaitable[None]) -> None:
115
- async def run_with_limiter() -> None:
116
- async with self.concurrency_limiter:
117
- await coro
118
-
119
- asyncio.run_coroutine_threadsafe(run_with_limiter(), self.loop)
120
-
121
-
122
77
  class StepUpload:
123
78
  def __init__(
124
79
  self,
@@ -142,15 +97,6 @@ class StepUpload:
142
97
  max_workers=max_threads,
143
98
  )
144
99
 
145
- self._async_executor = (
146
- AsyncExecutor(
147
- pool=self._pool,
148
- concurrency_limit=settings._async_upload_concurrency_limit,
149
- )
150
- if settings is not None and settings._async_upload_concurrency_limit
151
- else None
152
- )
153
-
154
100
  # Indexed by files' `save_name`'s, which are their ID's in the Run.
155
101
  self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
156
102
  self._pending_jobs: MutableSequence[RequestUpload] = []
@@ -186,8 +132,6 @@ class StepUpload:
186
132
  elif not self._running_jobs:
187
133
  # Queue was empty and no jobs left.
188
134
  self._pool.shutdown(wait=False)
189
- if self._async_executor:
190
- self._async_executor.stop()
191
135
  if finish_callback:
192
136
  finish_callback()
193
137
  break
@@ -245,18 +189,9 @@ class StepUpload:
245
189
  self._pending_jobs.append(event)
246
190
  return
247
191
 
248
- if self._async_executor and event.save_fn_async is not None:
249
- # (The `and save_fn_async is not None` is because the async code path
250
- # doesn't support all uploads yet: even if the user has requested async,
251
- # we sometimes need to use the sync method instead.)
252
- self._spawn_upload_async(
253
- event,
254
- async_executor=self._async_executor,
255
- )
256
- else:
257
- self._spawn_upload_sync(event)
192
+ self._spawn_upload(event)
258
193
 
259
- def _spawn_upload_sync(self, event: RequestUpload) -> None:
194
+ def _spawn_upload(self, event: RequestUpload) -> None:
260
195
  """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
261
196
 
262
197
  Context: it's important that, whenever we add an entry to `self._running_jobs`,
@@ -285,35 +220,13 @@ class StepUpload:
285
220
 
286
221
  def run_and_notify() -> None:
287
222
  try:
288
- self._do_upload_sync(event)
223
+ self._do_upload(event)
289
224
  finally:
290
225
  self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
291
226
 
292
227
  self._pool.submit(run_and_notify)
293
228
 
294
- def _spawn_upload_async(
295
- self,
296
- event: RequestUpload,
297
- async_executor: AsyncExecutor,
298
- ) -> None:
299
- """Equivalent to _spawn_upload_sync, but uses the async event loop instead of a thread, and requires `event.save_fn_async`.
300
-
301
- Raises:
302
- AssertionError: if `event.save_fn_async` is None.
303
- """
304
- assert event.save_fn_async is not None
305
-
306
- self._running_jobs[event.save_name] = event
307
-
308
- async def run_and_notify() -> None:
309
- try:
310
- await self._do_upload_async(event)
311
- finally:
312
- self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
313
-
314
- async_executor.submit(run_and_notify())
315
-
316
- def _do_upload_sync(self, event: RequestUpload) -> None:
229
+ def _do_upload(self, event: RequestUpload) -> None:
317
230
  job = upload_job.UploadJob(
318
231
  self._stats,
319
232
  self._api,
@@ -329,19 +242,6 @@ class StepUpload:
329
242
  )
330
243
  job.run()
331
244
 
332
- async def _do_upload_async(self, event: RequestUpload) -> None:
333
- """Upload a file and returns when it's done. Requires `event.save_fn_async`."""
334
- assert event.save_fn_async is not None
335
- job = upload_job.UploadJobAsync(
336
- stats=self._stats,
337
- api=self._api,
338
- file_stream=self._file_stream,
339
- silent=self.silent,
340
- request=event,
341
- save_fn_async=event.save_fn_async,
342
- )
343
- await job.run()
344
-
345
245
  def _init_artifact(self, artifact_id: str) -> None:
346
246
  self._artifacts[artifact_id] = {
347
247
  "finalize": False,
@@ -385,8 +285,6 @@ class StepUpload:
385
285
 
386
286
  def start(self) -> None:
387
287
  self._thread.start()
388
- if self._async_executor:
389
- self._async_executor.start()
390
288
 
391
289
  def is_alive(self) -> bool:
392
290
  return self._thread.is_alive()
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import logging
3
2
  import os
4
3
  from typing import TYPE_CHECKING, Optional
@@ -141,78 +140,3 @@ class UploadJob:
141
140
 
142
141
  def progress(self, total_bytes: int) -> None:
143
142
  self._stats.update_uploaded_file(self.save_name, total_bytes)
144
-
145
-
146
- class UploadJobAsync:
147
- """Roughly an async equivalent of UploadJob.
148
-
149
- Important differences:
150
- - `run` is a coroutine
151
- - If `run()` fails, it falls back to the synchronous UploadJob
152
- """
153
-
154
- def __init__(
155
- self,
156
- stats: "stats.Stats",
157
- api: "internal_api.Api",
158
- file_stream: "file_stream.FileStreamApi",
159
- silent: bool,
160
- request: "step_upload.RequestUpload",
161
- save_fn_async: "step_upload.SaveFnAsync",
162
- ) -> None:
163
- self._stats = stats
164
- self._api = api
165
- self._file_stream = file_stream
166
- self.silent = silent
167
- self._request = request
168
- self._save_fn_async = save_fn_async
169
-
170
- async def run(self) -> None:
171
- try:
172
- deduped = await self._save_fn_async(
173
- lambda _, t: self._stats.update_uploaded_file(self._request.path, t)
174
- )
175
- except Exception as e:
176
- # Async uploads aren't yet (2023-01) battle-tested.
177
- # Fall back to the "normal" synchronous upload.
178
- loop = asyncio.get_event_loop()
179
- logger.exception("async upload failed", exc_info=e)
180
- loop.run_in_executor(None, wandb._sentry.exception, e)
181
- wandb.termwarn(
182
- "Async file upload failed; falling back to sync", repeat=False
183
- )
184
- sync_job = UploadJob(
185
- self._stats,
186
- self._api,
187
- self._file_stream,
188
- self.silent,
189
- self._request.save_name,
190
- self._request.path,
191
- self._request.artifact_id,
192
- self._request.md5,
193
- self._request.copied,
194
- self._request.save_fn,
195
- self._request.digest,
196
- )
197
-
198
- await loop.run_in_executor(None, sync_job.run)
199
- else:
200
- self._file_stream.push_success(
201
- self._request.artifact_id, # type: ignore
202
- self._request.save_name,
203
- )
204
-
205
- if deduped:
206
- logger.info("Skipped uploading %s", self._request.path)
207
- self._stats.set_file_deduped(self._request.path)
208
- else:
209
- logger.info("Uploaded file %s", self._request.path)
210
- finally:
211
- # If we fell back to the sync impl, the file will have already been deleted.
212
- # Doesn't matter, we only try to delete it if it exists.
213
- if self._request.copied:
214
- try:
215
- os.remove(self._request.path)
216
- except OSError:
217
- # The file has already been deleted, we don't have permissions, or something else we can't fix.
218
- pass
@@ -81,7 +81,7 @@ def _checkpoint_artifact(
81
81
 
82
82
 
83
83
  def _log_feature_importance(
84
- model: Union[CatBoostClassifier, CatBoostRegressor]
84
+ model: Union[CatBoostClassifier, CatBoostRegressor],
85
85
  ) -> None:
86
86
  """Log feature importance with default settings."""
87
87
  if wandb.run is None:
@@ -35,6 +35,7 @@ Examples:
35
35
  learn.fit(..., callbacks=WandbCallback(learn, ...))
36
36
  ```
37
37
  """
38
+
38
39
  import random
39
40
  import sys
40
41
  from pathlib import Path
@@ -91,8 +91,8 @@ class HuggingFacePipelineRequestResponseResolver:
91
91
 
92
92
  # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
93
93
  # from transformers.modeling_utils import PreTrainedModel
94
- # We do not want this dependency explicity in our codebase so we make a very general assumption about
95
- # the structure of the pipeline which may have unintended consequences
94
+ # We do not want this dependency explicitly in our codebase so we make a very general
95
+ # assumption about the structure of the pipeline which may have unintended consequences
96
96
  def _get_model(self, pipe) -> Optional[Any]:
97
97
  """Extracts model from the pipeline.
98
98
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  Keras is a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
4
4
  """
5
+
5
6
  __all__ = (
6
7
  "WandbCallback",
7
8
  "WandbMetricsLogger",
@@ -43,7 +43,7 @@ class WandbMetricsLogger(callbacks.Callback):
43
43
  at the end of each epoch. If "batch", logs metrics at the end
44
44
  of each batch. If an integer, logs metrics at the end of that
45
45
  many batches. Defaults to "epoch".
46
- initial_global_step: (int) Use this argument to correcly log the
46
+ initial_global_step: (int) Use this argument to correctly log the
47
47
  learning rate when you resume training from some `initial_epoch`,
48
48
  and a learning rate scheduler is used. This can be computed as
49
49
  `step_size * initial_step`. Defaults to 0.
@@ -576,7 +576,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
576
576
  )
577
577
  self._model_trained_since_last_eval = False
578
578
  except Exception as e:
579
- wandb.termwarn("Error durring prediction logging for epoch: " + str(e))
579
+ wandb.termwarn("Error during prediction logging for epoch: " + str(e))
580
580
 
581
581
  def on_epoch_end(self, epoch, logs=None):
582
582
  if logs is None:
@@ -616,9 +616,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
616
616
  self.current = logs.get(self.monitor)
617
617
  if self.current and self.monitor_op(self.current, self.best):
618
618
  if self.log_best_prefix:
619
- wandb.run.summary[
620
- f"{self.log_best_prefix}{self.monitor}"
621
- ] = self.current
619
+ wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
620
+ self.current
621
+ )
622
622
  wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
623
623
  if self.verbose and not self.save_model:
624
624
  print(
@@ -937,9 +937,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
937
937
  grads = self._grad_accumulator_callback.grads
938
938
  metrics = {}
939
939
  for weight, grad in zip(weights, grads):
940
- metrics[
941
- "gradients/" + weight.name.split(":")[0] + ".gradient"
942
- ] = wandb.Histogram(grad)
940
+ metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
941
+ wandb.Histogram(grad)
942
+ )
943
943
  return metrics
944
944
 
945
945
  def _log_dataframe(self):
@@ -14,6 +14,7 @@ integration will not break user code. The one exception to the rule is at import
14
14
  LangChain is not installed, or the symbols are not in the same place, the appropriate error
15
15
  will be raised when importing this module.
16
16
  """
17
+
17
18
  from packaging import version
18
19
 
19
20
  import wandb.util
@@ -401,9 +401,7 @@ class WandbLogger(Logger):
401
401
  "*", step_metric="trainer/global_step", step_sync=True
402
402
  )
403
403
 
404
- self._experiment._label(
405
- repo="lightning_fabric_logger"
406
- ) # pylint: disable=protected-access
404
+ self._experiment._label(repo="lightning_fabric_logger") # pylint: disable=protected-access
407
405
  with telemetry.context(run=self._experiment) as tel:
408
406
  tel.feature.lightning_fabric_logger = True
409
407
  return self._experiment
@@ -36,7 +36,15 @@ try:
36
36
  import pandas as pd
37
37
 
38
38
  @typedispatch # noqa: F811
39
- def _wandb_use(name: str, data: pd.DataFrame, datasets=False, run=None, testing=False, *args, **kwargs): # type: ignore
39
+ def _wandb_use(
40
+ name: str,
41
+ data: pd.DataFrame,
42
+ datasets=False,
43
+ run=None,
44
+ testing=False,
45
+ *args,
46
+ **kwargs,
47
+ ): # type: ignore
40
48
  if testing:
41
49
  return "datasets" if datasets else None
42
50
 
@@ -74,7 +82,15 @@ try:
74
82
  import torch.nn as nn
75
83
 
76
84
  @typedispatch # noqa: F811
77
- def _wandb_use(name: str, data: nn.Module, models=False, run=None, testing=False, *args, **kwargs): # type: ignore
85
+ def _wandb_use(
86
+ name: str,
87
+ data: nn.Module,
88
+ models=False,
89
+ run=None,
90
+ testing=False,
91
+ *args,
92
+ **kwargs,
93
+ ): # type: ignore
78
94
  if testing:
79
95
  return "models" if models else None
80
96
 
@@ -111,7 +127,15 @@ try:
111
127
  from sklearn.base import BaseEstimator
112
128
 
113
129
  @typedispatch # noqa: F811
114
- def _wandb_use(name: str, data: BaseEstimator, models=False, run=None, testing=False, *args, **kwargs): # type: ignore
130
+ def _wandb_use(
131
+ name: str,
132
+ data: BaseEstimator,
133
+ models=False,
134
+ run=None,
135
+ testing=False,
136
+ *args,
137
+ **kwargs,
138
+ ): # type: ignore
115
139
  if testing:
116
140
  return "models" if models else None
117
141
 
@@ -169,7 +193,14 @@ class ArtifactProxy:
169
193
 
170
194
 
171
195
  @typedispatch # noqa: F811
172
- def wandb_track(name: str, data: (dict, list, set, str, int, float, bool), run=None, testing=False, *args, **kwargs): # type: ignore
196
+ def wandb_track(
197
+ name: str,
198
+ data: (dict, list, set, str, int, float, bool),
199
+ run=None,
200
+ testing=False,
201
+ *args,
202
+ **kwargs,
203
+ ): # type: ignore
173
204
  if testing:
174
205
  return "scalar"
175
206
 
@@ -222,12 +253,16 @@ def wandb_use(name: str, data, *args, **kwargs):
222
253
 
223
254
 
224
255
  @typedispatch # noqa: F811
225
- def wandb_use(name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs): # type: ignore
256
+ def wandb_use(
257
+ name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs
258
+ ): # type: ignore
226
259
  pass # do nothing for these types
227
260
 
228
261
 
229
262
  @typedispatch # noqa: F811
230
- def _wandb_use(name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs): # type: ignore
263
+ def _wandb_use(
264
+ name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
265
+ ): # type: ignore
231
266
  if testing:
232
267
  return "datasets" if datasets else None
233
268
 
@@ -310,9 +310,9 @@ class WandbLogger:
310
310
  try:
311
311
  hyperparams["n_epochs"] = hyperparameters.n_epochs
312
312
  hyperparams["batch_size"] = hyperparameters.batch_size
313
- hyperparams[
314
- "learning_rate_multiplier"
315
- ] = hyperparameters.learning_rate_multiplier
313
+ hyperparams["learning_rate_multiplier"] = (
314
+ hyperparameters.learning_rate_multiplier
315
+ )
316
316
  except Exception:
317
317
  # If unpacking fails, return the object to be logged as config
318
318
  return None
@@ -26,7 +26,7 @@ from PIL import Image
26
26
 
27
27
  import wandb
28
28
  from wandb import util
29
- from wandb.plots.utils import test_missing
29
+ from wandb.plot.utils import test_missing
30
30
  from wandb.sdk.lib import telemetry as wb_telemetry
31
31
 
32
32
 
wandb/old/summary.py CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
61
61
  This should only be implemented by the "_root" child class.
62
62
 
63
63
  We pass the child_dict so the item can be set on it or not as
64
- appropriate. Returning None for a nonexistant path wouldn't be
64
+ appropriate. Returning None for a nonexistent path wouldn't be
65
65
  distinguishable from that path being set to the value None.
66
66
  """
67
67
  raise NotImplementedError
@@ -48,7 +48,7 @@ def confusion_matrix(
48
48
 
49
49
  assert (probs is None or preds is None) and not (
50
50
  probs is None and preds is None
51
- ), "Must provide probabilties or predictions but not both to confusion matrix"
51
+ ), "Must provide probabilities or predictions but not both to confusion matrix"
52
52
 
53
53
  if probs is not None:
54
54
  preds = np.argmax(probs, axis=1).tolist()
wandb/plot/pr_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def pr_curve(
wandb/plot/roc_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def roc_curve(
@@ -1,21 +1,9 @@
1
- from collections.abc import Iterable, Sequence
1
+ from typing import Iterable, Sequence
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.sdk.lib import deprecate
6
5
 
7
6
 
8
- def deprecation_notice() -> None:
9
- deprecate.deprecate(
10
- field_name=deprecate.Deprecated.plots,
11
- warning_message=(
12
- "wandb.plots.* functions are deprecated and will be removed in a future release. "
13
- "Please use wandb.plot.* instead."
14
- ),
15
- )
16
-
17
-
18
- # Test assumptions for plotting parameters and datasets
19
7
  def test_missing(**kwargs):
20
8
  np = util.get_module("numpy", required="Logging plots requires numpy")
21
9
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
25
13
  for k, v in kwargs.items():
26
14
  # Missing/empty params/datapoint arrays
27
15
  if v is None:
28
- wandb.termerror("%s is None. Please try again." % (k))
16
+ wandb.termerror("{} is None. Please try again.".format(k))
29
17
  test_passed = False
30
18
  if (k == "X") or (k == "X_test"):
31
19
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
64
52
  )
65
53
  if non_nums > 0:
66
54
  wandb.termerror(
67
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
68
- % (k, k)
55
+ f"{k} contains values that are not numbers. Please vectorize, "
56
+ f"label encode or one hot encode {k} and call the plotting function again."
69
57
  )
70
58
  test_passed = False
71
59
  return test_passed
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
73
61
 
74
62
  def test_fitted(model):
75
63
  np = util.get_module("numpy", required="Logging plots requires numpy")
76
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
77
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
64
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
65
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
78
66
  scikit_utils = util.get_module(
79
67
  "sklearn.utils",
80
68
  required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
@@ -120,7 +108,7 @@ def test_fitted(model):
120
108
 
121
109
 
122
110
  def encode_labels(df):
123
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
111
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
124
112
  preprocessing = util.get_module(
125
113
  "sklearn.preprocessing",
126
114
  "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
@@ -137,7 +125,7 @@ def encode_labels(df):
137
125
  def test_types(**kwargs):
138
126
  np = util.get_module("numpy", required="Logging plots requires numpy")
139
127
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
140
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
128
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
141
129
 
142
130
  base = util.get_module(
143
131
  "sklearn.base",
@@ -171,25 +159,25 @@ def test_types(**kwargs):
171
159
  list,
172
160
  ),
173
161
  ):
174
- wandb.termerror("%s is not an array. Please try again." % (k))
162
+ wandb.termerror("{} is not an array. Please try again.".format(k))
175
163
  test_passed = False
176
164
  # check for classifier types
177
165
  if k == "model":
178
166
  if (not base.is_classifier(v)) and (not base.is_regressor(v)):
179
167
  wandb.termerror(
180
- "%s is not a classifier or regressor. Please try again." % (k)
168
+ "{} is not a classifier or regressor. Please try again.".format(k)
181
169
  )
182
170
  test_passed = False
183
171
  elif k == "clf" or k == "binary_clf":
184
172
  if not (base.is_classifier(v)):
185
- wandb.termerror("%s is not a classifier. Please try again." % (k))
173
+ wandb.termerror("{} is not a classifier. Please try again.".format(k))
186
174
  test_passed = False
187
175
  elif k == "regressor":
188
176
  if not base.is_regressor(v):
189
- wandb.termerror("%s is not a regressor. Please try again." % (k))
177
+ wandb.termerror("{} is not a regressor. Please try again.".format(k))
190
178
  test_passed = False
191
179
  elif k == "clusterer":
192
180
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
193
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
181
+ wandb.termerror("{} is not a clusterer. Please try again.".format(k))
194
182
  test_passed = False
195
183
  return test_passed