wandb 0.16.6__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +92 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +26 -11
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +2 -8
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +68 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +4 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +186 -224
- wandb/sdk/launch/agent/agent.py +37 -13
- wandb/sdk/launch/agent/config.py +72 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +12 -25
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +47 -37
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +15 -140
- wandb/sdk/lib/_settings_toposort_generated.py +0 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +14 -55
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +114 -56
- wandb/sdk/wandb_settings.py +0 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/RECORD +177 -187
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
wandb/filesync/step_upload.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Batching file prepare requests to our API."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import concurrent.futures
|
5
4
|
import logging
|
6
5
|
import queue
|
@@ -8,7 +7,6 @@ import sys
|
|
8
7
|
import threading
|
9
8
|
from typing import (
|
10
9
|
TYPE_CHECKING,
|
11
|
-
Awaitable,
|
12
10
|
Callable,
|
13
11
|
MutableMapping,
|
14
12
|
MutableSequence,
|
@@ -43,7 +41,6 @@ if TYPE_CHECKING:
|
|
43
41
|
PreCommitFn = Callable[[], None]
|
44
42
|
OnRequestFinishFn = Callable[[], None]
|
45
43
|
SaveFn = Callable[["progress.ProgressFn"], bool]
|
46
|
-
SaveFnAsync = Callable[["progress.ProgressFn"], Awaitable[bool]]
|
47
44
|
|
48
45
|
logger = logging.getLogger(__name__)
|
49
46
|
|
@@ -55,7 +52,6 @@ class RequestUpload(NamedTuple):
|
|
55
52
|
md5: Optional[str]
|
56
53
|
copied: bool
|
57
54
|
save_fn: Optional[SaveFn]
|
58
|
-
save_fn_async: Optional[SaveFnAsync]
|
59
55
|
digest: Optional[str]
|
60
56
|
|
61
57
|
|
@@ -78,47 +74,6 @@ class EventJobDone(NamedTuple):
|
|
78
74
|
Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
|
79
75
|
|
80
76
|
|
81
|
-
class AsyncExecutor:
|
82
|
-
"""Runs async file uploads in a background thread."""
|
83
|
-
|
84
|
-
def __init__(
|
85
|
-
self,
|
86
|
-
pool: concurrent.futures.ThreadPoolExecutor,
|
87
|
-
concurrency_limit: Optional[int],
|
88
|
-
) -> None:
|
89
|
-
self.loop = asyncio.new_event_loop()
|
90
|
-
self.loop.set_default_executor(pool)
|
91
|
-
self.loop_thread = threading.Thread(
|
92
|
-
target=self.loop.run_forever,
|
93
|
-
daemon=True,
|
94
|
-
name="wandb-upload-async",
|
95
|
-
)
|
96
|
-
|
97
|
-
self.concurrency_limiter = asyncio.Semaphore(
|
98
|
-
value=concurrency_limit or 128,
|
99
|
-
# Before Python 3.10: if we don't set `loop=loop`,
|
100
|
-
# then the Semaphore will bind to the wrong event loop,
|
101
|
-
# causing errors when a coroutine tries to wait for it;
|
102
|
-
# see https://pastebin.com/XcrS9suX .
|
103
|
-
# After 3.10: the `loop` argument doesn't exist.
|
104
|
-
# So we need to only conditionally pass in `loop`.
|
105
|
-
**({} if sys.version_info >= (3, 10) else {"loop": self.loop}),
|
106
|
-
)
|
107
|
-
|
108
|
-
def start(self) -> None:
|
109
|
-
self.loop_thread.start()
|
110
|
-
|
111
|
-
def stop(self) -> None:
|
112
|
-
self.loop.call_soon_threadsafe(self.loop.stop)
|
113
|
-
|
114
|
-
def submit(self, coro: Awaitable[None]) -> None:
|
115
|
-
async def run_with_limiter() -> None:
|
116
|
-
async with self.concurrency_limiter:
|
117
|
-
await coro
|
118
|
-
|
119
|
-
asyncio.run_coroutine_threadsafe(run_with_limiter(), self.loop)
|
120
|
-
|
121
|
-
|
122
77
|
class StepUpload:
|
123
78
|
def __init__(
|
124
79
|
self,
|
@@ -142,15 +97,6 @@ class StepUpload:
|
|
142
97
|
max_workers=max_threads,
|
143
98
|
)
|
144
99
|
|
145
|
-
self._async_executor = (
|
146
|
-
AsyncExecutor(
|
147
|
-
pool=self._pool,
|
148
|
-
concurrency_limit=settings._async_upload_concurrency_limit,
|
149
|
-
)
|
150
|
-
if settings is not None and settings._async_upload_concurrency_limit
|
151
|
-
else None
|
152
|
-
)
|
153
|
-
|
154
100
|
# Indexed by files' `save_name`'s, which are their ID's in the Run.
|
155
101
|
self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
|
156
102
|
self._pending_jobs: MutableSequence[RequestUpload] = []
|
@@ -186,8 +132,6 @@ class StepUpload:
|
|
186
132
|
elif not self._running_jobs:
|
187
133
|
# Queue was empty and no jobs left.
|
188
134
|
self._pool.shutdown(wait=False)
|
189
|
-
if self._async_executor:
|
190
|
-
self._async_executor.stop()
|
191
135
|
if finish_callback:
|
192
136
|
finish_callback()
|
193
137
|
break
|
@@ -245,18 +189,9 @@ class StepUpload:
|
|
245
189
|
self._pending_jobs.append(event)
|
246
190
|
return
|
247
191
|
|
248
|
-
|
249
|
-
# (The `and save_fn_async is not None` is because the async code path
|
250
|
-
# doesn't support all uploads yet: even if the user has requested async,
|
251
|
-
# we sometimes need to use the sync method instead.)
|
252
|
-
self._spawn_upload_async(
|
253
|
-
event,
|
254
|
-
async_executor=self._async_executor,
|
255
|
-
)
|
256
|
-
else:
|
257
|
-
self._spawn_upload_sync(event)
|
192
|
+
self._spawn_upload(event)
|
258
193
|
|
259
|
-
def
|
194
|
+
def _spawn_upload(self, event: RequestUpload) -> None:
|
260
195
|
"""Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
|
261
196
|
|
262
197
|
Context: it's important that, whenever we add an entry to `self._running_jobs`,
|
@@ -285,35 +220,13 @@ class StepUpload:
|
|
285
220
|
|
286
221
|
def run_and_notify() -> None:
|
287
222
|
try:
|
288
|
-
self.
|
223
|
+
self._do_upload(event)
|
289
224
|
finally:
|
290
225
|
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
291
226
|
|
292
227
|
self._pool.submit(run_and_notify)
|
293
228
|
|
294
|
-
def
|
295
|
-
self,
|
296
|
-
event: RequestUpload,
|
297
|
-
async_executor: AsyncExecutor,
|
298
|
-
) -> None:
|
299
|
-
"""Equivalent to _spawn_upload_sync, but uses the async event loop instead of a thread, and requires `event.save_fn_async`.
|
300
|
-
|
301
|
-
Raises:
|
302
|
-
AssertionError: if `event.save_fn_async` is None.
|
303
|
-
"""
|
304
|
-
assert event.save_fn_async is not None
|
305
|
-
|
306
|
-
self._running_jobs[event.save_name] = event
|
307
|
-
|
308
|
-
async def run_and_notify() -> None:
|
309
|
-
try:
|
310
|
-
await self._do_upload_async(event)
|
311
|
-
finally:
|
312
|
-
self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
|
313
|
-
|
314
|
-
async_executor.submit(run_and_notify())
|
315
|
-
|
316
|
-
def _do_upload_sync(self, event: RequestUpload) -> None:
|
229
|
+
def _do_upload(self, event: RequestUpload) -> None:
|
317
230
|
job = upload_job.UploadJob(
|
318
231
|
self._stats,
|
319
232
|
self._api,
|
@@ -329,19 +242,6 @@ class StepUpload:
|
|
329
242
|
)
|
330
243
|
job.run()
|
331
244
|
|
332
|
-
async def _do_upload_async(self, event: RequestUpload) -> None:
|
333
|
-
"""Upload a file and returns when it's done. Requires `event.save_fn_async`."""
|
334
|
-
assert event.save_fn_async is not None
|
335
|
-
job = upload_job.UploadJobAsync(
|
336
|
-
stats=self._stats,
|
337
|
-
api=self._api,
|
338
|
-
file_stream=self._file_stream,
|
339
|
-
silent=self.silent,
|
340
|
-
request=event,
|
341
|
-
save_fn_async=event.save_fn_async,
|
342
|
-
)
|
343
|
-
await job.run()
|
344
|
-
|
345
245
|
def _init_artifact(self, artifact_id: str) -> None:
|
346
246
|
self._artifacts[artifact_id] = {
|
347
247
|
"finalize": False,
|
@@ -385,8 +285,6 @@ class StepUpload:
|
|
385
285
|
|
386
286
|
def start(self) -> None:
|
387
287
|
self._thread.start()
|
388
|
-
if self._async_executor:
|
389
|
-
self._async_executor.start()
|
390
288
|
|
391
289
|
def is_alive(self) -> bool:
|
392
290
|
return self._thread.is_alive()
|
wandb/filesync/upload_job.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import logging
|
3
2
|
import os
|
4
3
|
from typing import TYPE_CHECKING, Optional
|
@@ -141,78 +140,3 @@ class UploadJob:
|
|
141
140
|
|
142
141
|
def progress(self, total_bytes: int) -> None:
|
143
142
|
self._stats.update_uploaded_file(self.save_name, total_bytes)
|
144
|
-
|
145
|
-
|
146
|
-
class UploadJobAsync:
|
147
|
-
"""Roughly an async equivalent of UploadJob.
|
148
|
-
|
149
|
-
Important differences:
|
150
|
-
- `run` is a coroutine
|
151
|
-
- If `run()` fails, it falls back to the synchronous UploadJob
|
152
|
-
"""
|
153
|
-
|
154
|
-
def __init__(
|
155
|
-
self,
|
156
|
-
stats: "stats.Stats",
|
157
|
-
api: "internal_api.Api",
|
158
|
-
file_stream: "file_stream.FileStreamApi",
|
159
|
-
silent: bool,
|
160
|
-
request: "step_upload.RequestUpload",
|
161
|
-
save_fn_async: "step_upload.SaveFnAsync",
|
162
|
-
) -> None:
|
163
|
-
self._stats = stats
|
164
|
-
self._api = api
|
165
|
-
self._file_stream = file_stream
|
166
|
-
self.silent = silent
|
167
|
-
self._request = request
|
168
|
-
self._save_fn_async = save_fn_async
|
169
|
-
|
170
|
-
async def run(self) -> None:
|
171
|
-
try:
|
172
|
-
deduped = await self._save_fn_async(
|
173
|
-
lambda _, t: self._stats.update_uploaded_file(self._request.path, t)
|
174
|
-
)
|
175
|
-
except Exception as e:
|
176
|
-
# Async uploads aren't yet (2023-01) battle-tested.
|
177
|
-
# Fall back to the "normal" synchronous upload.
|
178
|
-
loop = asyncio.get_event_loop()
|
179
|
-
logger.exception("async upload failed", exc_info=e)
|
180
|
-
loop.run_in_executor(None, wandb._sentry.exception, e)
|
181
|
-
wandb.termwarn(
|
182
|
-
"Async file upload failed; falling back to sync", repeat=False
|
183
|
-
)
|
184
|
-
sync_job = UploadJob(
|
185
|
-
self._stats,
|
186
|
-
self._api,
|
187
|
-
self._file_stream,
|
188
|
-
self.silent,
|
189
|
-
self._request.save_name,
|
190
|
-
self._request.path,
|
191
|
-
self._request.artifact_id,
|
192
|
-
self._request.md5,
|
193
|
-
self._request.copied,
|
194
|
-
self._request.save_fn,
|
195
|
-
self._request.digest,
|
196
|
-
)
|
197
|
-
|
198
|
-
await loop.run_in_executor(None, sync_job.run)
|
199
|
-
else:
|
200
|
-
self._file_stream.push_success(
|
201
|
-
self._request.artifact_id, # type: ignore
|
202
|
-
self._request.save_name,
|
203
|
-
)
|
204
|
-
|
205
|
-
if deduped:
|
206
|
-
logger.info("Skipped uploading %s", self._request.path)
|
207
|
-
self._stats.set_file_deduped(self._request.path)
|
208
|
-
else:
|
209
|
-
logger.info("Uploaded file %s", self._request.path)
|
210
|
-
finally:
|
211
|
-
# If we fell back to the sync impl, the file will have already been deleted.
|
212
|
-
# Doesn't matter, we only try to delete it if it exists.
|
213
|
-
if self._request.copied:
|
214
|
-
try:
|
215
|
-
os.remove(self._request.path)
|
216
|
-
except OSError:
|
217
|
-
# The file has already been deleted, we don't have permissions, or something else we can't fix.
|
218
|
-
pass
|
@@ -81,7 +81,7 @@ def _checkpoint_artifact(
|
|
81
81
|
|
82
82
|
|
83
83
|
def _log_feature_importance(
|
84
|
-
model: Union[CatBoostClassifier, CatBoostRegressor]
|
84
|
+
model: Union[CatBoostClassifier, CatBoostRegressor],
|
85
85
|
) -> None:
|
86
86
|
"""Log feature importance with default settings."""
|
87
87
|
if wandb.run is None:
|
@@ -91,8 +91,8 @@ class HuggingFacePipelineRequestResponseResolver:
|
|
91
91
|
|
92
92
|
# TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
|
93
93
|
# from transformers.modeling_utils import PreTrainedModel
|
94
|
-
# We do not want this dependency
|
95
|
-
# the structure of the pipeline which may have unintended consequences
|
94
|
+
# We do not want this dependency explicitly in our codebase so we make a very general
|
95
|
+
# assumption about the structure of the pipeline which may have unintended consequences
|
96
96
|
def _get_model(self, pipe) -> Optional[Any]:
|
97
97
|
"""Extracts model from the pipeline.
|
98
98
|
|
@@ -43,7 +43,7 @@ class WandbMetricsLogger(callbacks.Callback):
|
|
43
43
|
at the end of each epoch. If "batch", logs metrics at the end
|
44
44
|
of each batch. If an integer, logs metrics at the end of that
|
45
45
|
many batches. Defaults to "epoch".
|
46
|
-
initial_global_step: (int) Use this argument to
|
46
|
+
initial_global_step: (int) Use this argument to correctly log the
|
47
47
|
learning rate when you resume training from some `initial_epoch`,
|
48
48
|
and a learning rate scheduler is used. This can be computed as
|
49
49
|
`step_size * initial_step`. Defaults to 0.
|
wandb/integration/keras/keras.py
CHANGED
@@ -576,7 +576,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
576
576
|
)
|
577
577
|
self._model_trained_since_last_eval = False
|
578
578
|
except Exception as e:
|
579
|
-
wandb.termwarn("Error
|
579
|
+
wandb.termwarn("Error during prediction logging for epoch: " + str(e))
|
580
580
|
|
581
581
|
def on_epoch_end(self, epoch, logs=None):
|
582
582
|
if logs is None:
|
@@ -616,9 +616,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
616
616
|
self.current = logs.get(self.monitor)
|
617
617
|
if self.current and self.monitor_op(self.current, self.best):
|
618
618
|
if self.log_best_prefix:
|
619
|
-
wandb.run.summary[
|
620
|
-
|
621
|
-
|
619
|
+
wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
|
620
|
+
self.current
|
621
|
+
)
|
622
622
|
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
|
623
623
|
if self.verbose and not self.save_model:
|
624
624
|
print(
|
@@ -937,9 +937,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
937
937
|
grads = self._grad_accumulator_callback.grads
|
938
938
|
metrics = {}
|
939
939
|
for weight, grad in zip(weights, grads):
|
940
|
-
metrics[
|
941
|
-
|
942
|
-
|
940
|
+
metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
|
941
|
+
wandb.Histogram(grad)
|
942
|
+
)
|
943
943
|
return metrics
|
944
944
|
|
945
945
|
def _log_dataframe(self):
|
@@ -14,6 +14,7 @@ integration will not break user code. The one exception to the rule is at import
|
|
14
14
|
LangChain is not installed, or the symbols are not in the same place, the appropriate error
|
15
15
|
will be raised when importing this module.
|
16
16
|
"""
|
17
|
+
|
17
18
|
from packaging import version
|
18
19
|
|
19
20
|
import wandb.util
|
@@ -401,9 +401,7 @@ class WandbLogger(Logger):
|
|
401
401
|
"*", step_metric="trainer/global_step", step_sync=True
|
402
402
|
)
|
403
403
|
|
404
|
-
self._experiment._label(
|
405
|
-
repo="lightning_fabric_logger"
|
406
|
-
) # pylint: disable=protected-access
|
404
|
+
self._experiment._label(repo="lightning_fabric_logger") # pylint: disable=protected-access
|
407
405
|
with telemetry.context(run=self._experiment) as tel:
|
408
406
|
tel.feature.lightning_fabric_logger = True
|
409
407
|
return self._experiment
|
@@ -36,7 +36,15 @@ try:
|
|
36
36
|
import pandas as pd
|
37
37
|
|
38
38
|
@typedispatch # noqa: F811
|
39
|
-
def _wandb_use(
|
39
|
+
def _wandb_use(
|
40
|
+
name: str,
|
41
|
+
data: pd.DataFrame,
|
42
|
+
datasets=False,
|
43
|
+
run=None,
|
44
|
+
testing=False,
|
45
|
+
*args,
|
46
|
+
**kwargs,
|
47
|
+
): # type: ignore
|
40
48
|
if testing:
|
41
49
|
return "datasets" if datasets else None
|
42
50
|
|
@@ -74,7 +82,15 @@ try:
|
|
74
82
|
import torch.nn as nn
|
75
83
|
|
76
84
|
@typedispatch # noqa: F811
|
77
|
-
def _wandb_use(
|
85
|
+
def _wandb_use(
|
86
|
+
name: str,
|
87
|
+
data: nn.Module,
|
88
|
+
models=False,
|
89
|
+
run=None,
|
90
|
+
testing=False,
|
91
|
+
*args,
|
92
|
+
**kwargs,
|
93
|
+
): # type: ignore
|
78
94
|
if testing:
|
79
95
|
return "models" if models else None
|
80
96
|
|
@@ -111,7 +127,15 @@ try:
|
|
111
127
|
from sklearn.base import BaseEstimator
|
112
128
|
|
113
129
|
@typedispatch # noqa: F811
|
114
|
-
def _wandb_use(
|
130
|
+
def _wandb_use(
|
131
|
+
name: str,
|
132
|
+
data: BaseEstimator,
|
133
|
+
models=False,
|
134
|
+
run=None,
|
135
|
+
testing=False,
|
136
|
+
*args,
|
137
|
+
**kwargs,
|
138
|
+
): # type: ignore
|
115
139
|
if testing:
|
116
140
|
return "models" if models else None
|
117
141
|
|
@@ -169,7 +193,14 @@ class ArtifactProxy:
|
|
169
193
|
|
170
194
|
|
171
195
|
@typedispatch # noqa: F811
|
172
|
-
def wandb_track(
|
196
|
+
def wandb_track(
|
197
|
+
name: str,
|
198
|
+
data: (dict, list, set, str, int, float, bool),
|
199
|
+
run=None,
|
200
|
+
testing=False,
|
201
|
+
*args,
|
202
|
+
**kwargs,
|
203
|
+
): # type: ignore
|
173
204
|
if testing:
|
174
205
|
return "scalar"
|
175
206
|
|
@@ -222,12 +253,16 @@ def wandb_use(name: str, data, *args, **kwargs):
|
|
222
253
|
|
223
254
|
|
224
255
|
@typedispatch # noqa: F811
|
225
|
-
def wandb_use(
|
256
|
+
def wandb_use(
|
257
|
+
name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs
|
258
|
+
): # type: ignore
|
226
259
|
pass # do nothing for these types
|
227
260
|
|
228
261
|
|
229
262
|
@typedispatch # noqa: F811
|
230
|
-
def _wandb_use(
|
263
|
+
def _wandb_use(
|
264
|
+
name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
|
265
|
+
): # type: ignore
|
231
266
|
if testing:
|
232
267
|
return "datasets" if datasets else None
|
233
268
|
|
@@ -310,9 +310,9 @@ class WandbLogger:
|
|
310
310
|
try:
|
311
311
|
hyperparams["n_epochs"] = hyperparameters.n_epochs
|
312
312
|
hyperparams["batch_size"] = hyperparameters.batch_size
|
313
|
-
hyperparams[
|
314
|
-
|
315
|
-
|
313
|
+
hyperparams["learning_rate_multiplier"] = (
|
314
|
+
hyperparameters.learning_rate_multiplier
|
315
|
+
)
|
316
316
|
except Exception:
|
317
317
|
# If unpacking fails, return the object to be logged as config
|
318
318
|
return None
|
wandb/old/summary.py
CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
|
|
61
61
|
This should only be implemented by the "_root" child class.
|
62
62
|
|
63
63
|
We pass the child_dict so the item can be set on it or not as
|
64
|
-
appropriate. Returning None for a
|
64
|
+
appropriate. Returning None for a nonexistent path wouldn't be
|
65
65
|
distinguishable from that path being set to the value None.
|
66
66
|
"""
|
67
67
|
raise NotImplementedError
|
wandb/plot/confusion_matrix.py
CHANGED
@@ -48,7 +48,7 @@ def confusion_matrix(
|
|
48
48
|
|
49
49
|
assert (probs is None or preds is None) and not (
|
50
50
|
probs is None and preds is None
|
51
|
-
), "Must provide
|
51
|
+
), "Must provide probabilities or predictions but not both to confusion matrix"
|
52
52
|
|
53
53
|
if probs is not None:
|
54
54
|
preds = np.argmax(probs, axis=1).tolist()
|
wandb/plot/pr_curve.py
CHANGED
wandb/plot/roc_curve.py
CHANGED
wandb/{plots → plot}/utils.py
RENAMED
@@ -1,21 +1,9 @@
|
|
1
|
-
from
|
1
|
+
from typing import Iterable, Sequence
|
2
2
|
|
3
3
|
import wandb
|
4
4
|
from wandb import util
|
5
|
-
from wandb.sdk.lib import deprecate
|
6
5
|
|
7
6
|
|
8
|
-
def deprecation_notice() -> None:
|
9
|
-
deprecate.deprecate(
|
10
|
-
field_name=deprecate.Deprecated.plots,
|
11
|
-
warning_message=(
|
12
|
-
"wandb.plots.* functions are deprecated and will be removed in a future release. "
|
13
|
-
"Please use wandb.plot.* instead."
|
14
|
-
),
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
# Test assumptions for plotting parameters and datasets
|
19
7
|
def test_missing(**kwargs):
|
20
8
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
21
9
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
|
|
25
13
|
for k, v in kwargs.items():
|
26
14
|
# Missing/empty params/datapoint arrays
|
27
15
|
if v is None:
|
28
|
-
wandb.termerror("
|
16
|
+
wandb.termerror("{} is None. Please try again.".format(k))
|
29
17
|
test_passed = False
|
30
18
|
if (k == "X") or (k == "X_test"):
|
31
19
|
if isinstance(v, scipy.sparse.csr.csr_matrix):
|
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
|
|
64
52
|
)
|
65
53
|
if non_nums > 0:
|
66
54
|
wandb.termerror(
|
67
|
-
"
|
68
|
-
|
55
|
+
f"{k} contains values that are not numbers. Please vectorize, "
|
56
|
+
f"label encode or one hot encode {k} and call the plotting function again."
|
69
57
|
)
|
70
58
|
test_passed = False
|
71
59
|
return test_passed
|
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
|
|
73
61
|
|
74
62
|
def test_fitted(model):
|
75
63
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
76
|
-
|
77
|
-
|
64
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
65
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
78
66
|
scikit_utils = util.get_module(
|
79
67
|
"sklearn.utils",
|
80
68
|
required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
|
@@ -120,7 +108,7 @@ def test_fitted(model):
|
|
120
108
|
|
121
109
|
|
122
110
|
def encode_labels(df):
|
123
|
-
|
111
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
124
112
|
preprocessing = util.get_module(
|
125
113
|
"sklearn.preprocessing",
|
126
114
|
"roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
|
@@ -137,7 +125,7 @@ def encode_labels(df):
|
|
137
125
|
def test_types(**kwargs):
|
138
126
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
139
127
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
140
|
-
|
128
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
141
129
|
|
142
130
|
base = util.get_module(
|
143
131
|
"sklearn.base",
|
@@ -171,25 +159,25 @@ def test_types(**kwargs):
|
|
171
159
|
list,
|
172
160
|
),
|
173
161
|
):
|
174
|
-
wandb.termerror("
|
162
|
+
wandb.termerror("{} is not an array. Please try again.".format(k))
|
175
163
|
test_passed = False
|
176
164
|
# check for classifier types
|
177
165
|
if k == "model":
|
178
166
|
if (not base.is_classifier(v)) and (not base.is_regressor(v)):
|
179
167
|
wandb.termerror(
|
180
|
-
"
|
168
|
+
"{} is not a classifier or regressor. Please try again.".format(k)
|
181
169
|
)
|
182
170
|
test_passed = False
|
183
171
|
elif k == "clf" or k == "binary_clf":
|
184
172
|
if not (base.is_classifier(v)):
|
185
|
-
wandb.termerror("
|
173
|
+
wandb.termerror("{} is not a classifier. Please try again.".format(k))
|
186
174
|
test_passed = False
|
187
175
|
elif k == "regressor":
|
188
176
|
if not base.is_regressor(v):
|
189
|
-
wandb.termerror("
|
177
|
+
wandb.termerror("{} is not a regressor. Please try again.".format(k))
|
190
178
|
test_passed = False
|
191
179
|
elif k == "clusterer":
|
192
180
|
if not (getattr(v, "_estimator_type", None) == "clusterer"):
|
193
|
-
wandb.termerror("
|
181
|
+
wandb.termerror("{} is not a clusterer. Please try again.".format(k))
|
194
182
|
test_passed = False
|
195
183
|
return test_passed
|