wandb 0.21.1__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/api.py +1 -2
  4. wandb/apis/public/artifacts.py +3 -5
  5. wandb/apis/public/registries/_utils.py +14 -16
  6. wandb/apis/public/registries/registries_search.py +176 -289
  7. wandb/apis/public/reports.py +13 -10
  8. wandb/automations/_generated/delete_automation.py +1 -3
  9. wandb/automations/_generated/enums.py +13 -11
  10. wandb/bin/gpu_stats.exe +0 -0
  11. wandb/bin/wandb-core +0 -0
  12. wandb/cli/cli.py +47 -2
  13. wandb/integration/metaflow/data_pandas.py +2 -2
  14. wandb/integration/metaflow/data_pytorch.py +75 -0
  15. wandb/integration/metaflow/data_sklearn.py +76 -0
  16. wandb/integration/metaflow/metaflow.py +16 -87
  17. wandb/integration/weave/__init__.py +6 -0
  18. wandb/integration/weave/interface.py +49 -0
  19. wandb/integration/weave/weave.py +63 -0
  20. wandb/proto/v3/wandb_internal_pb2.py +3 -2
  21. wandb/proto/v4/wandb_internal_pb2.py +2 -2
  22. wandb/proto/v5/wandb_internal_pb2.py +2 -2
  23. wandb/proto/v6/wandb_internal_pb2.py +2 -2
  24. wandb/sdk/artifacts/_factories.py +17 -0
  25. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  26. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  27. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  28. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  29. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  30. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  31. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  32. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  33. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  34. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  35. wandb/sdk/artifacts/_generated/enums.py +5 -0
  36. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  37. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  38. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  39. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  40. wandb/sdk/artifacts/_generated/operations.py +654 -51
  41. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  42. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  43. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  44. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  45. wandb/sdk/artifacts/_validators.py +6 -4
  46. wandb/sdk/artifacts/artifact.py +406 -543
  47. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  48. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  49. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  50. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  51. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  53. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  54. wandb/sdk/data_types/video.py +2 -2
  55. wandb/sdk/interface/interface_queue.py +1 -4
  56. wandb/sdk/interface/interface_shared.py +26 -37
  57. wandb/sdk/interface/interface_sock.py +24 -14
  58. wandb/sdk/internal/settings_static.py +2 -3
  59. wandb/sdk/launch/create_job.py +12 -1
  60. wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
  61. wandb/sdk/lib/asyncio_compat.py +16 -16
  62. wandb/sdk/lib/asyncio_manager.py +252 -0
  63. wandb/sdk/lib/hashutil.py +13 -4
  64. wandb/sdk/lib/printer.py +2 -2
  65. wandb/sdk/lib/printer_asyncio.py +3 -1
  66. wandb/sdk/lib/retry.py +185 -78
  67. wandb/sdk/lib/service/service_client.py +106 -0
  68. wandb/sdk/lib/service/service_connection.py +20 -26
  69. wandb/sdk/lib/service/service_token.py +30 -13
  70. wandb/sdk/mailbox/mailbox.py +13 -5
  71. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  72. wandb/sdk/mailbox/response_handle.py +42 -106
  73. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  74. wandb/sdk/wandb_init.py +11 -25
  75. wandb/sdk/wandb_login.py +1 -1
  76. wandb/sdk/wandb_run.py +91 -55
  77. wandb/sdk/wandb_settings.py +45 -32
  78. wandb/sdk/wandb_setup.py +176 -96
  79. wandb/util.py +1 -1
  80. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  81. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
  82. wandb/sdk/interface/interface_relay.py +0 -38
  83. wandb/sdk/interface/router.py +0 -89
  84. wandb/sdk/interface/router_queue.py +0 -43
  85. wandb/sdk/interface/router_relay.py +0 -50
  86. wandb/sdk/interface/router_sock.py +0 -32
  87. wandb/sdk/lib/sock_client.py +0 -232
  88. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  89. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -7,7 +7,6 @@ import json
7
7
  import logging
8
8
  import os
9
9
  import time
10
- import uuid
11
10
  from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
12
11
 
13
12
  import yaml
@@ -28,6 +27,7 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
28
27
  CustomResource,
29
28
  LaunchKubernetesMonitor,
30
29
  )
30
+ from wandb.sdk.launch.utils import recursive_macro_sub
31
31
  from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
32
32
  from wandb.util import get_module
33
33
 
@@ -62,6 +62,9 @@ from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa:
62
62
  from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
63
63
  CustomObjectsApi,
64
64
  )
65
+ from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
66
+ NetworkingV1Api,
67
+ )
65
68
  from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
66
69
  V1Secret,
67
70
  )
@@ -85,6 +88,7 @@ class KubernetesSubmittedRun(AbstractRun):
85
88
  batch_api: "BatchV1Api",
86
89
  core_api: "CoreV1Api",
87
90
  apps_api: "AppsV1Api",
91
+ network_api: "NetworkingV1Api",
88
92
  name: str,
89
93
  namespace: Optional[str] = "default",
90
94
  secret: Optional["V1Secret"] = None,
@@ -103,6 +107,7 @@ class KubernetesSubmittedRun(AbstractRun):
103
107
  Arguments:
104
108
  batch_api: Kubernetes BatchV1Api object.
105
109
  core_api: Kubernetes CoreV1Api object.
110
+ network_api: Kubernetes NetworkV1Api object.
106
111
  name: Name of the job.
107
112
  namespace: Kubernetes namespace.
108
113
  secret: Kubernetes secret.
@@ -113,6 +118,7 @@ class KubernetesSubmittedRun(AbstractRun):
113
118
  self.batch_api = batch_api
114
119
  self.core_api = core_api
115
120
  self.apps_api = apps_api
121
+ self.network_api = network_api
116
122
  self.name = name
117
123
  self.namespace = namespace
118
124
  self._fail_count = 0
@@ -207,11 +213,9 @@ class KubernetesSubmittedRun(AbstractRun):
207
213
  (self.core_api, "service"),
208
214
  (self.batch_api, "job"),
209
215
  (self.core_api, "pod"),
210
- (self.core_api, "config_map"),
211
216
  (self.core_api, "secret"),
212
217
  (self.apps_api, "deployment"),
213
- (self.apps_api, "replica_set"),
214
- (self.apps_api, "daemon_set"),
218
+ (self.network_api, "network_policy"),
215
219
  ]
216
220
 
217
221
  for api_client, resource_type in resource_cleanups:
@@ -700,7 +704,6 @@ class KubernetesRunner(AbstractRunner):
700
704
  config: Dict[str, Any],
701
705
  namespace: str,
702
706
  run_id: str,
703
- auxiliary_resource_label_key: str,
704
707
  launch_project: LaunchProject,
705
708
  api_key_secret: Optional["V1Secret"] = None,
706
709
  wait_for_ready: bool = True,
@@ -713,7 +716,6 @@ class KubernetesRunner(AbstractRunner):
713
716
  config: The resource configuration to prepare.
714
717
  namespace: The namespace to create the resource in.
715
718
  run_id: The run ID to label the resource with.
716
- auxiliary_resource_label_key: The key of the auxiliary resource label.
717
719
  launch_project: The launch project to get environment variables from.
718
720
  api_key_secret: The API key secret to inject.
719
721
  wait_for_ready: Whether to wait for the resource to be ready after creation.
@@ -722,25 +724,8 @@ class KubernetesRunner(AbstractRunner):
722
724
  config.setdefault("metadata", {})
723
725
  config["metadata"].setdefault("labels", {})
724
726
  config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
725
- config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
726
- auxiliary_resource_label_key
727
- )
728
727
  config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
729
728
 
730
- if config.get("kind") == "Service" or config.get("kind") == "Deployment":
731
- config.setdefault("metadata", {})
732
- original_name = config["metadata"].get("name", config.get("kind"))
733
- safe_name = make_name_dns_safe(original_name)
734
- safe_entity = make_name_dns_safe(launch_project.target_entity or "")
735
- safe_project = make_name_dns_safe(launch_project.target_project or "")
736
- safe_run_id = make_name_dns_safe(run_id or "")
737
-
738
- new_name = f"{safe_name}-{safe_entity}-{safe_project}-{safe_run_id}"
739
- config["metadata"]["name"] = new_name
740
- wandb.termlog(
741
- f"{LOG_PREFIX}Modified {config.get('kind')} name from '{original_name}' to '{new_name}'"
742
- )
743
-
744
729
  env_vars = launch_project.get_env_vars_dict(
745
730
  self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
746
731
  )
@@ -920,19 +905,29 @@ class KubernetesRunner(AbstractRunner):
920
905
  batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
921
906
  core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
922
907
  apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
908
+ network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
923
909
 
924
910
  namespace = self.get_namespace(resource_args, context)
925
911
  job, secret = await self._inject_defaults(
926
912
  resource_args, launch_project, image_uri, namespace, core_api
927
913
  )
928
914
 
929
- additional_services = launch_project.launch_spec.get("additional_services", [])
930
- auxiliary_resource_label_key = None
915
+ update_dict = {
916
+ "project_name": launch_project.target_project,
917
+ "entity_name": launch_project.target_entity,
918
+ "run_id": launch_project.run_id,
919
+ "run_name": launch_project.name,
920
+ "image_uri": image_uri,
921
+ "author": launch_project.author,
922
+ }
923
+ update_dict.update(os.environ)
924
+ additional_services: List[Dict[str, Any]] = recursive_macro_sub(
925
+ launch_project.launch_spec.get("additional_services", []), update_dict
926
+ )
931
927
  if additional_services:
932
928
  wandb.termlog(
933
929
  f"{LOG_PREFIX}Creating additional services: {additional_services}"
934
930
  )
935
- auxiliary_resource_label_key = f"aux-{uuid.uuid4()}"
936
931
 
937
932
  wait_for_ready = resource_args.get("wait_for_ready", True)
938
933
  wait_timeout = resource_args.get("wait_timeout", 300)
@@ -941,10 +936,9 @@ class KubernetesRunner(AbstractRunner):
941
936
  *[
942
937
  self._prepare_resource(
943
938
  api_client,
944
- resource.get("config"),
939
+ resource.get("config", {}),
945
940
  namespace,
946
941
  launch_project.run_id,
947
- auxiliary_resource_label_key,
948
942
  launch_project,
949
943
  secret,
950
944
  wait_for_ready,
@@ -982,10 +976,11 @@ class KubernetesRunner(AbstractRunner):
982
976
  batch_api,
983
977
  core_api,
984
978
  apps_api,
979
+ network_api,
985
980
  job_name,
986
981
  namespace,
987
982
  secret,
988
- auxiliary_resource_label_key,
983
+ f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}",
989
984
  )
990
985
  if self.backend_config[PROJECT_SYNCHRONOUS]:
991
986
  await submitted_job.wait()
@@ -23,7 +23,7 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
23
23
  Note that due to starting a new thread, this is slightly slow.
24
24
  """
25
25
  with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
26
- runner = _Runner()
26
+ runner = CancellableRunner()
27
27
  future = executor.submit(runner.run, fn)
28
28
 
29
29
  try:
@@ -33,15 +33,16 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
33
33
  runner.cancel()
34
34
 
35
35
 
36
- class _RunnerCancelledError(Exception):
37
- """The `_Runner.run()` invocation was cancelled."""
36
+ class RunnerCancelledError(Exception):
37
+ """The `CancellableRunner.run()` invocation was cancelled."""
38
38
 
39
39
 
40
- class _Runner:
40
+ class CancellableRunner:
41
41
  """Runs an asyncio event loop allowing cancellation.
42
42
 
43
- This is like `asyncio.run()`, except it provides a `cancel()` method
44
- meant to be called in a `finally` block.
43
+ The `run()` method is like `asyncio.run()`. The `cancel()` method may
44
+ be used in a different thread, for instance in a `finally` block, to cancel
45
+ all tasks, and it is a no-op if `run()` completed.
45
46
 
46
47
  Without this, it is impossible to make `asyncio.run()` stop if it runs
47
48
  in a non-main thread. In particular, a KeyboardInterrupt causes the
@@ -69,7 +70,7 @@ class _Runner:
69
70
  The result of the coroutine returned by `fn`.
70
71
 
71
72
  Raises:
72
- _RunnerCancelledError: If `cancel()` is called.
73
+ RunnerCancelledError: If `cancel()` is called.
73
74
  """
74
75
  return asyncio.run(self._run_or_cancel(fn))
75
76
 
@@ -79,7 +80,7 @@ class _Runner:
79
80
  ) -> _T:
80
81
  with self._lock:
81
82
  if self._is_cancelled:
82
- raise _RunnerCancelledError()
83
+ raise RunnerCancelledError()
83
84
 
84
85
  self._loop = asyncio.get_running_loop()
85
86
  self._cancel_event = asyncio.Event()
@@ -97,7 +98,7 @@ class _Runner:
97
98
  if fn_task.done():
98
99
  return fn_task.result()
99
100
  else:
100
- raise _RunnerCancelledError()
101
+ raise RunnerCancelledError()
101
102
 
102
103
  finally:
103
104
  # NOTE: asyncio.run() cancels all tasks after the main task exits,
@@ -157,11 +158,9 @@ class TaskGroup:
157
158
  )
158
159
 
159
160
  for task in done:
160
- try:
161
+ with contextlib.suppress(asyncio.CancelledError):
161
162
  if exc := task.exception():
162
163
  raise exc
163
- except asyncio.CancelledError:
164
- pass
165
164
 
166
165
  def _cancel_all(self) -> None:
167
166
  """Cancel all tasks."""
@@ -199,15 +198,16 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
199
198
  def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
200
199
  """Schedule a task, cancelling it when exiting the context manager.
201
200
 
202
- If the given coroutine raises an exception, that exception is raised
203
- when exiting the context manager.
201
+ If the context manager exits successfully but the given coroutine raises
202
+ an exception, that exception is reraised. The exception is suppressed
203
+ if the context manager raises an exception.
204
204
  """
205
205
  task = asyncio.create_task(coro)
206
206
 
207
207
  try:
208
208
  yield
209
- finally:
209
+
210
210
  if task.done() and (exception := task.exception()):
211
211
  raise exception
212
-
212
+ finally:
213
213
  task.cancel()
@@ -0,0 +1,252 @@
1
+ """Implements an asyncio thread suitable for internal wandb use."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import concurrent.futures
7
+ import contextlib
8
+ import logging
9
+ import threading
10
+ from typing import Any, Callable, Coroutine, TypeVar
11
+
12
+ from . import asyncio_compat
13
+
14
+ _T = TypeVar("_T")
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+
19
+ class RunCancelledError(Exception):
20
+ """A function passed to AsyncioManager.run() was cancelled."""
21
+
22
+
23
+ class AlreadyJoinedError(Exception):
24
+ """AsyncioManager.run() used after join()."""
25
+
26
+
27
+ class AsyncioManager:
28
+ """Manages a thread running an asyncio loop.
29
+
30
+ The thread must be started using start() and should be joined using
31
+ join(). The thread is a daemon thread, so if join() is not invoked,
32
+ the asyncio work could end abruptly when all non-daemon threads exit.
33
+
34
+ The run() method allows invoking an async function in the asyncio thread
35
+ and waiting until it completes. The run_soon() method allows running
36
+ an async function without waiting for it.
37
+
38
+ Note that although tempting, it is **not** possible to write a safe
39
+ run_in_loop() method that chooses whether to use run() or execute a function
40
+ directly based on whether it's called from the asyncio thread: Suppose a
41
+ function bad() holds a threading.Lock while using run_in_loop() and an
42
+ asyncio task calling bad() is scheduled. If bad() is then invoked in a
43
+ different thread that reaches run_in_loop(), the aforementioned asyncio task
44
+ will deadlock. It is unreasonable to require that run_in_loop() never be
45
+ called while holding a lock (which would apply to the callers of its
46
+ callers, and so on), so it cannot safely exist.
47
+ """
48
+
49
+ def __init__(self) -> None:
50
+ self._runner = asyncio_compat.CancellableRunner()
51
+ self._thread = threading.Thread(
52
+ target=self._main,
53
+ name="wandb-AsyncioManager-main",
54
+ daemon=True,
55
+ )
56
+ self._lock = threading.Lock()
57
+
58
+ self._ready_event = threading.Event()
59
+ """Whether asyncio primitives have been initialized."""
60
+
61
+ self._joined = False
62
+ """Whether join() has been called. Guarded by _lock."""
63
+
64
+ self._loop: asyncio.AbstractEventLoop
65
+ """A handle for interacting with the asyncio event loop."""
66
+
67
+ self._done_event: asyncio.Event
68
+ """Indicates to the asyncio loop that join() was called."""
69
+
70
+ self._remaining_tasks = 0
71
+ """The number of tasks remaining. Guarded by _lock."""
72
+
73
+ self._task_finished_cond: asyncio.Condition
74
+ """Signalled when _remaining_tasks is decremented."""
75
+
76
+ def start(self) -> None:
77
+ """Start the asyncio thread."""
78
+ self._thread.start()
79
+
80
+ def join(self) -> None:
81
+ """Stop accepting new asyncio tasks and wait for the remaining ones."""
82
+ try:
83
+ with self._lock:
84
+ # If join() was already called, block until the thread completes
85
+ # and then return.
86
+ if self._joined:
87
+ self._thread.join()
88
+ return
89
+
90
+ self._joined = True
91
+
92
+ # Wait until _loop and _done_event are initialized.
93
+ self._ready_event.wait()
94
+
95
+ # Set the done event. The main function will exit once all
96
+ # tasks complete.
97
+ self._loop.call_soon_threadsafe(self._done_event.set)
98
+
99
+ self._thread.join()
100
+
101
+ finally:
102
+ # Any of the above may get interrupted by Ctrl+C, in which case we
103
+ # should cancel all tasks, since join() can only be called once.
104
+ # This only matters if the KeyboardInterrupt is suppressed.
105
+ self._runner.cancel()
106
+
107
+ def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
108
+ """Run an async function to completion.
109
+
110
+ The function is called in the asyncio thread. Blocks until start()
111
+ is called. This raises an error if called inside an async function,
112
+ and as a consequence, the caller may also not be called inside an
113
+ async function.
114
+
115
+ Args:
116
+ fn: The function to run.
117
+
118
+ Returns:
119
+ The return value of fn.
120
+
121
+ Raises:
122
+ Exception: Any exception raised by fn.
123
+ RunCancelledError: If fn is cancelled, particularly when join()
124
+ is interrupted by Ctrl+C or if it otherwise cancels itself.
125
+ AlreadyJoinedError: If join() was already called.
126
+ ValueError: If called inside an async function.
127
+ """
128
+ self._ready_event.wait()
129
+
130
+ if threading.current_thread().ident == self._thread.ident:
131
+ raise ValueError("Cannot use run() inside async loop.")
132
+
133
+ future = self._schedule(fn, daemon=False)
134
+
135
+ try:
136
+ return future.result()
137
+
138
+ except concurrent.futures.CancelledError:
139
+ raise RunCancelledError from None
140
+
141
+ except KeyboardInterrupt:
142
+ # If we're interrupted here, we only cancel this task rather than
143
+ # cancelling all tasks like in join(). This only matters if the
144
+ # interrupt is then suppressed (or delayed) in which case we
145
+ # should let other tasks progress.
146
+ future.cancel()
147
+ raise
148
+
149
+ def run_soon(
150
+ self,
151
+ fn: Callable[[], Coroutine[Any, Any, None]],
152
+ *,
153
+ daemon: bool = False,
154
+ name: str | None = None,
155
+ ) -> None:
156
+ """Run an async function without waiting for it to complete.
157
+
158
+ The function is called in the asyncio thread. Note that since that's
159
+ a daemon thread, it will not get joined when the main thread exits,
160
+ so fn can stop abruptly.
161
+
162
+ Unlike run(), it is OK to call this inside an async function.
163
+
164
+ Blocks until start() is called.
165
+
166
+ Args:
167
+ fn: The function to run.
168
+ daemon: If true, join() will cancel fn after all non-daemon
169
+ tasks complete. By default, join() blocks until fn
170
+ completes.
171
+ name: An optional name to give to long-running tasks which can
172
+ appear in error traces and be useful to debugging.
173
+
174
+ Raises:
175
+ AlreadyJoinedError: If join() was already called.
176
+ """
177
+
178
+ # Wrap exceptions so that they're not printed to console.
179
+ async def fn_wrap_exceptions() -> None:
180
+ try:
181
+ await fn()
182
+ except Exception:
183
+ _logger.exception("Uncaught exception in run_soon callback.")
184
+
185
+ _ = self._schedule(fn_wrap_exceptions, daemon=daemon, name=name)
186
+
187
+ def _schedule(
188
+ self,
189
+ fn: Callable[[], Coroutine[Any, Any, _T]],
190
+ daemon: bool,
191
+ name: str | None = None,
192
+ ) -> concurrent.futures.Future[_T]:
193
+ # Wait for _loop to be initialized.
194
+ self._ready_event.wait()
195
+
196
+ with self._lock:
197
+ if self._joined:
198
+ raise AlreadyJoinedError
199
+
200
+ if not daemon:
201
+ self._remaining_tasks += 1
202
+
203
+ return asyncio.run_coroutine_threadsafe(
204
+ self._wrap(fn, daemon=daemon, name=name),
205
+ self._loop,
206
+ )
207
+
208
+ async def _wrap(
209
+ self,
210
+ fn: Callable[[], Coroutine[Any, Any, _T]],
211
+ daemon: bool,
212
+ name: str | None,
213
+ ) -> _T:
214
+ """Run fn to completion and possibly decrement _remaining tasks."""
215
+ try:
216
+ if name and (task := asyncio.current_task()):
217
+ task.set_name(name)
218
+
219
+ return await fn()
220
+ finally:
221
+ if not daemon:
222
+ async with self._task_finished_cond:
223
+ with self._lock:
224
+ self._remaining_tasks -= 1
225
+ self._task_finished_cond.notify_all()
226
+
227
+ def _main(self) -> None:
228
+ """Run the asyncio loop until join() is called and all tasks finish."""
229
+ # A cancellation error is expected if join() is interrupted.
230
+ #
231
+ # Were it not suppressed, its stacktrace would get printed.
232
+ with contextlib.suppress(asyncio_compat.RunnerCancelledError):
233
+ self._runner.run(self._main_async)
234
+
235
+ async def _main_async(self) -> None:
236
+ """Wait until join() is called and all tasks finish."""
237
+ self._loop = asyncio.get_running_loop()
238
+ self._done_event = asyncio.Event()
239
+ self._task_finished_cond = asyncio.Condition()
240
+
241
+ self._ready_event.set()
242
+
243
+ # Wait until done.
244
+ await self._done_event.wait()
245
+
246
+ # Wait for all tasks to complete.
247
+ #
248
+ # Once we exit, asyncio will cancel any leftover tasks.
249
+ async with self._task_finished_cond:
250
+ await self._task_finished_cond.wait_for(
251
+ lambda: self._remaining_tasks <= 0,
252
+ )
wandb/sdk/lib/hashutil.py CHANGED
@@ -6,19 +6,28 @@ import logging
6
6
  import mmap
7
7
  import sys
8
8
  import time
9
- from typing import TYPE_CHECKING, NewType
9
+ from typing import TYPE_CHECKING
10
+
11
+ from typing_extensions import TypeAlias
10
12
 
11
13
  from wandb.sdk.lib.paths import StrPath
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  import _hashlib # type: ignore[import-not-found]
15
17
 
16
- ETag = NewType("ETag", str)
17
- HexMD5 = NewType("HexMD5", str)
18
- B64MD5 = NewType("B64MD5", str)
19
18
 
20
19
  logger = logging.getLogger(__name__)
21
20
 
21
+ # In the future, consider relying on pydantic to validate these types via e.g.
22
+ # - Base64Str: https://docs.pydantic.dev/latest/api/types/#pydantic.types.Base64Str
23
+ # - a custom EncodedStr + Encoder impl: https://docs.pydantic.dev/latest/api/types/#pydantic.types.EncodedStr
24
+ #
25
+ # Note that so long as we continue to support Pydantic v1, the options above will require a compatible shim/backport
26
+ # implementation, since those types are not in Pydantic v1.
27
+ ETag: TypeAlias = str
28
+ HexMD5: TypeAlias = str
29
+ B64MD5: TypeAlias = str
30
+
22
31
 
23
32
  def _md5(data: bytes = b"") -> _hashlib.HASH:
24
33
  """Allow FIPS-compliant md5 hash when supported."""
wandb/sdk/lib/printer.py CHANGED
@@ -102,8 +102,8 @@ def new_printer(settings: wandb.Settings | None = None) -> Printer:
102
102
  has been called, then global settings are used. Otherwise,
103
103
  settings (such as silent mode) are ignored.
104
104
  """
105
- if not settings and (singleton := wandb_setup.singleton_if_setup()):
106
- settings = singleton.settings
105
+ if not settings and (s := wandb_setup.singleton().settings_if_loaded):
106
+ settings = s
107
107
 
108
108
  if ipython.in_jupyter():
109
109
  return _PrinterJupyter(settings=settings)
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  from typing import Callable, TypeVar
3
3
 
4
+ from wandb.sdk import wandb_setup
4
5
  from wandb.sdk.lib import asyncio_compat, printer
5
6
 
6
7
  _T = TypeVar("_T")
@@ -43,4 +44,5 @@ def run_async_with_spinner(
43
44
  func_running.set()
44
45
  return res
45
46
 
46
- return asyncio_compat.run(_loop_run_with_spinner)
47
+ asyncer = wandb_setup.singleton().asyncer
48
+ return asyncer.run(_loop_run_with_spinner)