wandb 0.17.0rc2__py3-none-macosx_11_0_arm64.whl → 0.17.1__py3-none-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (160) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/apis/importers/internals/internal.py +0 -1
  3. wandb/apis/importers/wandb.py +12 -7
  4. wandb/apis/internal.py +0 -3
  5. wandb/apis/public/api.py +213 -79
  6. wandb/apis/public/artifacts.py +335 -100
  7. wandb/apis/public/files.py +9 -9
  8. wandb/apis/public/jobs.py +16 -4
  9. wandb/apis/public/projects.py +26 -28
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +163 -65
  12. wandb/apis/public/sweeps.py +2 -2
  13. wandb/apis/reports/__init__.py +1 -7
  14. wandb/apis/reports/v1/__init__.py +5 -27
  15. wandb/apis/reports/v2/__init__.py +7 -19
  16. wandb/apis/workspaces/__init__.py +8 -0
  17. wandb/beta/workflows.py +8 -3
  18. wandb/bin/apple_gpu_stats +0 -0
  19. wandb/bin/wandb-core +0 -0
  20. wandb/cli/cli.py +131 -59
  21. wandb/docker/__init__.py +1 -1
  22. wandb/errors/term.py +10 -2
  23. wandb/filesync/step_checksum.py +1 -4
  24. wandb/filesync/step_prepare.py +4 -24
  25. wandb/filesync/step_upload.py +5 -107
  26. wandb/filesync/upload_job.py +0 -76
  27. wandb/integration/gym/__init__.py +35 -15
  28. wandb/integration/openai/fine_tuning.py +21 -3
  29. wandb/integration/prodigy/prodigy.py +1 -1
  30. wandb/jupyter.py +16 -17
  31. wandb/plot/pr_curve.py +2 -1
  32. wandb/plot/roc_curve.py +2 -1
  33. wandb/{plots → plot}/utils.py +13 -25
  34. wandb/proto/v3/wandb_internal_pb2.py +54 -54
  35. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  36. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  37. wandb/proto/v4/wandb_internal_pb2.py +54 -54
  38. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  39. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  40. wandb/proto/v5/wandb_base_pb2.py +30 -0
  41. wandb/proto/v5/wandb_internal_pb2.py +355 -0
  42. wandb/proto/v5/wandb_server_pb2.py +63 -0
  43. wandb/proto/v5/wandb_settings_pb2.py +45 -0
  44. wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
  45. wandb/proto/wandb_base_pb2.py +2 -0
  46. wandb/proto/wandb_deprecated.py +9 -1
  47. wandb/proto/wandb_generate_deprecated.py +34 -0
  48. wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
  49. wandb/proto/wandb_internal_pb2.py +2 -0
  50. wandb/proto/wandb_server_pb2.py +2 -0
  51. wandb/proto/wandb_settings_pb2.py +2 -0
  52. wandb/proto/wandb_telemetry_pb2.py +2 -0
  53. wandb/sdk/artifacts/artifact.py +68 -22
  54. wandb/sdk/artifacts/artifact_manifest.py +1 -1
  55. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
  56. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
  57. wandb/sdk/artifacts/artifact_saver.py +1 -10
  58. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
  59. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  60. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
  62. wandb/sdk/artifacts/storage_policy.py +1 -12
  63. wandb/sdk/data_types/image.py +1 -1
  64. wandb/sdk/data_types/video.py +4 -2
  65. wandb/sdk/interface/interface.py +13 -0
  66. wandb/sdk/interface/interface_shared.py +1 -1
  67. wandb/sdk/internal/file_pusher.py +2 -5
  68. wandb/sdk/internal/file_stream.py +6 -19
  69. wandb/sdk/internal/internal_api.py +148 -136
  70. wandb/sdk/internal/job_builder.py +207 -135
  71. wandb/sdk/internal/progress.py +0 -28
  72. wandb/sdk/internal/sender.py +102 -39
  73. wandb/sdk/internal/settings_static.py +8 -1
  74. wandb/sdk/internal/system/assets/trainium.py +3 -3
  75. wandb/sdk/internal/system/system_info.py +4 -2
  76. wandb/sdk/internal/update.py +1 -1
  77. wandb/sdk/launch/__init__.py +9 -1
  78. wandb/sdk/launch/_launch.py +4 -24
  79. wandb/sdk/launch/_launch_add.py +1 -3
  80. wandb/sdk/launch/_project_spec.py +184 -224
  81. wandb/sdk/launch/agent/agent.py +58 -18
  82. wandb/sdk/launch/agent/config.py +0 -3
  83. wandb/sdk/launch/builder/abstract.py +67 -0
  84. wandb/sdk/launch/builder/build.py +165 -576
  85. wandb/sdk/launch/builder/context_manager.py +235 -0
  86. wandb/sdk/launch/builder/docker_builder.py +7 -23
  87. wandb/sdk/launch/builder/kaniko_builder.py +10 -23
  88. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  89. wandb/sdk/launch/create_job.py +51 -45
  90. wandb/sdk/launch/environment/aws_environment.py +26 -1
  91. wandb/sdk/launch/inputs/files.py +148 -0
  92. wandb/sdk/launch/inputs/internal.py +224 -0
  93. wandb/sdk/launch/inputs/manage.py +95 -0
  94. wandb/sdk/launch/runner/abstract.py +2 -2
  95. wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
  96. wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
  97. wandb/sdk/launch/runner/local_container.py +2 -3
  98. wandb/sdk/launch/runner/local_process.py +8 -29
  99. wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
  100. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  101. wandb/sdk/launch/sweeps/scheduler.py +2 -0
  102. wandb/sdk/launch/sweeps/utils.py +2 -2
  103. wandb/sdk/launch/utils.py +16 -138
  104. wandb/sdk/lib/_settings_toposort_generated.py +2 -5
  105. wandb/sdk/lib/apikey.py +4 -2
  106. wandb/sdk/lib/config_util.py +3 -3
  107. wandb/sdk/lib/proto_util.py +22 -1
  108. wandb/sdk/lib/redirect.py +1 -1
  109. wandb/sdk/service/service.py +2 -1
  110. wandb/sdk/service/streams.py +5 -5
  111. wandb/sdk/wandb_init.py +25 -59
  112. wandb/sdk/wandb_login.py +28 -25
  113. wandb/sdk/wandb_run.py +112 -45
  114. wandb/sdk/wandb_settings.py +33 -64
  115. wandb/sdk/wandb_watch.py +1 -1
  116. wandb/sklearn/plot/classifier.py +4 -6
  117. wandb/sync/sync.py +2 -2
  118. wandb/testing/relay.py +32 -17
  119. wandb/util.py +36 -37
  120. wandb/wandb_agent.py +3 -3
  121. wandb/wandb_controller.py +3 -2
  122. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
  123. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +126 -148
  124. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
  125. wandb/apis/reports/v1/_blocks.py +0 -1406
  126. wandb/apis/reports/v1/_helpers.py +0 -70
  127. wandb/apis/reports/v1/_panels.py +0 -1282
  128. wandb/apis/reports/v1/_templates.py +0 -478
  129. wandb/apis/reports/v1/blocks.py +0 -27
  130. wandb/apis/reports/v1/helpers.py +0 -2
  131. wandb/apis/reports/v1/mutations.py +0 -66
  132. wandb/apis/reports/v1/panels.py +0 -17
  133. wandb/apis/reports/v1/report.py +0 -268
  134. wandb/apis/reports/v1/runset.py +0 -144
  135. wandb/apis/reports/v1/templates.py +0 -7
  136. wandb/apis/reports/v1/util.py +0 -406
  137. wandb/apis/reports/v1/validators.py +0 -131
  138. wandb/apis/reports/v2/blocks.py +0 -25
  139. wandb/apis/reports/v2/expr_parsing.py +0 -257
  140. wandb/apis/reports/v2/gql.py +0 -68
  141. wandb/apis/reports/v2/interface.py +0 -1911
  142. wandb/apis/reports/v2/internal.py +0 -867
  143. wandb/apis/reports/v2/metrics.py +0 -6
  144. wandb/apis/reports/v2/panels.py +0 -15
  145. wandb/catboost/__init__.py +0 -9
  146. wandb/fastai/__init__.py +0 -9
  147. wandb/keras/__init__.py +0 -19
  148. wandb/lightgbm/__init__.py +0 -9
  149. wandb/plots/__init__.py +0 -6
  150. wandb/plots/explain_text.py +0 -36
  151. wandb/plots/heatmap.py +0 -81
  152. wandb/plots/named_entity.py +0 -43
  153. wandb/plots/part_of_speech.py +0 -50
  154. wandb/plots/plot_definitions.py +0 -768
  155. wandb/plots/precision_recall.py +0 -121
  156. wandb/plots/roc.py +0 -103
  157. wandb/sacred/__init__.py +0 -3
  158. wandb/xgboost/__init__.py +0 -9
  159. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
  160. {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  """Batching file prepare requests to our API."""
2
2
 
3
- import asyncio
4
3
  import concurrent.futures
5
4
  import logging
6
5
  import queue
@@ -8,7 +7,6 @@ import sys
8
7
  import threading
9
8
  from typing import (
10
9
  TYPE_CHECKING,
11
- Awaitable,
12
10
  Callable,
13
11
  MutableMapping,
14
12
  MutableSequence,
@@ -43,7 +41,6 @@ if TYPE_CHECKING:
43
41
  PreCommitFn = Callable[[], None]
44
42
  OnRequestFinishFn = Callable[[], None]
45
43
  SaveFn = Callable[["progress.ProgressFn"], bool]
46
- SaveFnAsync = Callable[["progress.ProgressFn"], Awaitable[bool]]
47
44
 
48
45
  logger = logging.getLogger(__name__)
49
46
 
@@ -55,7 +52,6 @@ class RequestUpload(NamedTuple):
55
52
  md5: Optional[str]
56
53
  copied: bool
57
54
  save_fn: Optional[SaveFn]
58
- save_fn_async: Optional[SaveFnAsync]
59
55
  digest: Optional[str]
60
56
 
61
57
 
@@ -78,47 +74,6 @@ class EventJobDone(NamedTuple):
78
74
  Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
79
75
 
80
76
 
81
- class AsyncExecutor:
82
- """Runs async file uploads in a background thread."""
83
-
84
- def __init__(
85
- self,
86
- pool: concurrent.futures.ThreadPoolExecutor,
87
- concurrency_limit: Optional[int],
88
- ) -> None:
89
- self.loop = asyncio.new_event_loop()
90
- self.loop.set_default_executor(pool)
91
- self.loop_thread = threading.Thread(
92
- target=self.loop.run_forever,
93
- daemon=True,
94
- name="wandb-upload-async",
95
- )
96
-
97
- self.concurrency_limiter = asyncio.Semaphore(
98
- value=concurrency_limit or 128,
99
- # Before Python 3.10: if we don't set `loop=loop`,
100
- # then the Semaphore will bind to the wrong event loop,
101
- # causing errors when a coroutine tries to wait for it;
102
- # see https://pastebin.com/XcrS9suX .
103
- # After 3.10: the `loop` argument doesn't exist.
104
- # So we need to only conditionally pass in `loop`.
105
- **({} if sys.version_info >= (3, 10) else {"loop": self.loop}),
106
- )
107
-
108
- def start(self) -> None:
109
- self.loop_thread.start()
110
-
111
- def stop(self) -> None:
112
- self.loop.call_soon_threadsafe(self.loop.stop)
113
-
114
- def submit(self, coro: Awaitable[None]) -> None:
115
- async def run_with_limiter() -> None:
116
- async with self.concurrency_limiter:
117
- await coro
118
-
119
- asyncio.run_coroutine_threadsafe(run_with_limiter(), self.loop)
120
-
121
-
122
77
  class StepUpload:
123
78
  def __init__(
124
79
  self,
@@ -142,15 +97,6 @@ class StepUpload:
142
97
  max_workers=max_threads,
143
98
  )
144
99
 
145
- self._async_executor = (
146
- AsyncExecutor(
147
- pool=self._pool,
148
- concurrency_limit=settings._async_upload_concurrency_limit,
149
- )
150
- if settings is not None and settings._async_upload_concurrency_limit
151
- else None
152
- )
153
-
154
100
  # Indexed by files' `save_name`'s, which are their ID's in the Run.
155
101
  self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
156
102
  self._pending_jobs: MutableSequence[RequestUpload] = []
@@ -186,8 +132,6 @@ class StepUpload:
186
132
  elif not self._running_jobs:
187
133
  # Queue was empty and no jobs left.
188
134
  self._pool.shutdown(wait=False)
189
- if self._async_executor:
190
- self._async_executor.stop()
191
135
  if finish_callback:
192
136
  finish_callback()
193
137
  break
@@ -235,7 +179,7 @@ class StepUpload:
235
179
  self._artifacts[event.artifact_id]["pending_count"] += 1
236
180
  self._start_upload_job(event)
237
181
  else:
238
- raise Exception("Programming error: unhandled event: %s" % str(event))
182
+ raise Exception("Programming error: unhandled event: {}".format(str(event)))
239
183
 
240
184
  def _start_upload_job(self, event: RequestUpload) -> None:
241
185
  # Operations on a single backend file must be serialized. if
@@ -245,18 +189,9 @@ class StepUpload:
245
189
  self._pending_jobs.append(event)
246
190
  return
247
191
 
248
- if self._async_executor and event.save_fn_async is not None:
249
- # (The `and save_fn_async is not None` is because the async code path
250
- # doesn't support all uploads yet: even if the user has requested async,
251
- # we sometimes need to use the sync method instead.)
252
- self._spawn_upload_async(
253
- event,
254
- async_executor=self._async_executor,
255
- )
256
- else:
257
- self._spawn_upload_sync(event)
192
+ self._spawn_upload(event)
258
193
 
259
- def _spawn_upload_sync(self, event: RequestUpload) -> None:
194
+ def _spawn_upload(self, event: RequestUpload) -> None:
260
195
  """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
261
196
 
262
197
  Context: it's important that, whenever we add an entry to `self._running_jobs`,
@@ -285,35 +220,13 @@ class StepUpload:
285
220
 
286
221
  def run_and_notify() -> None:
287
222
  try:
288
- self._do_upload_sync(event)
223
+ self._do_upload(event)
289
224
  finally:
290
225
  self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
291
226
 
292
227
  self._pool.submit(run_and_notify)
293
228
 
294
- def _spawn_upload_async(
295
- self,
296
- event: RequestUpload,
297
- async_executor: AsyncExecutor,
298
- ) -> None:
299
- """Equivalent to _spawn_upload_sync, but uses the async event loop instead of a thread, and requires `event.save_fn_async`.
300
-
301
- Raises:
302
- AssertionError: if `event.save_fn_async` is None.
303
- """
304
- assert event.save_fn_async is not None
305
-
306
- self._running_jobs[event.save_name] = event
307
-
308
- async def run_and_notify() -> None:
309
- try:
310
- await self._do_upload_async(event)
311
- finally:
312
- self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
313
-
314
- async_executor.submit(run_and_notify())
315
-
316
- def _do_upload_sync(self, event: RequestUpload) -> None:
229
+ def _do_upload(self, event: RequestUpload) -> None:
317
230
  job = upload_job.UploadJob(
318
231
  self._stats,
319
232
  self._api,
@@ -329,19 +242,6 @@ class StepUpload:
329
242
  )
330
243
  job.run()
331
244
 
332
- async def _do_upload_async(self, event: RequestUpload) -> None:
333
- """Upload a file and returns when it's done. Requires `event.save_fn_async`."""
334
- assert event.save_fn_async is not None
335
- job = upload_job.UploadJobAsync(
336
- stats=self._stats,
337
- api=self._api,
338
- file_stream=self._file_stream,
339
- silent=self.silent,
340
- request=event,
341
- save_fn_async=event.save_fn_async,
342
- )
343
- await job.run()
344
-
345
245
  def _init_artifact(self, artifact_id: str) -> None:
346
246
  self._artifacts[artifact_id] = {
347
247
  "finalize": False,
@@ -385,8 +285,6 @@ class StepUpload:
385
285
 
386
286
  def start(self) -> None:
387
287
  self._thread.start()
388
- if self._async_executor:
389
- self._async_executor.start()
390
288
 
391
289
  def is_alive(self) -> bool:
392
290
  return self._thread.is_alive()
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import logging
3
2
  import os
4
3
  from typing import TYPE_CHECKING, Optional
@@ -141,78 +140,3 @@ class UploadJob:
141
140
 
142
141
  def progress(self, total_bytes: int) -> None:
143
142
  self._stats.update_uploaded_file(self.save_name, total_bytes)
144
-
145
-
146
- class UploadJobAsync:
147
- """Roughly an async equivalent of UploadJob.
148
-
149
- Important differences:
150
- - `run` is a coroutine
151
- - If `run()` fails, it falls back to the synchronous UploadJob
152
- """
153
-
154
- def __init__(
155
- self,
156
- stats: "stats.Stats",
157
- api: "internal_api.Api",
158
- file_stream: "file_stream.FileStreamApi",
159
- silent: bool,
160
- request: "step_upload.RequestUpload",
161
- save_fn_async: "step_upload.SaveFnAsync",
162
- ) -> None:
163
- self._stats = stats
164
- self._api = api
165
- self._file_stream = file_stream
166
- self.silent = silent
167
- self._request = request
168
- self._save_fn_async = save_fn_async
169
-
170
- async def run(self) -> None:
171
- try:
172
- deduped = await self._save_fn_async(
173
- lambda _, t: self._stats.update_uploaded_file(self._request.path, t)
174
- )
175
- except Exception as e:
176
- # Async uploads aren't yet (2023-01) battle-tested.
177
- # Fall back to the "normal" synchronous upload.
178
- loop = asyncio.get_event_loop()
179
- logger.exception("async upload failed", exc_info=e)
180
- loop.run_in_executor(None, wandb._sentry.exception, e)
181
- wandb.termwarn(
182
- "Async file upload failed; falling back to sync", repeat=False
183
- )
184
- sync_job = UploadJob(
185
- self._stats,
186
- self._api,
187
- self._file_stream,
188
- self.silent,
189
- self._request.save_name,
190
- self._request.path,
191
- self._request.artifact_id,
192
- self._request.md5,
193
- self._request.copied,
194
- self._request.save_fn,
195
- self._request.digest,
196
- )
197
-
198
- await loop.run_in_executor(None, sync_job.run)
199
- else:
200
- self._file_stream.push_success(
201
- self._request.artifact_id, # type: ignore
202
- self._request.save_name,
203
- )
204
-
205
- if deduped:
206
- logger.info("Skipped uploading %s", self._request.path)
207
- self._stats.set_file_deduped(self._request.path)
208
- else:
209
- logger.info("Uploaded file %s", self._request.path)
210
- finally:
211
- # If we fell back to the sync impl, the file will have already been deleted.
212
- # Doesn't matter, we only try to delete it if it exists.
213
- if self._request.copied:
214
- try:
215
- os.remove(self._request.path)
216
- except OSError:
217
- # The file has already been deleted, we don't have permissions, or something else we can't fix.
218
- pass
@@ -12,6 +12,8 @@ else:
12
12
 
13
13
 
14
14
  _gym_version_lt_0_26: Optional[bool] = None
15
+ _gymnasium_version_lt_1_0_0: Optional[bool] = None
16
+
15
17
  _required_error_msg = (
16
18
  "Couldn't import the gymnasium python package, "
17
19
  "install with `pip install gymnasium`"
@@ -35,14 +37,10 @@ def monitor():
35
37
  if gym_lib is None:
36
38
  raise wandb.Error(_required_error_msg)
37
39
 
38
- vcr = wandb.util.get_module(
39
- f"{gym_lib}.wrappers.monitoring.video_recorder",
40
- required=_required_error_msg,
41
- )
42
-
43
40
  global _gym_version_lt_0_26
41
+ global _gymnasium_version_lt_1_0_0
44
42
 
45
- if _gym_version_lt_0_26 is None:
43
+ if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None:
46
44
  if gym_lib == "gym":
47
45
  import gym
48
46
  else:
@@ -50,15 +48,31 @@ def monitor():
50
48
 
51
49
  from wandb.util import parse_version
52
50
 
53
- if parse_version(gym.__version__) < parse_version("0.26.0"):
54
- _gym_version_lt_0_26 = True
51
+ gym_lib_version = parse_version(gym.__version__)
52
+ _gym_version_lt_0_26 = gym_lib_version < parse_version("0.26.0")
53
+ _gymnasium_version_lt_1_0_0 = gym_lib_version < parse_version("1.0.0a1")
54
+
55
+ path = "path" # Default path
56
+ if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
57
+ vcr_recorder_attribute = "RecordVideo"
58
+ wrappers = wandb.util.get_module(
59
+ f"{gym_lib}.wrappers",
60
+ required=_required_error_msg,
61
+ )
62
+ recorder = getattr(wrappers, vcr_recorder_attribute)
63
+ else:
64
+ vcr = wandb.util.get_module(
65
+ f"{gym_lib}.wrappers.monitoring.video_recorder",
66
+ required=_required_error_msg,
67
+ )
68
+ # Breaking change in gym 0.26.0
69
+ if _gym_version_lt_0_26:
70
+ vcr_recorder_attribute = "ImageEncoder"
71
+ recorder = getattr(vcr, vcr_recorder_attribute)
72
+ path = "output_path" # Override path for older gym versions
55
73
  else:
56
- _gym_version_lt_0_26 = False
57
-
58
- # breaking change in gym 0.26.0
59
- vcr_recorder_attribute = "ImageEncoder" if _gym_version_lt_0_26 else "VideoRecorder"
60
- recorder = getattr(vcr, vcr_recorder_attribute)
61
- path = "output_path" if _gym_version_lt_0_26 else "path"
74
+ vcr_recorder_attribute = "VideoRecorder"
75
+ recorder = getattr(vcr, vcr_recorder_attribute)
62
76
 
63
77
  recorder.orig_close = recorder.close
64
78
 
@@ -77,9 +91,15 @@ def monitor():
77
91
  if not _gym_version_lt_0_26:
78
92
  recorder.__del__ = del_
79
93
  recorder.close = close
94
+
95
+ if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
96
+ wrapper_name = vcr_recorder_attribute
97
+ else:
98
+ wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}"
99
+
80
100
  wandb.patched["gym"].append(
81
101
  [
82
- f"{gym_lib}.wrappers.monitoring.video_recorder.{vcr_recorder_attribute}",
102
+ f"{gym_lib}.wrappers.{wrapper_name}",
83
103
  "close",
84
104
  ]
85
105
  )
@@ -65,6 +65,8 @@ class WandbLogger:
65
65
  overwrite: bool = False,
66
66
  wait_for_job_success: bool = True,
67
67
  log_datasets: bool = True,
68
+ model_artifact_name: str = "model-metadata",
69
+ model_artifact_type: str = "model",
68
70
  **kwargs_wandb_init: Dict[str, Any],
69
71
  ) -> str:
70
72
  """Sync fine-tunes to Weights & Biases.
@@ -76,6 +78,8 @@ class WandbLogger:
76
78
  :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
77
79
  :param overwrite: Forces logging and overwrite existing wandb run of the same fine-tune.
78
80
  :param wait_for_job_success: Waits for the fine-tune to be complete and then log metrics to W&B. By default, it is True.
81
+ :param model_artifact_name: Name of the model artifact that is logged
82
+ :param model_artifact_type: Type of the model artifact that is logged
79
83
  """
80
84
  if openai_client is None:
81
85
  openai_client = OpenAI()
@@ -157,6 +161,8 @@ class WandbLogger:
157
161
  overwrite,
158
162
  show_individual_warnings,
159
163
  log_datasets,
164
+ model_artifact_name,
165
+ model_artifact_type,
160
166
  **kwargs_wandb_init,
161
167
  )
162
168
 
@@ -201,6 +207,8 @@ class WandbLogger:
201
207
  overwrite: bool,
202
208
  show_individual_warnings: bool,
203
209
  log_datasets: bool,
210
+ model_artifact_name: str,
211
+ model_artifact_type: str,
204
212
  **kwargs_wandb_init: Dict[str, Any],
205
213
  ):
206
214
  fine_tune_id = fine_tune.id
@@ -244,7 +252,15 @@ class WandbLogger:
244
252
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
245
253
 
246
254
  # training/validation files and fine-tune details
247
- cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
255
+ cls._log_artifacts(
256
+ fine_tune,
257
+ project,
258
+ entity,
259
+ log_datasets,
260
+ overwrite,
261
+ model_artifact_name,
262
+ model_artifact_type,
263
+ )
248
264
 
249
265
  # mark run as complete
250
266
  cls._run.summary["status"] = "succeeded"
@@ -341,6 +357,8 @@ class WandbLogger:
341
357
  entity: Optional[str],
342
358
  log_datasets: bool,
343
359
  overwrite: bool,
360
+ model_artifact_name: str,
361
+ model_artifact_type: str,
344
362
  ) -> None:
345
363
  if log_datasets:
346
364
  wandb.termlog("Logging training/validation files...")
@@ -361,8 +379,8 @@ class WandbLogger:
361
379
  # fine-tune details
362
380
  fine_tune_id = fine_tune.id
363
381
  artifact = wandb.Artifact(
364
- "model_metadata",
365
- type="model",
382
+ model_artifact_name,
383
+ type=model_artifact_type,
366
384
  metadata=dict(fine_tune),
367
385
  )
368
386
 
@@ -26,7 +26,7 @@ from PIL import Image
26
26
 
27
27
  import wandb
28
28
  from wandb import util
29
- from wandb.plots.utils import test_missing
29
+ from wandb.plot.utils import test_missing
30
30
  from wandb.sdk.lib import telemetry as wb_telemetry
31
31
 
32
32
 
wandb/jupyter.py CHANGED
@@ -288,35 +288,34 @@ def attempt_colab_login(app_url):
288
288
  display.display(
289
289
  display.Javascript(
290
290
  """
291
- window._wandbApiKey = new Promise((resolve, reject) => {
292
- function loadScript(url) {
293
- return new Promise(function(resolve, reject) {
291
+ window._wandbApiKey = new Promise((resolve, reject) => {{
292
+ function loadScript(url) {{
293
+ return new Promise(function(resolve, reject) {{
294
294
  let newScript = document.createElement("script");
295
295
  newScript.onerror = reject;
296
296
  newScript.onload = resolve;
297
297
  document.body.appendChild(newScript);
298
298
  newScript.src = url;
299
- });
300
- }
301
- loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {
299
+ }});
300
+ }}
301
+ loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {{
302
302
  const iframe = document.createElement('iframe')
303
303
  iframe.style.cssText = "width:0;height:0;border:none"
304
304
  document.body.appendChild(iframe)
305
- const handshake = new Postmate({
305
+ const handshake = new Postmate({{
306
306
  container: iframe,
307
- url: '%s/authorize'
308
- });
307
+ url: '{}/authorize'
308
+ }});
309
309
  const timeout = setTimeout(() => reject("Couldn't auto authenticate"), 5000)
310
- handshake.then(function(child) {
311
- child.on('authorize', data => {
310
+ handshake.then(function(child) {{
311
+ child.on('authorize', data => {{
312
312
  clearTimeout(timeout)
313
313
  resolve(data)
314
- });
315
- });
316
- })
317
- });
318
- """
319
- % app_url.replace("http:", "https:")
314
+ }});
315
+ }});
316
+ }})
317
+ }});
318
+ """.format(app_url.replace("http:", "https:"))
320
319
  )
321
320
  )
322
321
  try:
wandb/plot/pr_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def pr_curve(
wandb/plot/roc_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def roc_curve(
@@ -1,21 +1,9 @@
1
- from collections.abc import Iterable, Sequence
1
+ from typing import Iterable, Sequence
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.sdk.lib import deprecate
6
5
 
7
6
 
8
- def deprecation_notice() -> None:
9
- deprecate.deprecate(
10
- field_name=deprecate.Deprecated.plots,
11
- warning_message=(
12
- "wandb.plots.* functions are deprecated and will be removed in a future release. "
13
- "Please use wandb.plot.* instead."
14
- ),
15
- )
16
-
17
-
18
- # Test assumptions for plotting parameters and datasets
19
7
  def test_missing(**kwargs):
20
8
  np = util.get_module("numpy", required="Logging plots requires numpy")
21
9
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
25
13
  for k, v in kwargs.items():
26
14
  # Missing/empty params/datapoint arrays
27
15
  if v is None:
28
- wandb.termerror("%s is None. Please try again." % (k))
16
+ wandb.termerror("{} is None. Please try again.".format(k))
29
17
  test_passed = False
30
18
  if (k == "X") or (k == "X_test"):
31
19
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
64
52
  )
65
53
  if non_nums > 0:
66
54
  wandb.termerror(
67
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
68
- % (k, k)
55
+ f"{k} contains values that are not numbers. Please vectorize, "
56
+ f"label encode or one hot encode {k} and call the plotting function again."
69
57
  )
70
58
  test_passed = False
71
59
  return test_passed
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
73
61
 
74
62
  def test_fitted(model):
75
63
  np = util.get_module("numpy", required="Logging plots requires numpy")
76
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
77
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
64
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
65
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
78
66
  scikit_utils = util.get_module(
79
67
  "sklearn.utils",
80
68
  required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
@@ -120,7 +108,7 @@ def test_fitted(model):
120
108
 
121
109
 
122
110
  def encode_labels(df):
123
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
111
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
124
112
  preprocessing = util.get_module(
125
113
  "sklearn.preprocessing",
126
114
  "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
@@ -137,7 +125,7 @@ def encode_labels(df):
137
125
  def test_types(**kwargs):
138
126
  np = util.get_module("numpy", required="Logging plots requires numpy")
139
127
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
140
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
128
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
141
129
 
142
130
  base = util.get_module(
143
131
  "sklearn.base",
@@ -171,25 +159,25 @@ def test_types(**kwargs):
171
159
  list,
172
160
  ),
173
161
  ):
174
- wandb.termerror("%s is not an array. Please try again." % (k))
162
+ wandb.termerror("{} is not an array. Please try again.".format(k))
175
163
  test_passed = False
176
164
  # check for classifier types
177
165
  if k == "model":
178
166
  if (not base.is_classifier(v)) and (not base.is_regressor(v)):
179
167
  wandb.termerror(
180
- "%s is not a classifier or regressor. Please try again." % (k)
168
+ "{} is not a classifier or regressor. Please try again.".format(k)
181
169
  )
182
170
  test_passed = False
183
171
  elif k == "clf" or k == "binary_clf":
184
172
  if not (base.is_classifier(v)):
185
- wandb.termerror("%s is not a classifier. Please try again." % (k))
173
+ wandb.termerror("{} is not a classifier. Please try again.".format(k))
186
174
  test_passed = False
187
175
  elif k == "regressor":
188
176
  if not base.is_regressor(v):
189
- wandb.termerror("%s is not a regressor. Please try again." % (k))
177
+ wandb.termerror("{} is not a regressor. Please try again.".format(k))
190
178
  test_passed = False
191
179
  elif k == "clusterer":
192
180
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
193
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
181
+ wandb.termerror("{} is not a clusterer. Please try again.".format(k))
194
182
  test_passed = False
195
183
  return test_passed