wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
wandb/errors/__init__.py
CHANGED
@@ -2,30 +2,22 @@ __all__ = [
|
|
2
2
|
"Error",
|
3
3
|
"UsageError",
|
4
4
|
"CommError",
|
5
|
-
"
|
6
|
-
"DockerError",
|
7
|
-
"LogMultiprocessError",
|
8
|
-
"MultiprocessError",
|
9
|
-
"RequireError",
|
10
|
-
"ExecutionError",
|
11
|
-
"LaunchError",
|
12
|
-
"SweepError",
|
5
|
+
"UnsupportedError",
|
13
6
|
"WaitTimeoutError",
|
14
|
-
"ContextCancelledError",
|
15
|
-
"ServiceStartProcessError",
|
16
|
-
"ServiceStartTimeoutError",
|
17
|
-
"ServiceStartPortError",
|
18
7
|
]
|
19
8
|
|
20
|
-
from typing import
|
9
|
+
from typing import Optional
|
21
10
|
|
22
11
|
|
23
12
|
class Error(Exception):
|
24
13
|
"""Base W&B Error"""
|
25
14
|
|
26
|
-
def __init__(self, message) -> None:
|
15
|
+
def __init__(self, message, context: Optional[dict] = None) -> None:
|
27
16
|
super().__init__(message)
|
28
17
|
self.message = message
|
18
|
+
# sentry context capture
|
19
|
+
if context:
|
20
|
+
self.context = context
|
29
21
|
|
30
22
|
# For python 2 support
|
31
23
|
def encode(self, encoding):
|
@@ -47,111 +39,11 @@ class UsageError(Error):
|
|
47
39
|
pass
|
48
40
|
|
49
41
|
|
50
|
-
class
|
51
|
-
"""Raised when
|
52
|
-
|
53
|
-
pass
|
54
|
-
|
55
|
-
|
56
|
-
class LogMultiprocessError(LogError):
|
57
|
-
"""Raised when wandb.log() fails because of multiprocessing"""
|
58
|
-
|
59
|
-
pass
|
60
|
-
|
61
|
-
|
62
|
-
class MultiprocessError(Error):
|
63
|
-
"""Raised when fails because of multiprocessing"""
|
64
|
-
|
65
|
-
pass
|
66
|
-
|
67
|
-
|
68
|
-
class RequireError(Error):
|
69
|
-
"""Raised when wandb.require() fails"""
|
70
|
-
|
71
|
-
pass
|
72
|
-
|
73
|
-
|
74
|
-
class ExecutionError(Error):
|
75
|
-
"""Generic execution exception"""
|
76
|
-
|
77
|
-
pass
|
78
|
-
|
79
|
-
|
80
|
-
class DockerError(Error):
|
81
|
-
"""Raised when attempting to execute a docker command"""
|
82
|
-
|
83
|
-
def __init__(
|
84
|
-
self,
|
85
|
-
command_launched: List[str],
|
86
|
-
return_code: int,
|
87
|
-
stdout: Optional[bytes] = None,
|
88
|
-
stderr: Optional[bytes] = None,
|
89
|
-
) -> None:
|
90
|
-
command_launched_str = " ".join(command_launched)
|
91
|
-
error_msg = (
|
92
|
-
f"The docker command executed was `{command_launched_str}`.\n"
|
93
|
-
f"It returned with code {return_code}\n"
|
94
|
-
)
|
95
|
-
if stdout is not None:
|
96
|
-
error_msg += f"The content of stdout is '{stdout.decode()}'\n"
|
97
|
-
else:
|
98
|
-
error_msg += (
|
99
|
-
"The content of stdout can be found above the "
|
100
|
-
"stacktrace (it wasn't captured).\n"
|
101
|
-
)
|
102
|
-
if stderr is not None:
|
103
|
-
error_msg += f"The content of stderr is '{stderr.decode()}'\n"
|
104
|
-
else:
|
105
|
-
error_msg += (
|
106
|
-
"The content of stderr can be found above the "
|
107
|
-
"stacktrace (it wasn't captured)."
|
108
|
-
)
|
109
|
-
super().__init__(error_msg)
|
110
|
-
|
111
|
-
|
112
|
-
class LaunchError(Error):
|
113
|
-
"""Raised when a known error occurs in wandb launch"""
|
114
|
-
|
115
|
-
pass
|
116
|
-
|
117
|
-
|
118
|
-
class SweepError(Error):
|
119
|
-
"""Raised when a known error occurs with wandb sweeps"""
|
120
|
-
|
121
|
-
pass
|
42
|
+
class UnsupportedError(UsageError):
|
43
|
+
"""Raised when trying to use a feature that is not supported"""
|
122
44
|
|
123
45
|
|
124
46
|
class WaitTimeoutError(Error):
|
125
47
|
"""Raised when wait() timeout occurs before process is finished"""
|
126
48
|
|
127
49
|
pass
|
128
|
-
|
129
|
-
|
130
|
-
class MailboxError(Error):
|
131
|
-
"""Generic Mailbox Exception"""
|
132
|
-
|
133
|
-
pass
|
134
|
-
|
135
|
-
|
136
|
-
class ContextCancelledError(Error):
|
137
|
-
"""Context cancelled Exception"""
|
138
|
-
|
139
|
-
pass
|
140
|
-
|
141
|
-
|
142
|
-
class ServiceStartProcessError(Error):
|
143
|
-
"""Raised when a known error occurs when launching wandb service"""
|
144
|
-
|
145
|
-
pass
|
146
|
-
|
147
|
-
|
148
|
-
class ServiceStartTimeoutError(Error):
|
149
|
-
"""Raised when service start times out"""
|
150
|
-
|
151
|
-
pass
|
152
|
-
|
153
|
-
|
154
|
-
class ServiceStartPortError(Error):
|
155
|
-
"""Raised when service start fails to find a port"""
|
156
|
-
|
157
|
-
pass
|
wandb/errors/term.py
CHANGED
wandb/fastai/__init__.py
CHANGED
wandb/filesync/dir_watcher.py
CHANGED
@@ -71,7 +71,7 @@ class FileEventHandler(abc.ABC):
|
|
71
71
|
|
72
72
|
|
73
73
|
class PolicyNow(FileEventHandler):
|
74
|
-
"""This policy only uploads files now"""
|
74
|
+
"""This policy only uploads files now."""
|
75
75
|
|
76
76
|
def on_modified(self, force: bool = False) -> None:
|
77
77
|
# only upload if we've never uploaded or when .save is called
|
@@ -88,7 +88,7 @@ class PolicyNow(FileEventHandler):
|
|
88
88
|
|
89
89
|
|
90
90
|
class PolicyEnd(FileEventHandler):
|
91
|
-
"""This policy only updates at the end of the run"""
|
91
|
+
"""This policy only updates at the end of the run."""
|
92
92
|
|
93
93
|
def on_modified(self, force: bool = False) -> None:
|
94
94
|
pass
|
@@ -106,8 +106,11 @@ class PolicyEnd(FileEventHandler):
|
|
106
106
|
|
107
107
|
|
108
108
|
class PolicyLive(FileEventHandler):
|
109
|
-
"""
|
110
|
-
|
109
|
+
"""Event handler that uploads respecting throttling.
|
110
|
+
|
111
|
+
Uploads files every RATE_LIMIT_SECONDS, which changes as the size increases to deal
|
112
|
+
with throttling.
|
113
|
+
"""
|
111
114
|
|
112
115
|
RATE_LIMIT_SECONDS = 15
|
113
116
|
unit_dict = dict(util.POW_10_BYTES)
|
@@ -250,7 +253,7 @@ class DirWatcher:
|
|
250
253
|
feh.on_modified(force=True)
|
251
254
|
|
252
255
|
def _per_file_event_handler(self) -> "wd_events.FileSystemEventHandler":
|
253
|
-
"""Create a Watchdog file event handler that does different things for every file"""
|
256
|
+
"""Create a Watchdog file event handler that does different things for every file."""
|
254
257
|
file_event_handler = wd_events.PatternMatchingEventHandler()
|
255
258
|
file_event_handler.on_created = self._on_file_created
|
256
259
|
file_event_handler.on_modified = self._on_file_modified
|
wandb/filesync/step_prepare.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
"""Batching file prepare requests to our API."""
|
2
2
|
|
3
3
|
import queue
|
4
|
-
import sys
|
5
4
|
import threading
|
6
5
|
import time
|
7
6
|
from typing import (
|
8
7
|
TYPE_CHECKING,
|
8
|
+
Callable,
|
9
9
|
List,
|
10
10
|
Mapping,
|
11
11
|
NamedTuple,
|
@@ -16,31 +16,16 @@ from typing import (
|
|
16
16
|
)
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
|
-
from wandb.sdk.internal import
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
from typing_extensions import Protocol
|
25
|
-
|
26
|
-
class DoPrepareFn(Protocol):
|
27
|
-
def __call__(self) -> "internal_api.CreateArtifactFileSpecInput":
|
28
|
-
pass
|
29
|
-
|
30
|
-
class OnPrepareFn(Protocol):
|
31
|
-
def __call__(
|
32
|
-
self,
|
33
|
-
upload_url: Optional[str], # GraphQL type File.uploadUrl
|
34
|
-
upload_headers: Sequence[str], # GraphQL type File.uploadHeaders
|
35
|
-
artifact_id: str, # GraphQL type File.artifact.id
|
36
|
-
) -> None:
|
37
|
-
pass
|
19
|
+
from wandb.sdk.internal.internal_api import (
|
20
|
+
Api,
|
21
|
+
CreateArtifactFileSpecInput,
|
22
|
+
CreateArtifactFilesResponseFile,
|
23
|
+
)
|
38
24
|
|
39
25
|
|
40
26
|
# Request for a file to be prepared.
|
41
27
|
class RequestPrepare(NamedTuple):
|
42
|
-
|
43
|
-
on_prepare: Optional["OnPrepareFn"]
|
28
|
+
file_spec: "CreateArtifactFileSpecInput"
|
44
29
|
response_queue: "queue.Queue[ResponsePrepare]"
|
45
30
|
|
46
31
|
|
@@ -54,7 +39,49 @@ class ResponsePrepare(NamedTuple):
|
|
54
39
|
birth_artifact_id: str
|
55
40
|
|
56
41
|
|
57
|
-
|
42
|
+
Request = Union[RequestPrepare, RequestFinish]
|
43
|
+
|
44
|
+
|
45
|
+
def _clamp(x: float, low: float, high: float) -> float:
|
46
|
+
return max(low, min(x, high))
|
47
|
+
|
48
|
+
|
49
|
+
def gather_batch(
|
50
|
+
request_queue: "queue.Queue[Request]",
|
51
|
+
batch_time: float,
|
52
|
+
inter_event_time: float,
|
53
|
+
max_batch_size: int,
|
54
|
+
clock: Callable[[], float] = time.monotonic,
|
55
|
+
) -> Tuple[bool, Sequence[RequestPrepare]]:
|
56
|
+
|
57
|
+
batch_start_time = clock()
|
58
|
+
remaining_time = batch_time
|
59
|
+
|
60
|
+
first_request = request_queue.get()
|
61
|
+
if isinstance(first_request, RequestFinish):
|
62
|
+
return True, []
|
63
|
+
|
64
|
+
batch: List[RequestPrepare] = [first_request]
|
65
|
+
|
66
|
+
while remaining_time > 0 and len(batch) < max_batch_size:
|
67
|
+
try:
|
68
|
+
request = request_queue.get(
|
69
|
+
timeout=_clamp(
|
70
|
+
x=inter_event_time,
|
71
|
+
low=1e-12, # 0 = "block forever", so just use something tiny
|
72
|
+
high=remaining_time,
|
73
|
+
),
|
74
|
+
)
|
75
|
+
if isinstance(request, RequestFinish):
|
76
|
+
return True, batch
|
77
|
+
|
78
|
+
batch.append(request)
|
79
|
+
remaining_time = batch_time - (clock() - batch_start_time)
|
80
|
+
|
81
|
+
except queue.Empty:
|
82
|
+
break
|
83
|
+
|
84
|
+
return False, batch
|
58
85
|
|
59
86
|
|
60
87
|
class StepPrepare:
|
@@ -66,68 +93,46 @@ class StepPrepare:
|
|
66
93
|
|
67
94
|
def __init__(
|
68
95
|
self,
|
69
|
-
api: "
|
96
|
+
api: "Api",
|
70
97
|
batch_time: float,
|
71
98
|
inter_event_time: float,
|
72
99
|
max_batch_size: int,
|
100
|
+
request_queue: Optional["queue.Queue[Request]"] = None,
|
73
101
|
) -> None:
|
74
102
|
self._api = api
|
75
103
|
self._inter_event_time = inter_event_time
|
76
104
|
self._batch_time = batch_time
|
77
105
|
self._max_batch_size = max_batch_size
|
78
|
-
self._request_queue: "queue.Queue[
|
79
|
-
queue.Queue()
|
80
|
-
)
|
106
|
+
self._request_queue: "queue.Queue[Request]" = request_queue or queue.Queue()
|
81
107
|
self._thread = threading.Thread(target=self._thread_body)
|
82
108
|
self._thread.daemon = True
|
83
109
|
|
84
110
|
def _thread_body(self) -> None:
|
85
111
|
while True:
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
112
|
+
finish, batch = gather_batch(
|
113
|
+
request_queue=self._request_queue,
|
114
|
+
batch_time=self._batch_time,
|
115
|
+
inter_event_time=self._inter_event_time,
|
116
|
+
max_batch_size=self._max_batch_size,
|
117
|
+
)
|
118
|
+
if batch:
|
119
|
+
prepare_response = self._prepare_batch(batch)
|
120
|
+
# send responses
|
121
|
+
for prepare_request in batch:
|
122
|
+
name = prepare_request.file_spec["name"]
|
123
|
+
response_file = prepare_response[name]
|
124
|
+
upload_url = response_file["uploadUrl"]
|
125
|
+
upload_headers = response_file["uploadHeaders"]
|
126
|
+
birth_artifact_id = response_file["artifact"]["id"]
|
127
|
+
prepare_request.response_queue.put(
|
128
|
+
ResponsePrepare(upload_url, upload_headers, birth_artifact_id)
|
101
129
|
)
|
102
|
-
prepare_request.response_queue.put(
|
103
|
-
ResponsePrepare(upload_url, upload_headers, birth_artifact_id)
|
104
|
-
)
|
105
130
|
if finish:
|
106
131
|
break
|
107
132
|
|
108
|
-
def _gather_batch(
|
109
|
-
self, first_request: RequestPrepare
|
110
|
-
) -> Tuple[bool, Sequence[RequestPrepare]]:
|
111
|
-
batch_start_time = time.time()
|
112
|
-
batch: List[RequestPrepare] = [first_request]
|
113
|
-
while True:
|
114
|
-
try:
|
115
|
-
request = self._request_queue.get(
|
116
|
-
block=True, timeout=self._inter_event_time
|
117
|
-
)
|
118
|
-
if isinstance(request, RequestFinish):
|
119
|
-
return True, batch
|
120
|
-
batch.append(request)
|
121
|
-
remaining_time = self._batch_time - (time.time() - batch_start_time)
|
122
|
-
if remaining_time < 0 or len(batch) >= self._max_batch_size:
|
123
|
-
break
|
124
|
-
except queue.Empty:
|
125
|
-
break
|
126
|
-
return False, batch
|
127
|
-
|
128
133
|
def _prepare_batch(
|
129
134
|
self, batch: Sequence[RequestPrepare]
|
130
|
-
) -> Mapping[str, "
|
135
|
+
) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
|
131
136
|
"""Execute the prepareFiles API call.
|
132
137
|
|
133
138
|
Arguments:
|
@@ -137,14 +142,10 @@ class StepPrepare:
|
|
137
142
|
an uploadUrl key. The value of the uploadUrl key is None if the file
|
138
143
|
already exists, or a url string if the file should be uploaded.
|
139
144
|
"""
|
140
|
-
|
141
|
-
for prepare_request in batch:
|
142
|
-
file_spec = prepare_request.prepare_fn()
|
143
|
-
file_specs.append(file_spec)
|
144
|
-
return self._api.create_artifact_files(file_specs)
|
145
|
+
return self._api.create_artifact_files([req.file_spec for req in batch])
|
145
146
|
|
146
147
|
def prepare_async(
|
147
|
-
self,
|
148
|
+
self, file_spec: "CreateArtifactFileSpecInput"
|
148
149
|
) -> "queue.Queue[ResponsePrepare]":
|
149
150
|
"""Request the backend to prepare a file for upload.
|
150
151
|
|
@@ -153,11 +154,11 @@ class StepPrepare:
|
|
153
154
|
either a file upload url, or None if the file doesn't need to be uploaded.
|
154
155
|
"""
|
155
156
|
response_queue: "queue.Queue[ResponsePrepare]" = queue.Queue()
|
156
|
-
self._request_queue.put(RequestPrepare(
|
157
|
+
self._request_queue.put(RequestPrepare(file_spec, response_queue))
|
157
158
|
return response_queue
|
158
159
|
|
159
|
-
def prepare(self,
|
160
|
-
return self.prepare_async(
|
160
|
+
def prepare(self, file_spec: "CreateArtifactFileSpecInput") -> ResponsePrepare:
|
161
|
+
return self.prepare_async(file_spec).get()
|
161
162
|
|
162
163
|
def start(self) -> None:
|
163
164
|
self._thread.start()
|
wandb/filesync/step_upload.py
CHANGED
@@ -203,7 +203,7 @@ class StepUpload:
|
|
203
203
|
self._spawn_upload(job)
|
204
204
|
|
205
205
|
def _spawn_upload(self, job: upload_job.UploadJob) -> None:
|
206
|
-
"""
|
206
|
+
"""Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
|
207
207
|
|
208
208
|
Context: it's important that, whenever we add an entry to `self._running_jobs`,
|
209
209
|
we ensure that a corresponding `EventJobDone` message will eventually get handled;
|
@@ -214,7 +214,6 @@ class StepUpload:
|
|
214
214
|
to `self._running_jobs` is textually right next to the code that eventually enqueues
|
215
215
|
the `EventJobDone` message. This should help keep them in sync.
|
216
216
|
"""
|
217
|
-
|
218
217
|
# Adding the entry to `self._running_jobs` MUST happen in the main thread,
|
219
218
|
# NOT in the job that gets submitted to the thread-pool, to guard against
|
220
219
|
# this sequence of events:
|
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
catboost init
|
3
|
-
"""
|
1
|
+
"""catboost init."""
|
4
2
|
|
5
3
|
from pathlib import Path
|
6
4
|
from types import SimpleNamespace
|
@@ -65,9 +63,7 @@ class WandbCallback:
|
|
65
63
|
def _checkpoint_artifact(
|
66
64
|
model: Union[CatBoostClassifier, CatBoostRegressor], aliases: List[str]
|
67
65
|
) -> None:
|
68
|
-
"""
|
69
|
-
Upload model checkpoint as W&B artifact
|
70
|
-
"""
|
66
|
+
"""Upload model checkpoint as W&B artifact."""
|
71
67
|
if wandb.run is None:
|
72
68
|
raise wandb.Error(
|
73
69
|
"You must call `wandb.init()` before `_checkpoint_artifact()`"
|
@@ -87,9 +83,7 @@ def _checkpoint_artifact(
|
|
87
83
|
def _log_feature_importance(
|
88
84
|
model: Union[CatBoostClassifier, CatBoostRegressor]
|
89
85
|
) -> None:
|
90
|
-
"""
|
91
|
-
Log feature importance with default settings.
|
92
|
-
"""
|
86
|
+
"""Log feature importance with default settings."""
|
93
87
|
if wandb.run is None:
|
94
88
|
raise wandb.Error(
|
95
89
|
"You must call `wandb.init()` before `_checkpoint_artifact()`"
|
@@ -119,7 +113,7 @@ def log_summary(
|
|
119
113
|
save_model_checkpoint: bool = False,
|
120
114
|
log_feature_importance: bool = True,
|
121
115
|
) -> None:
|
122
|
-
"""`log_summary` logs useful metrics about catboost model after training is done
|
116
|
+
"""`log_summary` logs useful metrics about catboost model after training is done.
|
123
117
|
|
124
118
|
Arguments:
|
125
119
|
model: it can be CatBoostClassifier or CatBoostRegressor.
|
@@ -136,13 +130,13 @@ def log_summary(
|
|
136
130
|
|
137
131
|
Example:
|
138
132
|
```python
|
139
|
-
train_pool = Pool(train[features], label=train[
|
140
|
-
test_pool = Pool(test[features], label=test[
|
133
|
+
train_pool = Pool(train[features], label=train["label"], cat_features=cat_features)
|
134
|
+
test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
|
141
135
|
|
142
136
|
model = CatBoostRegressor(
|
143
137
|
iterations=100,
|
144
|
-
loss_function=
|
145
|
-
eval_metric=
|
138
|
+
loss_function="Cox",
|
139
|
+
eval_metric="Cox",
|
146
140
|
)
|
147
141
|
|
148
142
|
model.fit(
|
@@ -1,5 +1,5 @@
|
|
1
|
-
"""
|
2
|
-
|
1
|
+
"""Hooks that add fast.ai v1 Learners to Weights & Biases through a callback.
|
2
|
+
|
3
3
|
Requested logged data can be configured through the callback constructor.
|
4
4
|
|
5
5
|
Examples:
|
@@ -61,8 +61,8 @@ except ImportError:
|
|
61
61
|
|
62
62
|
|
63
63
|
class WandbCallback(TrackerCallback):
|
64
|
-
"""
|
65
|
-
|
64
|
+
"""Callback for saving model topology, losses & metrics.
|
65
|
+
|
66
66
|
Optionally logs weights, gradients, sample predictions and best trained model.
|
67
67
|
|
68
68
|
Arguments:
|
@@ -92,7 +92,6 @@ class WandbCallback(TrackerCallback):
|
|
92
92
|
predictions: int = 36,
|
93
93
|
seed: int = 12345,
|
94
94
|
) -> None:
|
95
|
-
|
96
95
|
# Check if wandb.init has been called
|
97
96
|
if wandb.run is None:
|
98
97
|
raise ValueError("You must call wandb.init() before WandbCallback()")
|
@@ -119,8 +118,7 @@ class WandbCallback(TrackerCallback):
|
|
119
118
|
self.validation_data = [learn.data.valid_ds[i] for i in indices]
|
120
119
|
|
121
120
|
def on_train_begin(self, **kwargs: Any) -> None:
|
122
|
-
"""Call watch method to log model topology, gradients & weights"""
|
123
|
-
|
121
|
+
"""Call watch method to log model topology, gradients & weights."""
|
124
122
|
# Set self.best, method inherited from "TrackerCallback" by "SaveModelCallback"
|
125
123
|
super().on_train_begin()
|
126
124
|
|
@@ -134,8 +132,7 @@ class WandbCallback(TrackerCallback):
|
|
134
132
|
def on_epoch_end(
|
135
133
|
self, epoch: int, smooth_loss: float, last_metrics: list, **kwargs: Any
|
136
134
|
) -> None:
|
137
|
-
"""
|
138
|
-
|
135
|
+
"""Log training loss, validation loss and custom metrics & log prediction samples & save model."""
|
139
136
|
if self.save_model:
|
140
137
|
# Adapted from fast.ai "SaveModelCallback"
|
141
138
|
current = self.get_monitor_value()
|
@@ -174,7 +171,6 @@ class WandbCallback(TrackerCallback):
|
|
174
171
|
|
175
172
|
def on_train_end(self, **kwargs: Any) -> None:
|
176
173
|
"""Load the best model."""
|
177
|
-
|
178
174
|
if self.save_model:
|
179
175
|
# Adapted from fast.ai "SaveModelCallback"
|
180
176
|
if self.model_path.is_file():
|
@@ -183,8 +179,7 @@ class WandbCallback(TrackerCallback):
|
|
183
179
|
print(f"Loaded best saved model from {self.model_path}")
|
184
180
|
|
185
181
|
def _wandb_log_predictions(self) -> None:
|
186
|
-
"""Log prediction samples"""
|
187
|
-
|
182
|
+
"""Log prediction samples."""
|
188
183
|
pred_log = []
|
189
184
|
|
190
185
|
if self.validation_data is None:
|
@@ -234,7 +229,6 @@ class WandbCallback(TrackerCallback):
|
|
234
229
|
elif hasattr(y, "shape") and (
|
235
230
|
(len(y.shape) == 2) or (len(y.shape) == 3 and y.shape[0] in [1, 3, 4])
|
236
231
|
):
|
237
|
-
|
238
232
|
pred_log.extend(
|
239
233
|
[
|
240
234
|
wandb.Image(x.data, caption="Input data", grouping=3),
|
@@ -1,21 +1,52 @@
|
|
1
1
|
import re
|
2
|
+
import sys
|
2
3
|
from typing import Optional
|
3
4
|
|
4
5
|
import wandb
|
6
|
+
import wandb.util
|
7
|
+
|
8
|
+
if sys.version_info >= (3, 8):
|
9
|
+
from typing import Literal
|
10
|
+
else:
|
11
|
+
from typing_extensions import Literal
|
12
|
+
|
5
13
|
|
6
14
|
_gym_version_lt_0_26: Optional[bool] = None
|
15
|
+
_required_error_msg = (
|
16
|
+
"Couldn't import the gymnasium python package, "
|
17
|
+
"install with `pip install gymnasium`"
|
18
|
+
)
|
19
|
+
GymLib = Literal["gym", "gymnasium"]
|
7
20
|
|
8
21
|
|
9
22
|
def monitor():
|
23
|
+
"""Monitor a gym environment.
|
24
|
+
|
25
|
+
Supports both gym and gymnasium.
|
26
|
+
"""
|
27
|
+
gym_lib: Optional[GymLib] = None
|
28
|
+
|
29
|
+
# gym is not maintained anymore, gymnasium is the drop-in replacement - prefer it
|
30
|
+
if wandb.util.get_module("gymnasium") is not None:
|
31
|
+
gym_lib = "gymnasium"
|
32
|
+
elif wandb.util.get_module("gym") is not None:
|
33
|
+
gym_lib = "gym"
|
34
|
+
|
35
|
+
if gym_lib is None:
|
36
|
+
raise wandb.Error(_required_error_msg)
|
37
|
+
|
10
38
|
vcr = wandb.util.get_module(
|
11
|
-
"
|
12
|
-
required=
|
39
|
+
f"{gym_lib}.wrappers.monitoring.video_recorder",
|
40
|
+
required=_required_error_msg,
|
13
41
|
)
|
14
42
|
|
15
43
|
global _gym_version_lt_0_26
|
16
44
|
|
17
45
|
if _gym_version_lt_0_26 is None:
|
18
|
-
|
46
|
+
if gym_lib == "gym":
|
47
|
+
import gym
|
48
|
+
else:
|
49
|
+
import gymnasium as gym # type: ignore
|
19
50
|
from pkg_resources import parse_version
|
20
51
|
|
21
52
|
if parse_version(gym.__version__) < parse_version("0.26.0"):
|
@@ -47,7 +78,7 @@ def monitor():
|
|
47
78
|
recorder.close = close
|
48
79
|
wandb.patched["gym"].append(
|
49
80
|
[
|
50
|
-
f"
|
81
|
+
f"{gym_lib}.wrappers.monitoring.video_recorder.{vcr_recorder_attribute}",
|
51
82
|
"close",
|
52
83
|
]
|
53
84
|
)
|
@@ -1,6 +1,6 @@
|
|
1
|
-
"""
|
2
|
-
|
3
|
-
a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
|
1
|
+
"""Tools for integrating `wandb` with [`Keras`](https://keras.io/).
|
2
|
+
|
3
|
+
Keras is a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
|
4
4
|
"""
|
5
5
|
__all__ = (
|
6
6
|
"WandbCallback",
|