wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
wandb/errors/__init__.py CHANGED
@@ -2,30 +2,22 @@ __all__ = [
2
2
  "Error",
3
3
  "UsageError",
4
4
  "CommError",
5
- "LogError",
6
- "DockerError",
7
- "LogMultiprocessError",
8
- "MultiprocessError",
9
- "RequireError",
10
- "ExecutionError",
11
- "LaunchError",
12
- "SweepError",
5
+ "UnsupportedError",
13
6
  "WaitTimeoutError",
14
- "ContextCancelledError",
15
- "ServiceStartProcessError",
16
- "ServiceStartTimeoutError",
17
- "ServiceStartPortError",
18
7
  ]
19
8
 
20
- from typing import List, Optional
9
+ from typing import Optional
21
10
 
22
11
 
23
12
  class Error(Exception):
24
13
  """Base W&B Error"""
25
14
 
26
- def __init__(self, message) -> None:
15
+ def __init__(self, message, context: Optional[dict] = None) -> None:
27
16
  super().__init__(message)
28
17
  self.message = message
18
+ # sentry context capture
19
+ if context:
20
+ self.context = context
29
21
 
30
22
  # For python 2 support
31
23
  def encode(self, encoding):
@@ -47,111 +39,11 @@ class UsageError(Error):
47
39
  pass
48
40
 
49
41
 
50
- class LogError(Error):
51
- """Raised when wandb.log() fails"""
52
-
53
- pass
54
-
55
-
56
- class LogMultiprocessError(LogError):
57
- """Raised when wandb.log() fails because of multiprocessing"""
58
-
59
- pass
60
-
61
-
62
- class MultiprocessError(Error):
63
- """Raised when fails because of multiprocessing"""
64
-
65
- pass
66
-
67
-
68
- class RequireError(Error):
69
- """Raised when wandb.require() fails"""
70
-
71
- pass
72
-
73
-
74
- class ExecutionError(Error):
75
- """Generic execution exception"""
76
-
77
- pass
78
-
79
-
80
- class DockerError(Error):
81
- """Raised when attempting to execute a docker command"""
82
-
83
- def __init__(
84
- self,
85
- command_launched: List[str],
86
- return_code: int,
87
- stdout: Optional[bytes] = None,
88
- stderr: Optional[bytes] = None,
89
- ) -> None:
90
- command_launched_str = " ".join(command_launched)
91
- error_msg = (
92
- f"The docker command executed was `{command_launched_str}`.\n"
93
- f"It returned with code {return_code}\n"
94
- )
95
- if stdout is not None:
96
- error_msg += f"The content of stdout is '{stdout.decode()}'\n"
97
- else:
98
- error_msg += (
99
- "The content of stdout can be found above the "
100
- "stacktrace (it wasn't captured).\n"
101
- )
102
- if stderr is not None:
103
- error_msg += f"The content of stderr is '{stderr.decode()}'\n"
104
- else:
105
- error_msg += (
106
- "The content of stderr can be found above the "
107
- "stacktrace (it wasn't captured)."
108
- )
109
- super().__init__(error_msg)
110
-
111
-
112
- class LaunchError(Error):
113
- """Raised when a known error occurs in wandb launch"""
114
-
115
- pass
116
-
117
-
118
- class SweepError(Error):
119
- """Raised when a known error occurs with wandb sweeps"""
120
-
121
- pass
42
+ class UnsupportedError(UsageError):
43
+ """Raised when trying to use a feature that is not supported"""
122
44
 
123
45
 
124
46
  class WaitTimeoutError(Error):
125
47
  """Raised when wait() timeout occurs before process is finished"""
126
48
 
127
49
  pass
128
-
129
-
130
- class MailboxError(Error):
131
- """Generic Mailbox Exception"""
132
-
133
- pass
134
-
135
-
136
- class ContextCancelledError(Error):
137
- """Context cancelled Exception"""
138
-
139
- pass
140
-
141
-
142
- class ServiceStartProcessError(Error):
143
- """Raised when a known error occurs when launching wandb service"""
144
-
145
- pass
146
-
147
-
148
- class ServiceStartTimeoutError(Error):
149
- """Raised when service start times out"""
150
-
151
- pass
152
-
153
-
154
- class ServiceStartPortError(Error):
155
- """Raised when service start fails to find a port"""
156
-
157
- pass
wandb/errors/term.py CHANGED
@@ -17,7 +17,7 @@ _show_errors = True
17
17
  _logger = None
18
18
 
19
19
 
20
- def termsetup(settings, logger):
20
+ def termsetup(settings, logger) -> None:
21
21
  global _silent, _show_info, _show_warnings, _show_errors, _logger
22
22
  _silent = settings.silent
23
23
  _show_info = settings.show_info
wandb/fastai/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
- """
2
- Compatibility fastai module.
1
+ """Compatibility fastai module.
3
2
 
4
3
  In the future use:
5
4
  from wandb.integration.fastai import WandbCallback
@@ -71,7 +71,7 @@ class FileEventHandler(abc.ABC):
71
71
 
72
72
 
73
73
  class PolicyNow(FileEventHandler):
74
- """This policy only uploads files now"""
74
+ """This policy only uploads files now."""
75
75
 
76
76
  def on_modified(self, force: bool = False) -> None:
77
77
  # only upload if we've never uploaded or when .save is called
@@ -88,7 +88,7 @@ class PolicyNow(FileEventHandler):
88
88
 
89
89
 
90
90
  class PolicyEnd(FileEventHandler):
91
- """This policy only updates at the end of the run"""
91
+ """This policy only updates at the end of the run."""
92
92
 
93
93
  def on_modified(self, force: bool = False) -> None:
94
94
  pass
@@ -106,8 +106,11 @@ class PolicyEnd(FileEventHandler):
106
106
 
107
107
 
108
108
  class PolicyLive(FileEventHandler):
109
- """This policy will upload files every RATE_LIMIT_SECONDS as it
110
- changes throttling as the size increases"""
109
+ """Event handler that uploads respecting throttling.
110
+
111
+ Uploads files every RATE_LIMIT_SECONDS, which changes as the size increases to deal
112
+ with throttling.
113
+ """
111
114
 
112
115
  RATE_LIMIT_SECONDS = 15
113
116
  unit_dict = dict(util.POW_10_BYTES)
@@ -250,7 +253,7 @@ class DirWatcher:
250
253
  feh.on_modified(force=True)
251
254
 
252
255
  def _per_file_event_handler(self) -> "wd_events.FileSystemEventHandler":
253
- """Create a Watchdog file event handler that does different things for every file"""
256
+ """Create a Watchdog file event handler that does different things for every file."""
254
257
  file_event_handler = wd_events.PatternMatchingEventHandler()
255
258
  file_event_handler.on_created = self._on_file_created
256
259
  file_event_handler.on_modified = self._on_file_modified
@@ -1,11 +1,11 @@
1
1
  """Batching file prepare requests to our API."""
2
2
 
3
3
  import queue
4
- import sys
5
4
  import threading
6
5
  import time
7
6
  from typing import (
8
7
  TYPE_CHECKING,
8
+ Callable,
9
9
  List,
10
10
  Mapping,
11
11
  NamedTuple,
@@ -16,31 +16,16 @@ from typing import (
16
16
  )
17
17
 
18
18
  if TYPE_CHECKING:
19
- from wandb.sdk.internal import internal_api
20
-
21
- if sys.version_info >= (3, 8):
22
- from typing import Protocol
23
- else:
24
- from typing_extensions import Protocol
25
-
26
- class DoPrepareFn(Protocol):
27
- def __call__(self) -> "internal_api.CreateArtifactFileSpecInput":
28
- pass
29
-
30
- class OnPrepareFn(Protocol):
31
- def __call__(
32
- self,
33
- upload_url: Optional[str], # GraphQL type File.uploadUrl
34
- upload_headers: Sequence[str], # GraphQL type File.uploadHeaders
35
- artifact_id: str, # GraphQL type File.artifact.id
36
- ) -> None:
37
- pass
19
+ from wandb.sdk.internal.internal_api import (
20
+ Api,
21
+ CreateArtifactFileSpecInput,
22
+ CreateArtifactFilesResponseFile,
23
+ )
38
24
 
39
25
 
40
26
  # Request for a file to be prepared.
41
27
  class RequestPrepare(NamedTuple):
42
- prepare_fn: "DoPrepareFn"
43
- on_prepare: Optional["OnPrepareFn"]
28
+ file_spec: "CreateArtifactFileSpecInput"
44
29
  response_queue: "queue.Queue[ResponsePrepare]"
45
30
 
46
31
 
@@ -54,7 +39,49 @@ class ResponsePrepare(NamedTuple):
54
39
  birth_artifact_id: str
55
40
 
56
41
 
57
- Event = Union[RequestPrepare, RequestFinish, ResponsePrepare]
42
+ Request = Union[RequestPrepare, RequestFinish]
43
+
44
+
45
+ def _clamp(x: float, low: float, high: float) -> float:
46
+ return max(low, min(x, high))
47
+
48
+
49
+ def gather_batch(
50
+ request_queue: "queue.Queue[Request]",
51
+ batch_time: float,
52
+ inter_event_time: float,
53
+ max_batch_size: int,
54
+ clock: Callable[[], float] = time.monotonic,
55
+ ) -> Tuple[bool, Sequence[RequestPrepare]]:
56
+
57
+ batch_start_time = clock()
58
+ remaining_time = batch_time
59
+
60
+ first_request = request_queue.get()
61
+ if isinstance(first_request, RequestFinish):
62
+ return True, []
63
+
64
+ batch: List[RequestPrepare] = [first_request]
65
+
66
+ while remaining_time > 0 and len(batch) < max_batch_size:
67
+ try:
68
+ request = request_queue.get(
69
+ timeout=_clamp(
70
+ x=inter_event_time,
71
+ low=1e-12, # 0 = "block forever", so just use something tiny
72
+ high=remaining_time,
73
+ ),
74
+ )
75
+ if isinstance(request, RequestFinish):
76
+ return True, batch
77
+
78
+ batch.append(request)
79
+ remaining_time = batch_time - (clock() - batch_start_time)
80
+
81
+ except queue.Empty:
82
+ break
83
+
84
+ return False, batch
58
85
 
59
86
 
60
87
  class StepPrepare:
@@ -66,68 +93,46 @@ class StepPrepare:
66
93
 
67
94
  def __init__(
68
95
  self,
69
- api: "internal_api.Api",
96
+ api: "Api",
70
97
  batch_time: float,
71
98
  inter_event_time: float,
72
99
  max_batch_size: int,
100
+ request_queue: Optional["queue.Queue[Request]"] = None,
73
101
  ) -> None:
74
102
  self._api = api
75
103
  self._inter_event_time = inter_event_time
76
104
  self._batch_time = batch_time
77
105
  self._max_batch_size = max_batch_size
78
- self._request_queue: "queue.Queue[RequestPrepare | RequestFinish]" = (
79
- queue.Queue()
80
- )
106
+ self._request_queue: "queue.Queue[Request]" = request_queue or queue.Queue()
81
107
  self._thread = threading.Thread(target=self._thread_body)
82
108
  self._thread.daemon = True
83
109
 
84
110
  def _thread_body(self) -> None:
85
111
  while True:
86
- request = self._request_queue.get()
87
- if isinstance(request, RequestFinish):
88
- break
89
- finish, batch = self._gather_batch(request)
90
- prepare_response = self._prepare_batch(batch)
91
- # send responses
92
- for prepare_request in batch:
93
- name = prepare_request.prepare_fn()["name"]
94
- response_file = prepare_response[name]
95
- upload_url = response_file["uploadUrl"]
96
- upload_headers = response_file["uploadHeaders"]
97
- birth_artifact_id = response_file["artifact"]["id"]
98
- if prepare_request.on_prepare:
99
- prepare_request.on_prepare(
100
- upload_url, upload_headers, birth_artifact_id
112
+ finish, batch = gather_batch(
113
+ request_queue=self._request_queue,
114
+ batch_time=self._batch_time,
115
+ inter_event_time=self._inter_event_time,
116
+ max_batch_size=self._max_batch_size,
117
+ )
118
+ if batch:
119
+ prepare_response = self._prepare_batch(batch)
120
+ # send responses
121
+ for prepare_request in batch:
122
+ name = prepare_request.file_spec["name"]
123
+ response_file = prepare_response[name]
124
+ upload_url = response_file["uploadUrl"]
125
+ upload_headers = response_file["uploadHeaders"]
126
+ birth_artifact_id = response_file["artifact"]["id"]
127
+ prepare_request.response_queue.put(
128
+ ResponsePrepare(upload_url, upload_headers, birth_artifact_id)
101
129
  )
102
- prepare_request.response_queue.put(
103
- ResponsePrepare(upload_url, upload_headers, birth_artifact_id)
104
- )
105
130
  if finish:
106
131
  break
107
132
 
108
- def _gather_batch(
109
- self, first_request: RequestPrepare
110
- ) -> Tuple[bool, Sequence[RequestPrepare]]:
111
- batch_start_time = time.time()
112
- batch: List[RequestPrepare] = [first_request]
113
- while True:
114
- try:
115
- request = self._request_queue.get(
116
- block=True, timeout=self._inter_event_time
117
- )
118
- if isinstance(request, RequestFinish):
119
- return True, batch
120
- batch.append(request)
121
- remaining_time = self._batch_time - (time.time() - batch_start_time)
122
- if remaining_time < 0 or len(batch) >= self._max_batch_size:
123
- break
124
- except queue.Empty:
125
- break
126
- return False, batch
127
-
128
133
  def _prepare_batch(
129
134
  self, batch: Sequence[RequestPrepare]
130
- ) -> Mapping[str, "internal_api.CreateArtifactFilesResponseFile"]:
135
+ ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
131
136
  """Execute the prepareFiles API call.
132
137
 
133
138
  Arguments:
@@ -137,14 +142,10 @@ class StepPrepare:
137
142
  an uploadUrl key. The value of the uploadUrl key is None if the file
138
143
  already exists, or a url string if the file should be uploaded.
139
144
  """
140
- file_specs: List["internal_api.CreateArtifactFileSpecInput"] = []
141
- for prepare_request in batch:
142
- file_spec = prepare_request.prepare_fn()
143
- file_specs.append(file_spec)
144
- return self._api.create_artifact_files(file_specs)
145
+ return self._api.create_artifact_files([req.file_spec for req in batch])
145
146
 
146
147
  def prepare_async(
147
- self, prepare_fn: "DoPrepareFn", on_prepare: Optional["OnPrepareFn"] = None
148
+ self, file_spec: "CreateArtifactFileSpecInput"
148
149
  ) -> "queue.Queue[ResponsePrepare]":
149
150
  """Request the backend to prepare a file for upload.
150
151
 
@@ -153,11 +154,11 @@ class StepPrepare:
153
154
  either a file upload url, or None if the file doesn't need to be uploaded.
154
155
  """
155
156
  response_queue: "queue.Queue[ResponsePrepare]" = queue.Queue()
156
- self._request_queue.put(RequestPrepare(prepare_fn, on_prepare, response_queue))
157
+ self._request_queue.put(RequestPrepare(file_spec, response_queue))
157
158
  return response_queue
158
159
 
159
- def prepare(self, prepare_fn: "DoPrepareFn") -> ResponsePrepare:
160
- return self.prepare_async(prepare_fn).get()
160
+ def prepare(self, file_spec: "CreateArtifactFileSpecInput") -> ResponsePrepare:
161
+ return self.prepare_async(file_spec).get()
161
162
 
162
163
  def start(self) -> None:
163
164
  self._thread.start()
@@ -203,7 +203,7 @@ class StepUpload:
203
203
  self._spawn_upload(job)
204
204
 
205
205
  def _spawn_upload(self, job: upload_job.UploadJob) -> None:
206
- """Spawns an upload job, and handles the bookkeeping of `self._running_jobs`.
206
+ """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
207
207
 
208
208
  Context: it's important that, whenever we add an entry to `self._running_jobs`,
209
209
  we ensure that a corresponding `EventJobDone` message will eventually get handled;
@@ -214,7 +214,6 @@ class StepUpload:
214
214
  to `self._running_jobs` is textually right next to the code that eventually enqueues
215
215
  the `EventJobDone` message. This should help keep them in sync.
216
216
  """
217
-
218
217
  # Adding the entry to `self._running_jobs` MUST happen in the main thread,
219
218
  # NOT in the job that gets submitted to the thread-pool, to guard against
220
219
  # this sequence of events:
@@ -1,6 +1,4 @@
1
- """
2
- W&B callback for CatBoost
3
- """
1
+ """W&B callback for CatBoost."""
4
2
 
5
3
  from .catboost import WandbCallback, log_summary
6
4
 
@@ -1,6 +1,4 @@
1
- """
2
- catboost init
3
- """
1
+ """catboost init."""
4
2
 
5
3
  from pathlib import Path
6
4
  from types import SimpleNamespace
@@ -65,9 +63,7 @@ class WandbCallback:
65
63
  def _checkpoint_artifact(
66
64
  model: Union[CatBoostClassifier, CatBoostRegressor], aliases: List[str]
67
65
  ) -> None:
68
- """
69
- Upload model checkpoint as W&B artifact
70
- """
66
+ """Upload model checkpoint as W&B artifact."""
71
67
  if wandb.run is None:
72
68
  raise wandb.Error(
73
69
  "You must call `wandb.init()` before `_checkpoint_artifact()`"
@@ -87,9 +83,7 @@ def _checkpoint_artifact(
87
83
  def _log_feature_importance(
88
84
  model: Union[CatBoostClassifier, CatBoostRegressor]
89
85
  ) -> None:
90
- """
91
- Log feature importance with default settings.
92
- """
86
+ """Log feature importance with default settings."""
93
87
  if wandb.run is None:
94
88
  raise wandb.Error(
95
89
  "You must call `wandb.init()` before `_checkpoint_artifact()`"
@@ -119,7 +113,7 @@ def log_summary(
119
113
  save_model_checkpoint: bool = False,
120
114
  log_feature_importance: bool = True,
121
115
  ) -> None:
122
- """`log_summary` logs useful metrics about catboost model after training is done
116
+ """`log_summary` logs useful metrics about catboost model after training is done.
123
117
 
124
118
  Arguments:
125
119
  model: it can be CatBoostClassifier or CatBoostRegressor.
@@ -136,13 +130,13 @@ def log_summary(
136
130
 
137
131
  Example:
138
132
  ```python
139
- train_pool = Pool(train[features], label=train['label'], cat_features=cat_features)
140
- test_pool = Pool(test[features], label=test['label'], cat_features=cat_features)
133
+ train_pool = Pool(train[features], label=train["label"], cat_features=cat_features)
134
+ test_pool = Pool(test[features], label=test["label"], cat_features=cat_features)
141
135
 
142
136
  model = CatBoostRegressor(
143
137
  iterations=100,
144
- loss_function='Cox',
145
- eval_metric='Cox',
138
+ loss_function="Cox",
139
+ eval_metric="Cox",
146
140
  )
147
141
 
148
142
  model.fit(
@@ -1,5 +1,5 @@
1
- """
2
- This module hooks fast.ai v1 Learners to Weights & Biases through a callback.
1
+ """Hooks that add fast.ai v1 Learners to Weights & Biases through a callback.
2
+
3
3
  Requested logged data can be configured through the callback constructor.
4
4
 
5
5
  Examples:
@@ -61,8 +61,8 @@ except ImportError:
61
61
 
62
62
 
63
63
  class WandbCallback(TrackerCallback):
64
- """
65
- Automatically saves model topology, losses & metrics.
64
+ """Callback for saving model topology, losses & metrics.
65
+
66
66
  Optionally logs weights, gradients, sample predictions and best trained model.
67
67
 
68
68
  Arguments:
@@ -92,7 +92,6 @@ class WandbCallback(TrackerCallback):
92
92
  predictions: int = 36,
93
93
  seed: int = 12345,
94
94
  ) -> None:
95
-
96
95
  # Check if wandb.init has been called
97
96
  if wandb.run is None:
98
97
  raise ValueError("You must call wandb.init() before WandbCallback()")
@@ -119,8 +118,7 @@ class WandbCallback(TrackerCallback):
119
118
  self.validation_data = [learn.data.valid_ds[i] for i in indices]
120
119
 
121
120
  def on_train_begin(self, **kwargs: Any) -> None:
122
- """Call watch method to log model topology, gradients & weights"""
123
-
121
+ """Call watch method to log model topology, gradients & weights."""
124
122
  # Set self.best, method inherited from "TrackerCallback" by "SaveModelCallback"
125
123
  super().on_train_begin()
126
124
 
@@ -134,8 +132,7 @@ class WandbCallback(TrackerCallback):
134
132
  def on_epoch_end(
135
133
  self, epoch: int, smooth_loss: float, last_metrics: list, **kwargs: Any
136
134
  ) -> None:
137
- """Logs training loss, validation loss and custom metrics & log prediction samples & save model"""
138
-
135
+ """Log training loss, validation loss and custom metrics & log prediction samples & save model."""
139
136
  if self.save_model:
140
137
  # Adapted from fast.ai "SaveModelCallback"
141
138
  current = self.get_monitor_value()
@@ -174,7 +171,6 @@ class WandbCallback(TrackerCallback):
174
171
 
175
172
  def on_train_end(self, **kwargs: Any) -> None:
176
173
  """Load the best model."""
177
-
178
174
  if self.save_model:
179
175
  # Adapted from fast.ai "SaveModelCallback"
180
176
  if self.model_path.is_file():
@@ -183,8 +179,7 @@ class WandbCallback(TrackerCallback):
183
179
  print(f"Loaded best saved model from {self.model_path}")
184
180
 
185
181
  def _wandb_log_predictions(self) -> None:
186
- """Log prediction samples"""
187
-
182
+ """Log prediction samples."""
188
183
  pred_log = []
189
184
 
190
185
  if self.validation_data is None:
@@ -234,7 +229,6 @@ class WandbCallback(TrackerCallback):
234
229
  elif hasattr(y, "shape") and (
235
230
  (len(y.shape) == 2) or (len(y.shape) == 3 and y.shape[0] in [1, 3, 4])
236
231
  ):
237
-
238
232
  pred_log.extend(
239
233
  [
240
234
  wandb.Image(x.data, caption="Input data", grouping=3),
@@ -1,21 +1,52 @@
1
1
  import re
2
+ import sys
2
3
  from typing import Optional
3
4
 
4
5
  import wandb
6
+ import wandb.util
7
+
8
+ if sys.version_info >= (3, 8):
9
+ from typing import Literal
10
+ else:
11
+ from typing_extensions import Literal
12
+
5
13
 
6
14
  _gym_version_lt_0_26: Optional[bool] = None
15
+ _required_error_msg = (
16
+ "Couldn't import the gymnasium python package, "
17
+ "install with `pip install gymnasium`"
18
+ )
19
+ GymLib = Literal["gym", "gymnasium"]
7
20
 
8
21
 
9
22
  def monitor():
23
+ """Monitor a gym environment.
24
+
25
+ Supports both gym and gymnasium.
26
+ """
27
+ gym_lib: Optional[GymLib] = None
28
+
29
+ # gym is not maintained anymore, gymnasium is the drop-in replacement - prefer it
30
+ if wandb.util.get_module("gymnasium") is not None:
31
+ gym_lib = "gymnasium"
32
+ elif wandb.util.get_module("gym") is not None:
33
+ gym_lib = "gym"
34
+
35
+ if gym_lib is None:
36
+ raise wandb.Error(_required_error_msg)
37
+
10
38
  vcr = wandb.util.get_module(
11
- "gym.wrappers.monitoring.video_recorder",
12
- required="Couldn't import the gym python package, install with `pip install gym`",
39
+ f"{gym_lib}.wrappers.monitoring.video_recorder",
40
+ required=_required_error_msg,
13
41
  )
14
42
 
15
43
  global _gym_version_lt_0_26
16
44
 
17
45
  if _gym_version_lt_0_26 is None:
18
- import gym # type: ignore
46
+ if gym_lib == "gym":
47
+ import gym
48
+ else:
49
+ import gymnasium as gym # type: ignore
19
50
  from pkg_resources import parse_version
20
51
 
21
52
  if parse_version(gym.__version__) < parse_version("0.26.0"):
@@ -47,7 +78,7 @@ def monitor():
47
78
  recorder.close = close
48
79
  wandb.patched["gym"].append(
49
80
  [
50
- f"gym.wrappers.monitoring.video_recorder.{vcr_recorder_attribute}",
81
+ f"{gym_lib}.wrappers.monitoring.video_recorder.{vcr_recorder_attribute}",
51
82
  "close",
52
83
  ]
53
84
  )
@@ -1,6 +1,6 @@
1
- """
2
- Tools for integrating `wandb` with [`Keras`](https://keras.io/),
3
- a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
1
+ """Tools for integrating `wandb` with [`Keras`](https://keras.io/).
2
+
3
+ Keras is a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
4
4
  """
5
5
  __all__ = (
6
6
  "WandbCallback",