wandb 0.17.0rc1__py3-none-any.whl → 0.17.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (173) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/cli/cli.py +131 -59
  19. wandb/data_types.py +6 -3
  20. wandb/docker/__init__.py +2 -2
  21. wandb/env.py +3 -3
  22. wandb/errors/term.py +10 -2
  23. wandb/filesync/step_checksum.py +1 -4
  24. wandb/filesync/step_prepare.py +4 -24
  25. wandb/filesync/step_upload.py +5 -107
  26. wandb/filesync/upload_job.py +0 -76
  27. wandb/integration/gym/__init__.py +35 -15
  28. wandb/integration/huggingface/resolver.py +2 -2
  29. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  30. wandb/integration/keras/keras.py +1 -1
  31. wandb/integration/openai/fine_tuning.py +21 -3
  32. wandb/integration/prodigy/prodigy.py +1 -1
  33. wandb/jupyter.py +16 -17
  34. wandb/old/summary.py +1 -1
  35. wandb/plot/confusion_matrix.py +1 -1
  36. wandb/plot/pr_curve.py +2 -1
  37. wandb/plot/roc_curve.py +2 -1
  38. wandb/{plots → plot}/utils.py +13 -25
  39. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  40. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  41. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  42. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  43. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  44. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  45. wandb/proto/v5/wandb_base_pb2.py +30 -0
  46. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  47. wandb/proto/v5/wandb_server_pb2.py +63 -0
  48. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  49. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  50. wandb/proto/wandb_base_pb2.py +2 -0
  51. wandb/proto/wandb_deprecated.py +9 -1
  52. wandb/proto/wandb_generate_deprecated.py +34 -0
  53. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  54. wandb/proto/wandb_internal_pb2.py +2 -0
  55. wandb/proto/wandb_server_pb2.py +2 -0
  56. wandb/proto/wandb_settings_pb2.py +2 -0
  57. wandb/proto/wandb_telemetry_pb2.py +2 -0
  58. wandb/sdk/artifacts/artifact.py +68 -22
  59. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  61. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  62. wandb/sdk/artifacts/artifact_saver.py +1 -10
  63. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  64. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  65. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  66. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  67. wandb/sdk/artifacts/storage_policy.py +1 -12
  68. wandb/sdk/data_types/_dtypes.py +8 -8
  69. wandb/sdk/data_types/image.py +2 -2
  70. wandb/sdk/data_types/video.py +5 -3
  71. wandb/sdk/integration_utils/data_logging.py +5 -5
  72. wandb/sdk/interface/interface.py +14 -1
  73. wandb/sdk/interface/interface_shared.py +1 -1
  74. wandb/sdk/internal/file_pusher.py +2 -5
  75. wandb/sdk/internal/file_stream.py +6 -19
  76. wandb/sdk/internal/internal_api.py +148 -136
  77. wandb/sdk/internal/job_builder.py +208 -136
  78. wandb/sdk/internal/progress.py +0 -28
  79. wandb/sdk/internal/sender.py +102 -39
  80. wandb/sdk/internal/settings_static.py +8 -1
  81. wandb/sdk/internal/system/assets/trainium.py +3 -3
  82. wandb/sdk/internal/system/system_info.py +4 -2
  83. wandb/sdk/internal/update.py +1 -1
  84. wandb/sdk/launch/__init__.py +9 -1
  85. wandb/sdk/launch/_launch.py +4 -24
  86. wandb/sdk/launch/_launch_add.py +1 -3
  87. wandb/sdk/launch/_project_spec.py +187 -225
  88. wandb/sdk/launch/agent/agent.py +59 -19
  89. wandb/sdk/launch/agent/config.py +0 -3
  90. wandb/sdk/launch/builder/abstract.py +68 -1
  91. wandb/sdk/launch/builder/build.py +165 -576
  92. wandb/sdk/launch/builder/context_manager.py +235 -0
  93. wandb/sdk/launch/builder/docker_builder.py +7 -23
  94. wandb/sdk/launch/builder/kaniko_builder.py +12 -25
  95. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  96. wandb/sdk/launch/create_job.py +51 -45
  97. wandb/sdk/launch/environment/aws_environment.py +26 -1
  98. wandb/sdk/launch/inputs/files.py +148 -0
  99. wandb/sdk/launch/inputs/internal.py +224 -0
  100. wandb/sdk/launch/inputs/manage.py +95 -0
  101. wandb/sdk/launch/registry/google_artifact_registry.py +1 -1
  102. wandb/sdk/launch/runner/abstract.py +2 -2
  103. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  104. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  105. wandb/sdk/launch/runner/local_container.py +2 -3
  106. wandb/sdk/launch/runner/local_process.py +8 -29
  107. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  108. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  109. wandb/sdk/launch/sweeps/scheduler.py +5 -3
  110. wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -1
  111. wandb/sdk/launch/sweeps/utils.py +4 -4
  112. wandb/sdk/launch/utils.py +16 -138
  113. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  114. wandb/sdk/lib/apikey.py +4 -2
  115. wandb/sdk/lib/config_util.py +3 -3
  116. wandb/sdk/lib/import_hooks.py +1 -1
  117. wandb/sdk/lib/proto_util.py +22 -1
  118. wandb/sdk/lib/redirect.py +20 -15
  119. wandb/sdk/lib/tracelog.py +1 -1
  120. wandb/sdk/service/service.py +2 -1
  121. wandb/sdk/service/streams.py +5 -5
  122. wandb/sdk/wandb_init.py +25 -59
  123. wandb/sdk/wandb_login.py +28 -25
  124. wandb/sdk/wandb_run.py +123 -53
  125. wandb/sdk/wandb_settings.py +33 -64
  126. wandb/sdk/wandb_setup.py +1 -1
  127. wandb/sdk/wandb_watch.py +1 -1
  128. wandb/sklearn/plot/classifier.py +10 -12
  129. wandb/sklearn/plot/clusterer.py +1 -1
  130. wandb/sync/sync.py +2 -2
  131. wandb/testing/relay.py +32 -17
  132. wandb/util.py +36 -37
  133. wandb/wandb_agent.py +3 -3
  134. wandb/wandb_controller.py +5 -4
  135. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/METADATA +8 -10
  136. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/RECORD +139 -161
  137. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
  138. wandb/apis/reports/v1/_blocks.py +0 -1406
  139. wandb/apis/reports/v1/_helpers.py +0 -70
  140. wandb/apis/reports/v1/_panels.py +0 -1282
  141. wandb/apis/reports/v1/_templates.py +0 -478
  142. wandb/apis/reports/v1/blocks.py +0 -27
  143. wandb/apis/reports/v1/helpers.py +0 -2
  144. wandb/apis/reports/v1/mutations.py +0 -66
  145. wandb/apis/reports/v1/panels.py +0 -17
  146. wandb/apis/reports/v1/report.py +0 -268
  147. wandb/apis/reports/v1/runset.py +0 -144
  148. wandb/apis/reports/v1/templates.py +0 -7
  149. wandb/apis/reports/v1/util.py +0 -406
  150. wandb/apis/reports/v1/validators.py +0 -131
  151. wandb/apis/reports/v2/blocks.py +0 -25
  152. wandb/apis/reports/v2/expr_parsing.py +0 -257
  153. wandb/apis/reports/v2/gql.py +0 -68
  154. wandb/apis/reports/v2/interface.py +0 -1911
  155. wandb/apis/reports/v2/internal.py +0 -867
  156. wandb/apis/reports/v2/metrics.py +0 -6
  157. wandb/apis/reports/v2/panels.py +0 -15
  158. wandb/catboost/__init__.py +0 -9
  159. wandb/fastai/__init__.py +0 -9
  160. wandb/keras/__init__.py +0 -19
  161. wandb/lightgbm/__init__.py +0 -9
  162. wandb/plots/__init__.py +0 -6
  163. wandb/plots/explain_text.py +0 -36
  164. wandb/plots/heatmap.py +0 -81
  165. wandb/plots/named_entity.py +0 -43
  166. wandb/plots/part_of_speech.py +0 -50
  167. wandb/plots/plot_definitions.py +0 -768
  168. wandb/plots/precision_recall.py +0 -121
  169. wandb/plots/roc.py +0 -103
  170. wandb/sacred/__init__.py +0 -3
  171. wandb/xgboost/__init__.py +0 -9
  172. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
  173. {wandb-0.17.0rc1.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  """Batching file prepare requests to our API."""
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import logging
6
5
  import queue
@@ -8,7 +7,6 @@ import sys
8
7
  import threading
9
8
  from typing import (
10
9
  TYPE_CHECKING,
11
- Awaitable,
12
10
  Callable,
13
11
  MutableMapping,
14
12
  MutableSequence,
@@ -43,7 +41,6 @@ if TYPE_CHECKING:
43
41
  PreCommitFn = Callable[[], None]
44
42
  OnRequestFinishFn = Callable[[], None]
45
43
  SaveFn = Callable[["progress.ProgressFn"], bool]
46
- SaveFnAsync = Callable[["progress.ProgressFn"], Awaitable[bool]]
47
44
 
48
45
  logger = logging.getLogger(__name__)
49
46
 
@@ -55,7 +52,6 @@ class RequestUpload(NamedTuple):
55
52
  md5: Optional[str]
56
53
  copied: bool
57
54
  save_fn: Optional[SaveFn]
58
- save_fn_async: Optional[SaveFnAsync]
59
55
  digest: Optional[str]
60
56
 
61
57
 
@@ -78,47 +74,6 @@ class EventJobDone(NamedTuple):
78
74
  Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
79
75
 
80
76
 
81
- class AsyncExecutor:
82
- """Runs async file uploads in a background thread."""
83
-
84
- def __init__(
85
- self,
86
- pool: concurrent.futures.ThreadPoolExecutor,
87
- concurrency_limit: Optional[int],
88
- ) -> None:
89
- self.loop = asyncio.new_event_loop()
90
- self.loop.set_default_executor(pool)
91
- self.loop_thread = threading.Thread(
92
- target=self.loop.run_forever,
93
- daemon=True,
94
- name="wandb-upload-async",
95
- )
96
-
97
- self.concurrency_limiter = asyncio.Semaphore(
98
- value=concurrency_limit or 128,
99
- # Before Python 3.10: if we don't set `loop=loop`,
100
- # then the Semaphore will bind to the wrong event loop,
101
- # causing errors when a coroutine tries to wait for it;
102
- # see https://pastebin.com/XcrS9suX .
103
- # After 3.10: the `loop` argument doesn't exist.
104
- # So we need to only conditionally pass in `loop`.
105
- **({} if sys.version_info >= (3, 10) else {"loop": self.loop}),
106
- )
107
-
108
- def start(self) -> None:
109
- self.loop_thread.start()
110
-
111
- def stop(self) -> None:
112
- self.loop.call_soon_threadsafe(self.loop.stop)
113
-
114
- def submit(self, coro: Awaitable[None]) -> None:
115
- async def run_with_limiter() -> None:
116
- async with self.concurrency_limiter:
117
- await coro
118
-
119
- asyncio.run_coroutine_threadsafe(run_with_limiter(), self.loop)
120
-
121
-
122
77
  class StepUpload:
123
78
  def __init__(
124
79
  self,
@@ -142,15 +97,6 @@ class StepUpload:
142
97
  max_workers=max_threads,
143
98
  )
144
99
 
145
- self._async_executor = (
146
- AsyncExecutor(
147
- pool=self._pool,
148
- concurrency_limit=settings._async_upload_concurrency_limit,
149
- )
150
- if settings is not None and settings._async_upload_concurrency_limit
151
- else None
152
- )
153
-
154
100
  # Indexed by files' `save_name`'s, which are their ID's in the Run.
155
101
  self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
156
102
  self._pending_jobs: MutableSequence[RequestUpload] = []
@@ -186,8 +132,6 @@ class StepUpload:
186
132
  elif not self._running_jobs:
187
133
  # Queue was empty and no jobs left.
188
134
  self._pool.shutdown(wait=False)
189
- if self._async_executor:
190
- self._async_executor.stop()
191
135
  if finish_callback:
192
136
  finish_callback()
193
137
  break
@@ -235,7 +179,7 @@ class StepUpload:
235
179
  self._artifacts[event.artifact_id]["pending_count"] += 1
236
180
  self._start_upload_job(event)
237
181
  else:
238
- raise Exception("Programming error: unhandled event: %s" % str(event))
182
+ raise Exception("Programming error: unhandled event: {}".format(str(event)))
239
183
 
240
184
  def _start_upload_job(self, event: RequestUpload) -> None:
241
185
  # Operations on a single backend file must be serialized. if
@@ -245,18 +189,9 @@ class StepUpload:
245
189
  self._pending_jobs.append(event)
246
190
  return
247
191
 
248
- if self._async_executor and event.save_fn_async is not None:
249
- # (The `and save_fn_async is not None` is because the async code path
250
- # doesn't support all uploads yet: even if the user has requested async,
251
- # we sometimes need to use the sync method instead.)
252
- self._spawn_upload_async(
253
- event,
254
- async_executor=self._async_executor,
255
- )
256
- else:
257
- self._spawn_upload_sync(event)
192
+ self._spawn_upload(event)
258
193
 
259
- def _spawn_upload_sync(self, event: RequestUpload) -> None:
194
+ def _spawn_upload(self, event: RequestUpload) -> None:
260
195
  """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
261
196
 
262
197
  Context: it's important that, whenever we add an entry to `self._running_jobs`,
@@ -285,35 +220,13 @@ class StepUpload:
285
220
 
286
221
  def run_and_notify() -> None:
287
222
  try:
288
- self._do_upload_sync(event)
223
+ self._do_upload(event)
289
224
  finally:
290
225
  self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
291
226
 
292
227
  self._pool.submit(run_and_notify)
293
228
 
294
- def _spawn_upload_async(
295
- self,
296
- event: RequestUpload,
297
- async_executor: AsyncExecutor,
298
- ) -> None:
299
- """Equivalent to _spawn_upload_sync, but uses the async event loop instead of a thread, and requires `event.save_fn_async`.
300
-
301
- Raises:
302
- AssertionError: if `event.save_fn_async` is None.
303
- """
304
- assert event.save_fn_async is not None
305
-
306
- self._running_jobs[event.save_name] = event
307
-
308
- async def run_and_notify() -> None:
309
- try:
310
- await self._do_upload_async(event)
311
- finally:
312
- self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
313
-
314
- async_executor.submit(run_and_notify())
315
-
316
- def _do_upload_sync(self, event: RequestUpload) -> None:
229
+ def _do_upload(self, event: RequestUpload) -> None:
317
230
  job = upload_job.UploadJob(
318
231
  self._stats,
319
232
  self._api,
@@ -329,19 +242,6 @@ class StepUpload:
329
242
  )
330
243
  job.run()
331
244
 
332
- async def _do_upload_async(self, event: RequestUpload) -> None:
333
- """Upload a file and returns when it's done. Requires `event.save_fn_async`."""
334
- assert event.save_fn_async is not None
335
- job = upload_job.UploadJobAsync(
336
- stats=self._stats,
337
- api=self._api,
338
- file_stream=self._file_stream,
339
- silent=self.silent,
340
- request=event,
341
- save_fn_async=event.save_fn_async,
342
- )
343
- await job.run()
344
-
345
245
  def _init_artifact(self, artifact_id: str) -> None:
346
246
  self._artifacts[artifact_id] = {
347
247
  "finalize": False,
@@ -385,8 +285,6 @@ class StepUpload:
385
285
 
386
286
  def start(self) -> None:
387
287
  self._thread.start()
388
- if self._async_executor:
389
- self._async_executor.start()
390
288
 
391
289
  def is_alive(self) -> bool:
392
290
  return self._thread.is_alive()
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import logging
3
2
  import os
4
3
  from typing import TYPE_CHECKING, Optional
@@ -141,78 +140,3 @@ class UploadJob:
141
140
 
142
141
  def progress(self, total_bytes: int) -> None:
143
142
  self._stats.update_uploaded_file(self.save_name, total_bytes)
144
-
145
-
146
- class UploadJobAsync:
147
- """Roughly an async equivalent of UploadJob.
148
-
149
- Important differences:
150
- - `run` is a coroutine
151
- - If `run()` fails, it falls back to the synchronous UploadJob
152
- """
153
-
154
- def __init__(
155
- self,
156
- stats: "stats.Stats",
157
- api: "internal_api.Api",
158
- file_stream: "file_stream.FileStreamApi",
159
- silent: bool,
160
- request: "step_upload.RequestUpload",
161
- save_fn_async: "step_upload.SaveFnAsync",
162
- ) -> None:
163
- self._stats = stats
164
- self._api = api
165
- self._file_stream = file_stream
166
- self.silent = silent
167
- self._request = request
168
- self._save_fn_async = save_fn_async
169
-
170
- async def run(self) -> None:
171
- try:
172
- deduped = await self._save_fn_async(
173
- lambda _, t: self._stats.update_uploaded_file(self._request.path, t)
174
- )
175
- except Exception as e:
176
- # Async uploads aren't yet (2023-01) battle-tested.
177
- # Fall back to the "normal" synchronous upload.
178
- loop = asyncio.get_event_loop()
179
- logger.exception("async upload failed", exc_info=e)
180
- loop.run_in_executor(None, wandb._sentry.exception, e)
181
- wandb.termwarn(
182
- "Async file upload failed; falling back to sync", repeat=False
183
- )
184
- sync_job = UploadJob(
185
- self._stats,
186
- self._api,
187
- self._file_stream,
188
- self.silent,
189
- self._request.save_name,
190
- self._request.path,
191
- self._request.artifact_id,
192
- self._request.md5,
193
- self._request.copied,
194
- self._request.save_fn,
195
- self._request.digest,
196
- )
197
-
198
- await loop.run_in_executor(None, sync_job.run)
199
- else:
200
- self._file_stream.push_success(
201
- self._request.artifact_id, # type: ignore
202
- self._request.save_name,
203
- )
204
-
205
- if deduped:
206
- logger.info("Skipped uploading %s", self._request.path)
207
- self._stats.set_file_deduped(self._request.path)
208
- else:
209
- logger.info("Uploaded file %s", self._request.path)
210
- finally:
211
- # If we fell back to the sync impl, the file will have already been deleted.
212
- # Doesn't matter, we only try to delete it if it exists.
213
- if self._request.copied:
214
- try:
215
- os.remove(self._request.path)
216
- except OSError:
217
- # The file has already been deleted, we don't have permissions, or something else we can't fix.
218
- pass
@@ -12,6 +12,8 @@ else:
12
12
 
13
13
 
14
14
  _gym_version_lt_0_26: Optional[bool] = None
15
+ _gymnasium_version_lt_1_0_0: Optional[bool] = None
16
+
15
17
  _required_error_msg = (
16
18
  "Couldn't import the gymnasium python package, "
17
19
  "install with `pip install gymnasium`"
@@ -35,14 +37,10 @@ def monitor():
35
37
  if gym_lib is None:
36
38
  raise wandb.Error(_required_error_msg)
37
39
 
38
- vcr = wandb.util.get_module(
39
- f"{gym_lib}.wrappers.monitoring.video_recorder",
40
- required=_required_error_msg,
41
- )
42
-
43
40
  global _gym_version_lt_0_26
41
+ global _gymnasium_version_lt_1_0_0
44
42
 
45
- if _gym_version_lt_0_26 is None:
43
+ if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None:
46
44
  if gym_lib == "gym":
47
45
  import gym
48
46
  else:
@@ -50,15 +48,31 @@ def monitor():
50
48
 
51
49
  from wandb.util import parse_version
52
50
 
53
- if parse_version(gym.__version__) < parse_version("0.26.0"):
54
- _gym_version_lt_0_26 = True
51
+ gym_lib_version = parse_version(gym.__version__)
52
+ _gym_version_lt_0_26 = gym_lib_version < parse_version("0.26.0")
53
+ _gymnasium_version_lt_1_0_0 = gym_lib_version < parse_version("1.0.0a1")
54
+
55
+ path = "path" # Default path
56
+ if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
57
+ vcr_recorder_attribute = "RecordVideo"
58
+ wrappers = wandb.util.get_module(
59
+ f"{gym_lib}.wrappers",
60
+ required=_required_error_msg,
61
+ )
62
+ recorder = getattr(wrappers, vcr_recorder_attribute)
63
+ else:
64
+ vcr = wandb.util.get_module(
65
+ f"{gym_lib}.wrappers.monitoring.video_recorder",
66
+ required=_required_error_msg,
67
+ )
68
+ # Breaking change in gym 0.26.0
69
+ if _gym_version_lt_0_26:
70
+ vcr_recorder_attribute = "ImageEncoder"
71
+ recorder = getattr(vcr, vcr_recorder_attribute)
72
+ path = "output_path" # Override path for older gym versions
55
73
  else:
56
- _gym_version_lt_0_26 = False
57
-
58
- # breaking change in gym 0.26.0
59
- vcr_recorder_attribute = "ImageEncoder" if _gym_version_lt_0_26 else "VideoRecorder"
60
- recorder = getattr(vcr, vcr_recorder_attribute)
61
- path = "output_path" if _gym_version_lt_0_26 else "path"
74
+ vcr_recorder_attribute = "VideoRecorder"
75
+ recorder = getattr(vcr, vcr_recorder_attribute)
62
76
 
63
77
  recorder.orig_close = recorder.close
64
78
 
@@ -77,9 +91,15 @@ def monitor():
77
91
  if not _gym_version_lt_0_26:
78
92
  recorder.__del__ = del_
79
93
  recorder.close = close
94
+
95
+ if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
96
+ wrapper_name = vcr_recorder_attribute
97
+ else:
98
+ wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}"
99
+
80
100
  wandb.patched["gym"].append(
81
101
  [
82
- f"{gym_lib}.wrappers.monitoring.video_recorder.{vcr_recorder_attribute}",
102
+ f"{gym_lib}.wrappers.{wrapper_name}",
83
103
  "close",
84
104
  ]
85
105
  )
@@ -91,8 +91,8 @@ class HuggingFacePipelineRequestResponseResolver:
91
91
 
92
92
  # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
93
93
  # from transformers.modeling_utils import PreTrainedModel
94
- # We do not want this dependency explicity in our codebase so we make a very general assumption about
95
- # the structure of the pipeline which may have unintended consequences
94
+ # We do not want this dependency explicitly in our codebase so we make a very general
95
+ # assumption about the structure of the pipeline which may have unintended consequences
96
96
  def _get_model(self, pipe) -> Optional[Any]:
97
97
  """Extracts model from the pipeline.
98
98
 
@@ -43,7 +43,7 @@ class WandbMetricsLogger(callbacks.Callback):
43
43
  at the end of each epoch. If "batch", logs metrics at the end
44
44
  of each batch. If an integer, logs metrics at the end of that
45
45
  many batches. Defaults to "epoch".
46
- initial_global_step: (int) Use this argument to correcly log the
46
+ initial_global_step: (int) Use this argument to correctly log the
47
47
  learning rate when you resume training from some `initial_epoch`,
48
48
  and a learning rate scheduler is used. This can be computed as
49
49
  `step_size * initial_step`. Defaults to 0.
@@ -576,7 +576,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
576
576
  )
577
577
  self._model_trained_since_last_eval = False
578
578
  except Exception as e:
579
- wandb.termwarn("Error durring prediction logging for epoch: " + str(e))
579
+ wandb.termwarn("Error during prediction logging for epoch: " + str(e))
580
580
 
581
581
  def on_epoch_end(self, epoch, logs=None):
582
582
  if logs is None:
@@ -65,6 +65,8 @@ class WandbLogger:
65
65
  overwrite: bool = False,
66
66
  wait_for_job_success: bool = True,
67
67
  log_datasets: bool = True,
68
+ model_artifact_name: str = "model-metadata",
69
+ model_artifact_type: str = "model",
68
70
  **kwargs_wandb_init: Dict[str, Any],
69
71
  ) -> str:
70
72
  """Sync fine-tunes to Weights & Biases.
@@ -76,6 +78,8 @@ class WandbLogger:
76
78
  :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
77
79
  :param overwrite: Forces logging and overwrite existing wandb run of the same fine-tune.
78
80
  :param wait_for_job_success: Waits for the fine-tune to be complete and then log metrics to W&B. By default, it is True.
81
+ :param model_artifact_name: Name of the model artifact that is logged
82
+ :param model_artifact_type: Type of the model artifact that is logged
79
83
  """
80
84
  if openai_client is None:
81
85
  openai_client = OpenAI()
@@ -157,6 +161,8 @@ class WandbLogger:
157
161
  overwrite,
158
162
  show_individual_warnings,
159
163
  log_datasets,
164
+ model_artifact_name,
165
+ model_artifact_type,
160
166
  **kwargs_wandb_init,
161
167
  )
162
168
 
@@ -201,6 +207,8 @@ class WandbLogger:
201
207
  overwrite: bool,
202
208
  show_individual_warnings: bool,
203
209
  log_datasets: bool,
210
+ model_artifact_name: str,
211
+ model_artifact_type: str,
204
212
  **kwargs_wandb_init: Dict[str, Any],
205
213
  ):
206
214
  fine_tune_id = fine_tune.id
@@ -244,7 +252,15 @@ class WandbLogger:
244
252
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
245
253
 
246
254
  # training/validation files and fine-tune details
247
- cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
255
+ cls._log_artifacts(
256
+ fine_tune,
257
+ project,
258
+ entity,
259
+ log_datasets,
260
+ overwrite,
261
+ model_artifact_name,
262
+ model_artifact_type,
263
+ )
248
264
 
249
265
  # mark run as complete
250
266
  cls._run.summary["status"] = "succeeded"
@@ -341,6 +357,8 @@ class WandbLogger:
341
357
  entity: Optional[str],
342
358
  log_datasets: bool,
343
359
  overwrite: bool,
360
+ model_artifact_name: str,
361
+ model_artifact_type: str,
344
362
  ) -> None:
345
363
  if log_datasets:
346
364
  wandb.termlog("Logging training/validation files...")
@@ -361,8 +379,8 @@ class WandbLogger:
361
379
  # fine-tune details
362
380
  fine_tune_id = fine_tune.id
363
381
  artifact = wandb.Artifact(
364
- "model_metadata",
365
- type="model",
382
+ model_artifact_name,
383
+ type=model_artifact_type,
366
384
  metadata=dict(fine_tune),
367
385
  )
368
386
 
@@ -26,7 +26,7 @@ from PIL import Image
26
26
 
27
27
  import wandb
28
28
  from wandb import util
29
- from wandb.plots.utils import test_missing
29
+ from wandb.plot.utils import test_missing
30
30
  from wandb.sdk.lib import telemetry as wb_telemetry
31
31
 
32
32
 
wandb/jupyter.py CHANGED
@@ -288,35 +288,34 @@ def attempt_colab_login(app_url):
288
288
  display.display(
289
289
  display.Javascript(
290
290
  """
291
- window._wandbApiKey = new Promise((resolve, reject) => {
292
- function loadScript(url) {
293
- return new Promise(function(resolve, reject) {
291
+ window._wandbApiKey = new Promise((resolve, reject) => {{
292
+ function loadScript(url) {{
293
+ return new Promise(function(resolve, reject) {{
294
294
  let newScript = document.createElement("script");
295
295
  newScript.onerror = reject;
296
296
  newScript.onload = resolve;
297
297
  document.body.appendChild(newScript);
298
298
  newScript.src = url;
299
- });
300
- }
301
- loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {
299
+ }});
300
+ }}
301
+ loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {{
302
302
  const iframe = document.createElement('iframe')
303
303
  iframe.style.cssText = "width:0;height:0;border:none"
304
304
  document.body.appendChild(iframe)
305
- const handshake = new Postmate({
305
+ const handshake = new Postmate({{
306
306
  container: iframe,
307
- url: '%s/authorize'
308
- });
307
+ url: '{}/authorize'
308
+ }});
309
309
  const timeout = setTimeout(() => reject("Couldn't auto authenticate"), 5000)
310
- handshake.then(function(child) {
311
- child.on('authorize', data => {
310
+ handshake.then(function(child) {{
311
+ child.on('authorize', data => {{
312
312
  clearTimeout(timeout)
313
313
  resolve(data)
314
- });
315
- });
316
- })
317
- });
318
- """
319
- % app_url.replace("http:", "https:")
314
+ }});
315
+ }});
316
+ }})
317
+ }});
318
+ """.format(app_url.replace("http:", "https:"))
320
319
  )
321
320
  )
322
321
  try:
wandb/old/summary.py CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
61
61
  This should only be implemented by the "_root" child class.
62
62
 
63
63
  We pass the child_dict so the item can be set on it or not as
64
- appropriate. Returning None for a nonexistant path wouldn't be
64
+ appropriate. Returning None for a nonexistent path wouldn't be
65
65
  distinguishable from that path being set to the value None.
66
66
  """
67
67
  raise NotImplementedError
@@ -48,7 +48,7 @@ def confusion_matrix(
48
48
 
49
49
  assert (probs is None or preds is None) and not (
50
50
  probs is None and preds is None
51
- ), "Must provide probabilties or predictions but not both to confusion matrix"
51
+ ), "Must provide probabilities or predictions but not both to confusion matrix"
52
52
 
53
53
  if probs is not None:
54
54
  preds = np.argmax(probs, axis=1).tolist()
wandb/plot/pr_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def pr_curve(
wandb/plot/roc_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def roc_curve(
@@ -1,21 +1,9 @@
1
- from collections.abc import Iterable, Sequence
1
+ from typing import Iterable, Sequence
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.sdk.lib import deprecate
6
5
 
7
6
 
8
- def deprecation_notice() -> None:
9
- deprecate.deprecate(
10
- field_name=deprecate.Deprecated.plots,
11
- warning_message=(
12
- "wandb.plots.* functions are deprecated and will be removed in a future release. "
13
- "Please use wandb.plot.* instead."
14
- ),
15
- )
16
-
17
-
18
- # Test assumptions for plotting parameters and datasets
19
7
  def test_missing(**kwargs):
20
8
  np = util.get_module("numpy", required="Logging plots requires numpy")
21
9
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
25
13
  for k, v in kwargs.items():
26
14
  # Missing/empty params/datapoint arrays
27
15
  if v is None:
28
- wandb.termerror("%s is None. Please try again." % (k))
16
+ wandb.termerror("{} is None. Please try again.".format(k))
29
17
  test_passed = False
30
18
  if (k == "X") or (k == "X_test"):
31
19
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
64
52
  )
65
53
  if non_nums > 0:
66
54
  wandb.termerror(
67
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
68
- % (k, k)
55
+ f"{k} contains values that are not numbers. Please vectorize, "
56
+ f"label encode or one hot encode {k} and call the plotting function again."
69
57
  )
70
58
  test_passed = False
71
59
  return test_passed
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
73
61
 
74
62
  def test_fitted(model):
75
63
  np = util.get_module("numpy", required="Logging plots requires numpy")
76
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
77
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
64
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
65
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
78
66
  scikit_utils = util.get_module(
79
67
  "sklearn.utils",
80
68
  required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
@@ -120,7 +108,7 @@ def test_fitted(model):
120
108
 
121
109
 
122
110
  def encode_labels(df):
123
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
111
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
124
112
  preprocessing = util.get_module(
125
113
  "sklearn.preprocessing",
126
114
  "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
@@ -137,7 +125,7 @@ def encode_labels(df):
137
125
  def test_types(**kwargs):
138
126
  np = util.get_module("numpy", required="Logging plots requires numpy")
139
127
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
140
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
128
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
141
129
 
142
130
  base = util.get_module(
143
131
  "sklearn.base",
@@ -171,25 +159,25 @@ def test_types(**kwargs):
171
159
  list,
172
160
  ),
173
161
  ):
174
- wandb.termerror("%s is not an array. Please try again." % (k))
162
+ wandb.termerror("{} is not an array. Please try again.".format(k))
175
163
  test_passed = False
176
164
  # check for classifier types
177
165
  if k == "model":
178
166
  if (not base.is_classifier(v)) and (not base.is_regressor(v)):
179
167
  wandb.termerror(
180
- "%s is not a classifier or regressor. Please try again." % (k)
168
+ "{} is not a classifier or regressor. Please try again.".format(k)
181
169
  )
182
170
  test_passed = False
183
171
  elif k == "clf" or k == "binary_clf":
184
172
  if not (base.is_classifier(v)):
185
- wandb.termerror("%s is not a classifier. Please try again." % (k))
173
+ wandb.termerror("{} is not a classifier. Please try again.".format(k))
186
174
  test_passed = False
187
175
  elif k == "regressor":
188
176
  if not base.is_regressor(v):
189
- wandb.termerror("%s is not a regressor. Please try again." % (k))
177
+ wandb.termerror("{} is not a regressor. Please try again.".format(k))
190
178
  test_passed = False
191
179
  elif k == "clusterer":
192
180
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
193
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
181
+ wandb.termerror("{} is not a clusterer. Please try again.".format(k))
194
182
  test_passed = False
195
183
  return test_passed