wandb 0.17.0rc2__py3-none-macosx_11_0_arm64.whl → 0.17.1__py3-none-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -2
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/wandb.py +12 -7
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +213 -79
- wandb/apis/public/artifacts.py +335 -100
- wandb/apis/public/files.py +9 -9
- wandb/apis/public/jobs.py +16 -4
- wandb/apis/public/projects.py +26 -28
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +163 -65
- wandb/apis/public/sweeps.py +2 -2
- wandb/apis/reports/__init__.py +1 -7
- wandb/apis/reports/v1/__init__.py +5 -27
- wandb/apis/reports/v2/__init__.py +7 -19
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +8 -3
- wandb/bin/apple_gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +131 -59
- wandb/docker/__init__.py +1 -1
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +5 -107
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/gym/__init__.py +35 -15
- wandb/integration/openai/fine_tuning.py +21 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/jupyter.py +16 -17
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +54 -54
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +54 -54
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_base_pb2.py +30 -0
- wandb/proto/v5/wandb_internal_pb2.py +355 -0
- wandb/proto/v5/wandb_server_pb2.py +63 -0
- wandb/proto/v5/wandb_settings_pb2.py +45 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +41 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +9 -1
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/{wandb_internal_codegen.py → wandb_generate_proto.py} +1 -35
- wandb/proto/wandb_internal_pb2.py +2 -0
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/artifact.py +68 -22
- wandb/sdk/artifacts/artifact_manifest.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -1
- wandb/sdk/artifacts/artifact_saver.py +1 -10
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +6 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +2 -42
- wandb/sdk/artifacts/storage_policy.py +1 -12
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +4 -2
- wandb/sdk/interface/interface.py +13 -0
- wandb/sdk/interface/interface_shared.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +6 -19
- wandb/sdk/internal/internal_api.py +148 -136
- wandb/sdk/internal/job_builder.py +207 -135
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/sender.py +102 -39
- wandb/sdk/internal/settings_static.py +8 -1
- wandb/sdk/internal/system/assets/trainium.py +3 -3
- wandb/sdk/internal/system/system_info.py +4 -2
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +184 -224
- wandb/sdk/launch/agent/agent.py +58 -18
- wandb/sdk/launch/agent/config.py +0 -3
- wandb/sdk/launch/builder/abstract.py +67 -0
- wandb/sdk/launch/builder/build.py +165 -576
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +7 -23
- wandb/sdk/launch/builder/kaniko_builder.py +10 -23
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +51 -45
- wandb/sdk/launch/environment/aws_environment.py +26 -1
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +224 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/runner/abstract.py +2 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +45 -12
- wandb/sdk/launch/runner/kubernetes_runner.py +6 -8
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +20 -14
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/sweeps/utils.py +2 -2
- wandb/sdk/launch/utils.py +16 -138
- wandb/sdk/lib/_settings_toposort_generated.py +2 -5
- wandb/sdk/lib/apikey.py +4 -2
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/proto_util.py +22 -1
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/service/service.py +2 -1
- wandb/sdk/service/streams.py +5 -5
- wandb/sdk/wandb_init.py +25 -59
- wandb/sdk/wandb_login.py +28 -25
- wandb/sdk/wandb_run.py +112 -45
- wandb/sdk/wandb_settings.py +33 -64
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/plot/classifier.py +4 -6
- wandb/sync/sync.py +2 -2
- wandb/testing/relay.py +32 -17
- wandb/util.py +36 -37
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +3 -2
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/METADATA +7 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/RECORD +126 -148
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/WHEEL +1 -1
- wandb/apis/reports/v1/_blocks.py +0 -1406
- wandb/apis/reports/v1/_helpers.py +0 -70
- wandb/apis/reports/v1/_panels.py +0 -1282
- wandb/apis/reports/v1/_templates.py +0 -478
- wandb/apis/reports/v1/blocks.py +0 -27
- wandb/apis/reports/v1/helpers.py +0 -2
- wandb/apis/reports/v1/mutations.py +0 -66
- wandb/apis/reports/v1/panels.py +0 -17
- wandb/apis/reports/v1/report.py +0 -268
- wandb/apis/reports/v1/runset.py +0 -144
- wandb/apis/reports/v1/templates.py +0 -7
- wandb/apis/reports/v1/util.py +0 -406
- wandb/apis/reports/v1/validators.py +0 -131
- wandb/apis/reports/v2/blocks.py +0 -25
- wandb/apis/reports/v2/expr_parsing.py +0 -257
- wandb/apis/reports/v2/gql.py +0 -68
- wandb/apis/reports/v2/interface.py +0 -1911
- wandb/apis/reports/v2/internal.py +0 -867
- wandb/apis/reports/v2/metrics.py +0 -6
- wandb/apis/reports/v2/panels.py +0 -15
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -19
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.0rc2.dist-info → wandb-0.17.1.dist-info}/licenses/LICENSE +0 -0
wandb/filesync/step_upload.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Batching file prepare requests to our API."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import concurrent.futures
|
5
4
|
import logging
|
6
5
|
import queue
|
@@ -8,7 +7,6 @@ import sys
|
|
8
7
|
import threading
|
9
8
|
from typing import (
|
10
9
|
TYPE_CHECKING,
|
11
|
-
Awaitable,
|
12
10
|
Callable,
|
13
11
|
MutableMapping,
|
14
12
|
MutableSequence,
|
@@ -43,7 +41,6 @@ if TYPE_CHECKING:
|
|
43
41
|
PreCommitFn = Callable[[], None]
|
44
42
|
OnRequestFinishFn = Callable[[], None]
|
45
43
|
SaveFn = Callable[["progress.ProgressFn"], bool]
|
46
|
-
SaveFnAsync = Callable[["progress.ProgressFn"], Awaitable[bool]]
|
47
44
|
|
48
45
|
logger = logging.getLogger(__name__)
|
49
46
|
|
@@ -55,7 +52,6 @@ class RequestUpload(NamedTuple):
|
|
55
52
|
md5: Optional[str]
|
56
53
|
copied: bool
|
57
54
|
save_fn: Optional[SaveFn]
|
58
|
-
save_fn_async: Optional[SaveFnAsync]
|
59
55
|
digest: Optional[str]
|
60
56
|
|
61
57
|
|
@@ -78,47 +74,6 @@ class EventJobDone(NamedTuple):
|
|
78
74
|
Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
|
79
75
|
|
80
76
|
|
81
|
-
class AsyncExecutor:
|
82
|
-
"""Runs async file uploads in a background thread."""
|
83
|
-
|
84
|
-
def __init__(
|
85
|
-
self,
|
86
|
-
pool: concurrent.futures.ThreadPoolExecutor,
|
87
|
-
concurrency_limit: Optional[int],
|
88
|
-
) -> None:
|
89
|
-
self.loop = asyncio.new_event_loop()
|
90
|
-
self.loop.set_default_executor(pool)
|
91
|
-
self.loop_thread = threading.Thread(
|
92
|
-
target=self.loop.run_forever,
|
93
|
-
daemon=True,
|
94
|
-
name="wandb-upload-async",
|
95
|
-
)
|
96
|
-
|
97
|
-
self.concurrency_limiter = asyncio.Semaphore(
|
98
|
-
value=concurrency_limit or 128,
|
99
|
-
# Before Python 3.10: if we don't set `loop=loop`,
|
100
|
-
# then the Semaphore will bind to the wrong event loop,
|
101
|
-
# causing errors when a coroutine tries to wait for it;
|
102
|
-
# see https://pastebin.com/XcrS9suX .
|
103
|
-
# After 3.10: the `loop` argument doesn't exist.
|
104
|
-
# So we need to only conditionally pass in `loop`.
|
105
|
-
**({} if sys.version_info >= (3, 10) else {"loop": self.loop}),
|
106
|
-
)
|
107
|
-
|
108
|
-
def start(self) -> None:
|
109
|
-
self.loop_thread.start()
|
110
|
-
|
111
|
-
def stop(self) -> None:
|
112
|
-
self.loop.call_soon_threadsafe(self.loop.stop)
|
113
|
-
|
114
|
-
def submit(self, coro: Awaitable[None]) -> None:
|
115
|
-
async def run_with_limiter() -> None:
|
116
|
-
async with self.concurrency_limiter:
|
117
|
-
await coro
|
118
|
-
|
119
|
-
asyncio.run_coroutine_threadsafe(run_with_limiter(), self.loop)
|
120
|
-
|
121
|
-
|
122
77
|
class StepUpload:
|
123
78
|
def __init__(
|
124
79
|
self,
|
@@ -142,15 +97,6 @@ class StepUpload:
|
|
142
97
|
max_workers=max_threads,
|
143
98
|
)
|
144
99
|
|
145
|
-
self._async_executor = (
|
146
|
-
AsyncExecutor(
|
147
|
-
pool=self._pool,
|
148
|
-
concurrency_limit=settings._async_upload_concurrency_limit,
|
149
|
-
)
|
150
|
-
if settings is not None and settings._async_upload_concurrency_limit
|
151
|
-
else None
|
152
|
-
)
|
153
|
-
|
154
100
|
# Indexed by files' `save_name`'s, which are their ID's in the Run.
|
155
101
|
self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
|
156
102
|
self._pending_jobs: MutableSequence[RequestUpload] = []
|
@@ -186,8 +132,6 @@ class StepUpload:
|
|
186
132
|
elif not self._running_jobs:
|
187
133
|
# Queue was empty and no jobs left.
|
188
134
|
self._pool.shutdown(wait=False)
|
189
|
-
if self._async_executor:
|
190
|
-
self._async_executor.stop()
|
191
135
|
if finish_callback:
|
192
136
|
finish_callback()
|
193
137
|
break
|
@@ -235,7 +179,7 @@ class StepUpload:
|
|
235
179
|
self._artifacts[event.artifact_id]["pending_count"] += 1
|
236
180
|
self._start_upload_job(event)
|
237
181
|
else:
|
238
|
-
raise Exception("Programming error: unhandled event:
|
182
|
+
raise Exception("Programming error: unhandled event: {}".format(str(event)))
|
239
183
|
|
240
184
|
def _start_upload_job(self, event: RequestUpload) -> None:
|
241
185
|
# Operations on a single backend file must be serialized. if
|
@@ -245,18 +189,9 @@ class StepUpload:
|
|
245
189
|
self._pending_jobs.append(event)
|
246
190
|
return
|
247
191
|
|
248
|
-
|
249
|
-
# (The `and save_fn_async is not None` is because the async code path
|
250
|
-
# doesn't support all uploads yet: even if the user has requested async,
|
251
|
-
# we sometimes need to use the sync method instead.)
|
252
|
-
self._spawn_upload_async(
|
253
|
-
event,
|
254
|
-
async_executor=self._async_executor,
|
255
|
-
)
|
256
|
-
else:
|
257
|
-
self._spawn_upload_sync(event)
|
192
|
+
self._spawn_upload(event)
|
258
193
|
|
259
|
-
def
|
194
|
+
def _spawn_upload(self, event: RequestUpload) -> None:
|
260
195
|
"""Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
|
261
196
|
|
262
197
|
Context: it's important that, whenever we add an entry to `self._running_jobs`,
|
@@ -285,35 +220,13 @@ class StepUpload:
|
|
285
220
|
|
286
221
|
def run_and_notify() -> None:
|
287
222
|
try:
|
288
|
-
self.
|
223
|
+
self._do_upload(event)
|
289
224
|
finally:
|
290
225
|
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
291
226
|
|
292
227
|
self._pool.submit(run_and_notify)
|
293
228
|
|
294
|
-
def
|
295
|
-
self,
|
296
|
-
event: RequestUpload,
|
297
|
-
async_executor: AsyncExecutor,
|
298
|
-
) -> None:
|
299
|
-
"""Equivalent to _spawn_upload_sync, but uses the async event loop instead of a thread, and requires `event.save_fn_async`.
|
300
|
-
|
301
|
-
Raises:
|
302
|
-
AssertionError: if `event.save_fn_async` is None.
|
303
|
-
"""
|
304
|
-
assert event.save_fn_async is not None
|
305
|
-
|
306
|
-
self._running_jobs[event.save_name] = event
|
307
|
-
|
308
|
-
async def run_and_notify() -> None:
|
309
|
-
try:
|
310
|
-
await self._do_upload_async(event)
|
311
|
-
finally:
|
312
|
-
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
313
|
-
|
314
|
-
async_executor.submit(run_and_notify())
|
315
|
-
|
316
|
-
def _do_upload_sync(self, event: RequestUpload) -> None:
|
229
|
+
def _do_upload(self, event: RequestUpload) -> None:
|
317
230
|
job = upload_job.UploadJob(
|
318
231
|
self._stats,
|
319
232
|
self._api,
|
@@ -329,19 +242,6 @@ class StepUpload:
|
|
329
242
|
)
|
330
243
|
job.run()
|
331
244
|
|
332
|
-
async def _do_upload_async(self, event: RequestUpload) -> None:
|
333
|
-
"""Upload a file and returns when it's done. Requires `event.save_fn_async`."""
|
334
|
-
assert event.save_fn_async is not None
|
335
|
-
job = upload_job.UploadJobAsync(
|
336
|
-
stats=self._stats,
|
337
|
-
api=self._api,
|
338
|
-
file_stream=self._file_stream,
|
339
|
-
silent=self.silent,
|
340
|
-
request=event,
|
341
|
-
save_fn_async=event.save_fn_async,
|
342
|
-
)
|
343
|
-
await job.run()
|
344
|
-
|
345
245
|
def _init_artifact(self, artifact_id: str) -> None:
|
346
246
|
self._artifacts[artifact_id] = {
|
347
247
|
"finalize": False,
|
@@ -385,8 +285,6 @@ class StepUpload:
|
|
385
285
|
|
386
286
|
def start(self) -> None:
|
387
287
|
self._thread.start()
|
388
|
-
if self._async_executor:
|
389
|
-
self._async_executor.start()
|
390
288
|
|
391
289
|
def is_alive(self) -> bool:
|
392
290
|
return self._thread.is_alive()
|
wandb/filesync/upload_job.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import logging
|
3
2
|
import os
|
4
3
|
from typing import TYPE_CHECKING, Optional
|
@@ -141,78 +140,3 @@ class UploadJob:
|
|
141
140
|
|
142
141
|
def progress(self, total_bytes: int) -> None:
|
143
142
|
self._stats.update_uploaded_file(self.save_name, total_bytes)
|
144
|
-
|
145
|
-
|
146
|
-
class UploadJobAsync:
|
147
|
-
"""Roughly an async equivalent of UploadJob.
|
148
|
-
|
149
|
-
Important differences:
|
150
|
-
- `run` is a coroutine
|
151
|
-
- If `run()` fails, it falls back to the synchronous UploadJob
|
152
|
-
"""
|
153
|
-
|
154
|
-
def __init__(
|
155
|
-
self,
|
156
|
-
stats: "stats.Stats",
|
157
|
-
api: "internal_api.Api",
|
158
|
-
file_stream: "file_stream.FileStreamApi",
|
159
|
-
silent: bool,
|
160
|
-
request: "step_upload.RequestUpload",
|
161
|
-
save_fn_async: "step_upload.SaveFnAsync",
|
162
|
-
) -> None:
|
163
|
-
self._stats = stats
|
164
|
-
self._api = api
|
165
|
-
self._file_stream = file_stream
|
166
|
-
self.silent = silent
|
167
|
-
self._request = request
|
168
|
-
self._save_fn_async = save_fn_async
|
169
|
-
|
170
|
-
async def run(self) -> None:
|
171
|
-
try:
|
172
|
-
deduped = await self._save_fn_async(
|
173
|
-
lambda _, t: self._stats.update_uploaded_file(self._request.path, t)
|
174
|
-
)
|
175
|
-
except Exception as e:
|
176
|
-
# Async uploads aren't yet (2023-01) battle-tested.
|
177
|
-
# Fall back to the "normal" synchronous upload.
|
178
|
-
loop = asyncio.get_event_loop()
|
179
|
-
logger.exception("async upload failed", exc_info=e)
|
180
|
-
loop.run_in_executor(None, wandb._sentry.exception, e)
|
181
|
-
wandb.termwarn(
|
182
|
-
"Async file upload failed; falling back to sync", repeat=False
|
183
|
-
)
|
184
|
-
sync_job = UploadJob(
|
185
|
-
self._stats,
|
186
|
-
self._api,
|
187
|
-
self._file_stream,
|
188
|
-
self.silent,
|
189
|
-
self._request.save_name,
|
190
|
-
self._request.path,
|
191
|
-
self._request.artifact_id,
|
192
|
-
self._request.md5,
|
193
|
-
self._request.copied,
|
194
|
-
self._request.save_fn,
|
195
|
-
self._request.digest,
|
196
|
-
)
|
197
|
-
|
198
|
-
await loop.run_in_executor(None, sync_job.run)
|
199
|
-
else:
|
200
|
-
self._file_stream.push_success(
|
201
|
-
self._request.artifact_id, # type: ignore
|
202
|
-
self._request.save_name,
|
203
|
-
)
|
204
|
-
|
205
|
-
if deduped:
|
206
|
-
logger.info("Skipped uploading %s", self._request.path)
|
207
|
-
self._stats.set_file_deduped(self._request.path)
|
208
|
-
else:
|
209
|
-
logger.info("Uploaded file %s", self._request.path)
|
210
|
-
finally:
|
211
|
-
# If we fell back to the sync impl, the file will have already been deleted.
|
212
|
-
# Doesn't matter, we only try to delete it if it exists.
|
213
|
-
if self._request.copied:
|
214
|
-
try:
|
215
|
-
os.remove(self._request.path)
|
216
|
-
except OSError:
|
217
|
-
# The file has already been deleted, we don't have permissions, or something else we can't fix.
|
218
|
-
pass
|
@@ -12,6 +12,8 @@ else:
|
|
12
12
|
|
13
13
|
|
14
14
|
_gym_version_lt_0_26: Optional[bool] = None
|
15
|
+
_gymnasium_version_lt_1_0_0: Optional[bool] = None
|
16
|
+
|
15
17
|
_required_error_msg = (
|
16
18
|
"Couldn't import the gymnasium python package, "
|
17
19
|
"install with `pip install gymnasium`"
|
@@ -35,14 +37,10 @@ def monitor():
|
|
35
37
|
if gym_lib is None:
|
36
38
|
raise wandb.Error(_required_error_msg)
|
37
39
|
|
38
|
-
vcr = wandb.util.get_module(
|
39
|
-
f"{gym_lib}.wrappers.monitoring.video_recorder",
|
40
|
-
required=_required_error_msg,
|
41
|
-
)
|
42
|
-
|
43
40
|
global _gym_version_lt_0_26
|
41
|
+
global _gymnasium_version_lt_1_0_0
|
44
42
|
|
45
|
-
if _gym_version_lt_0_26 is None:
|
43
|
+
if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None:
|
46
44
|
if gym_lib == "gym":
|
47
45
|
import gym
|
48
46
|
else:
|
@@ -50,15 +48,31 @@ def monitor():
|
|
50
48
|
|
51
49
|
from wandb.util import parse_version
|
52
50
|
|
53
|
-
|
54
|
-
|
51
|
+
gym_lib_version = parse_version(gym.__version__)
|
52
|
+
_gym_version_lt_0_26 = gym_lib_version < parse_version("0.26.0")
|
53
|
+
_gymnasium_version_lt_1_0_0 = gym_lib_version < parse_version("1.0.0a1")
|
54
|
+
|
55
|
+
path = "path" # Default path
|
56
|
+
if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
|
57
|
+
vcr_recorder_attribute = "RecordVideo"
|
58
|
+
wrappers = wandb.util.get_module(
|
59
|
+
f"{gym_lib}.wrappers",
|
60
|
+
required=_required_error_msg,
|
61
|
+
)
|
62
|
+
recorder = getattr(wrappers, vcr_recorder_attribute)
|
63
|
+
else:
|
64
|
+
vcr = wandb.util.get_module(
|
65
|
+
f"{gym_lib}.wrappers.monitoring.video_recorder",
|
66
|
+
required=_required_error_msg,
|
67
|
+
)
|
68
|
+
# Breaking change in gym 0.26.0
|
69
|
+
if _gym_version_lt_0_26:
|
70
|
+
vcr_recorder_attribute = "ImageEncoder"
|
71
|
+
recorder = getattr(vcr, vcr_recorder_attribute)
|
72
|
+
path = "output_path" # Override path for older gym versions
|
55
73
|
else:
|
56
|
-
|
57
|
-
|
58
|
-
# breaking change in gym 0.26.0
|
59
|
-
vcr_recorder_attribute = "ImageEncoder" if _gym_version_lt_0_26 else "VideoRecorder"
|
60
|
-
recorder = getattr(vcr, vcr_recorder_attribute)
|
61
|
-
path = "output_path" if _gym_version_lt_0_26 else "path"
|
74
|
+
vcr_recorder_attribute = "VideoRecorder"
|
75
|
+
recorder = getattr(vcr, vcr_recorder_attribute)
|
62
76
|
|
63
77
|
recorder.orig_close = recorder.close
|
64
78
|
|
@@ -77,9 +91,15 @@ def monitor():
|
|
77
91
|
if not _gym_version_lt_0_26:
|
78
92
|
recorder.__del__ = del_
|
79
93
|
recorder.close = close
|
94
|
+
|
95
|
+
if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0:
|
96
|
+
wrapper_name = vcr_recorder_attribute
|
97
|
+
else:
|
98
|
+
wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}"
|
99
|
+
|
80
100
|
wandb.patched["gym"].append(
|
81
101
|
[
|
82
|
-
f"{gym_lib}.wrappers.
|
102
|
+
f"{gym_lib}.wrappers.{wrapper_name}",
|
83
103
|
"close",
|
84
104
|
]
|
85
105
|
)
|
@@ -65,6 +65,8 @@ class WandbLogger:
|
|
65
65
|
overwrite: bool = False,
|
66
66
|
wait_for_job_success: bool = True,
|
67
67
|
log_datasets: bool = True,
|
68
|
+
model_artifact_name: str = "model-metadata",
|
69
|
+
model_artifact_type: str = "model",
|
68
70
|
**kwargs_wandb_init: Dict[str, Any],
|
69
71
|
) -> str:
|
70
72
|
"""Sync fine-tunes to Weights & Biases.
|
@@ -76,6 +78,8 @@ class WandbLogger:
|
|
76
78
|
:param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
|
77
79
|
:param overwrite: Forces logging and overwrite existing wandb run of the same fine-tune.
|
78
80
|
:param wait_for_job_success: Waits for the fine-tune to be complete and then log metrics to W&B. By default, it is True.
|
81
|
+
:param model_artifact_name: Name of the model artifact that is logged
|
82
|
+
:param model_artifact_type: Type of the model artifact that is logged
|
79
83
|
"""
|
80
84
|
if openai_client is None:
|
81
85
|
openai_client = OpenAI()
|
@@ -157,6 +161,8 @@ class WandbLogger:
|
|
157
161
|
overwrite,
|
158
162
|
show_individual_warnings,
|
159
163
|
log_datasets,
|
164
|
+
model_artifact_name,
|
165
|
+
model_artifact_type,
|
160
166
|
**kwargs_wandb_init,
|
161
167
|
)
|
162
168
|
|
@@ -201,6 +207,8 @@ class WandbLogger:
|
|
201
207
|
overwrite: bool,
|
202
208
|
show_individual_warnings: bool,
|
203
209
|
log_datasets: bool,
|
210
|
+
model_artifact_name: str,
|
211
|
+
model_artifact_type: str,
|
204
212
|
**kwargs_wandb_init: Dict[str, Any],
|
205
213
|
):
|
206
214
|
fine_tune_id = fine_tune.id
|
@@ -244,7 +252,15 @@ class WandbLogger:
|
|
244
252
|
cls._run.summary["fine_tuned_model"] = fine_tuned_model
|
245
253
|
|
246
254
|
# training/validation files and fine-tune details
|
247
|
-
cls._log_artifacts(
|
255
|
+
cls._log_artifacts(
|
256
|
+
fine_tune,
|
257
|
+
project,
|
258
|
+
entity,
|
259
|
+
log_datasets,
|
260
|
+
overwrite,
|
261
|
+
model_artifact_name,
|
262
|
+
model_artifact_type,
|
263
|
+
)
|
248
264
|
|
249
265
|
# mark run as complete
|
250
266
|
cls._run.summary["status"] = "succeeded"
|
@@ -341,6 +357,8 @@ class WandbLogger:
|
|
341
357
|
entity: Optional[str],
|
342
358
|
log_datasets: bool,
|
343
359
|
overwrite: bool,
|
360
|
+
model_artifact_name: str,
|
361
|
+
model_artifact_type: str,
|
344
362
|
) -> None:
|
345
363
|
if log_datasets:
|
346
364
|
wandb.termlog("Logging training/validation files...")
|
@@ -361,8 +379,8 @@ class WandbLogger:
|
|
361
379
|
# fine-tune details
|
362
380
|
fine_tune_id = fine_tune.id
|
363
381
|
artifact = wandb.Artifact(
|
364
|
-
|
365
|
-
type=
|
382
|
+
model_artifact_name,
|
383
|
+
type=model_artifact_type,
|
366
384
|
metadata=dict(fine_tune),
|
367
385
|
)
|
368
386
|
|
wandb/jupyter.py
CHANGED
@@ -288,35 +288,34 @@ def attempt_colab_login(app_url):
|
|
288
288
|
display.display(
|
289
289
|
display.Javascript(
|
290
290
|
"""
|
291
|
-
window._wandbApiKey = new Promise((resolve, reject) => {
|
292
|
-
function loadScript(url) {
|
293
|
-
return new Promise(function(resolve, reject) {
|
291
|
+
window._wandbApiKey = new Promise((resolve, reject) => {{
|
292
|
+
function loadScript(url) {{
|
293
|
+
return new Promise(function(resolve, reject) {{
|
294
294
|
let newScript = document.createElement("script");
|
295
295
|
newScript.onerror = reject;
|
296
296
|
newScript.onload = resolve;
|
297
297
|
document.body.appendChild(newScript);
|
298
298
|
newScript.src = url;
|
299
|
-
});
|
300
|
-
}
|
301
|
-
loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {
|
299
|
+
}});
|
300
|
+
}}
|
301
|
+
loadScript("https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js").then(() => {{
|
302
302
|
const iframe = document.createElement('iframe')
|
303
303
|
iframe.style.cssText = "width:0;height:0;border:none"
|
304
304
|
document.body.appendChild(iframe)
|
305
|
-
const handshake = new Postmate({
|
305
|
+
const handshake = new Postmate({{
|
306
306
|
container: iframe,
|
307
|
-
url: '
|
308
|
-
});
|
307
|
+
url: '{}/authorize'
|
308
|
+
}});
|
309
309
|
const timeout = setTimeout(() => reject("Couldn't auto authenticate"), 5000)
|
310
|
-
handshake.then(function(child) {
|
311
|
-
child.on('authorize', data => {
|
310
|
+
handshake.then(function(child) {{
|
311
|
+
child.on('authorize', data => {{
|
312
312
|
clearTimeout(timeout)
|
313
313
|
resolve(data)
|
314
|
-
});
|
315
|
-
});
|
316
|
-
})
|
317
|
-
});
|
318
|
-
"""
|
319
|
-
% app_url.replace("http:", "https:")
|
314
|
+
}});
|
315
|
+
}});
|
316
|
+
}})
|
317
|
+
}});
|
318
|
+
""".format(app_url.replace("http:", "https:"))
|
320
319
|
)
|
321
320
|
)
|
322
321
|
try:
|
wandb/plot/pr_curve.py
CHANGED
wandb/plot/roc_curve.py
CHANGED
wandb/{plots → plot}/utils.py
RENAMED
@@ -1,21 +1,9 @@
|
|
1
|
-
from
|
1
|
+
from typing import Iterable, Sequence
|
2
2
|
|
3
3
|
import wandb
|
4
4
|
from wandb import util
|
5
|
-
from wandb.sdk.lib import deprecate
|
6
5
|
|
7
6
|
|
8
|
-
def deprecation_notice() -> None:
|
9
|
-
deprecate.deprecate(
|
10
|
-
field_name=deprecate.Deprecated.plots,
|
11
|
-
warning_message=(
|
12
|
-
"wandb.plots.* functions are deprecated and will be removed in a future release. "
|
13
|
-
"Please use wandb.plot.* instead."
|
14
|
-
),
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
# Test assumptions for plotting parameters and datasets
|
19
7
|
def test_missing(**kwargs):
|
20
8
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
21
9
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
|
|
25
13
|
for k, v in kwargs.items():
|
26
14
|
# Missing/empty params/datapoint arrays
|
27
15
|
if v is None:
|
28
|
-
wandb.termerror("
|
16
|
+
wandb.termerror("{} is None. Please try again.".format(k))
|
29
17
|
test_passed = False
|
30
18
|
if (k == "X") or (k == "X_test"):
|
31
19
|
if isinstance(v, scipy.sparse.csr.csr_matrix):
|
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
|
|
64
52
|
)
|
65
53
|
if non_nums > 0:
|
66
54
|
wandb.termerror(
|
67
|
-
"
|
68
|
-
|
55
|
+
f"{k} contains values that are not numbers. Please vectorize, "
|
56
|
+
f"label encode or one hot encode {k} and call the plotting function again."
|
69
57
|
)
|
70
58
|
test_passed = False
|
71
59
|
return test_passed
|
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
|
|
73
61
|
|
74
62
|
def test_fitted(model):
|
75
63
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
76
|
-
|
77
|
-
|
64
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
65
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
78
66
|
scikit_utils = util.get_module(
|
79
67
|
"sklearn.utils",
|
80
68
|
required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
|
@@ -120,7 +108,7 @@ def test_fitted(model):
|
|
120
108
|
|
121
109
|
|
122
110
|
def encode_labels(df):
|
123
|
-
|
111
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
124
112
|
preprocessing = util.get_module(
|
125
113
|
"sklearn.preprocessing",
|
126
114
|
"roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
|
@@ -137,7 +125,7 @@ def encode_labels(df):
|
|
137
125
|
def test_types(**kwargs):
|
138
126
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
139
127
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
140
|
-
|
128
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
141
129
|
|
142
130
|
base = util.get_module(
|
143
131
|
"sklearn.base",
|
@@ -171,25 +159,25 @@ def test_types(**kwargs):
|
|
171
159
|
list,
|
172
160
|
),
|
173
161
|
):
|
174
|
-
wandb.termerror("
|
162
|
+
wandb.termerror("{} is not an array. Please try again.".format(k))
|
175
163
|
test_passed = False
|
176
164
|
# check for classifier types
|
177
165
|
if k == "model":
|
178
166
|
if (not base.is_classifier(v)) and (not base.is_regressor(v)):
|
179
167
|
wandb.termerror(
|
180
|
-
"
|
168
|
+
"{} is not a classifier or regressor. Please try again.".format(k)
|
181
169
|
)
|
182
170
|
test_passed = False
|
183
171
|
elif k == "clf" or k == "binary_clf":
|
184
172
|
if not (base.is_classifier(v)):
|
185
|
-
wandb.termerror("
|
173
|
+
wandb.termerror("{} is not a classifier. Please try again.".format(k))
|
186
174
|
test_passed = False
|
187
175
|
elif k == "regressor":
|
188
176
|
if not base.is_regressor(v):
|
189
|
-
wandb.termerror("
|
177
|
+
wandb.termerror("{} is not a regressor. Please try again.".format(k))
|
190
178
|
test_passed = False
|
191
179
|
elif k == "clusterer":
|
192
180
|
if not (getattr(v, "_estimator_type", None) == "clusterer"):
|
193
|
-
wandb.termerror("
|
181
|
+
wandb.termerror("{} is not a clusterer. Please try again.".format(k))
|
194
182
|
test_passed = False
|
195
183
|
return test_passed
|