wandb 0.21.2__py3-none-win_amd64.whl → 0.21.4__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/_analytics.py +65 -0
- wandb/_iterutils.py +8 -0
- wandb/_pydantic/__init__.py +10 -11
- wandb/_pydantic/base.py +3 -53
- wandb/_pydantic/field_types.py +29 -0
- wandb/_pydantic/v1_compat.py +47 -30
- wandb/_strutils.py +40 -0
- wandb/apis/public/api.py +17 -4
- wandb/apis/public/artifacts.py +5 -4
- wandb/apis/public/automations.py +2 -1
- wandb/apis/public/registries/_freezable_list.py +6 -6
- wandb/apis/public/registries/_utils.py +2 -1
- wandb/apis/public/registries/registries_search.py +4 -0
- wandb/apis/public/registries/registry.py +7 -0
- wandb/automations/_filters/expressions.py +3 -2
- wandb/automations/_filters/operators.py +2 -1
- wandb/automations/_validators.py +20 -0
- wandb/automations/actions.py +4 -2
- wandb/automations/events.py +4 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +48 -130
- wandb/cli/beta_sync.py +226 -0
- wandb/cli/cli.py +1 -1
- wandb/integration/dspy/__init__.py +5 -0
- wandb/integration/dspy/dspy.py +422 -0
- wandb/integration/weave/weave.py +55 -0
- wandb/proto/v3/wandb_server_pb2.py +38 -57
- wandb/proto/v3/wandb_sync_pb2.py +87 -0
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_server_pb2.py +38 -41
- wandb/proto/v4/wandb_sync_pb2.py +38 -0
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_server_pb2.py +38 -41
- wandb/proto/v5/wandb_sync_pb2.py +39 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_server_pb2.py +38 -41
- wandb/proto/v6/wandb_sync_pb2.py +49 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/proto/wandb_sync_pb2.py +12 -0
- wandb/sdk/artifacts/_validators.py +50 -49
- wandb/sdk/artifacts/artifact.py +11 -11
- wandb/sdk/artifacts/artifact_file_cache.py +1 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -8
- wandb/sdk/artifacts/exceptions.py +2 -1
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +2 -1
- wandb/sdk/launch/inputs/internal.py +25 -24
- wandb/sdk/launch/inputs/schema.py +31 -1
- wandb/sdk/lib/asyncio_compat.py +88 -23
- wandb/sdk/lib/gql_request.py +18 -7
- wandb/sdk/lib/paths.py +23 -21
- wandb/sdk/lib/printer.py +9 -13
- wandb/sdk/lib/progress.py +8 -6
- wandb/sdk/lib/service/service_connection.py +42 -12
- wandb/sdk/mailbox/wait_with_progress.py +1 -1
- wandb/sdk/wandb_init.py +0 -8
- wandb/sdk/wandb_run.py +14 -2
- wandb/sdk/wandb_settings.py +55 -0
- wandb/sdk/wandb_setup.py +2 -2
- {wandb-0.21.2.dist-info → wandb-0.21.4.dist-info}/METADATA +2 -2
- {wandb-0.21.2.dist-info → wandb-0.21.4.dist-info}/RECORD +68 -57
- {wandb-0.21.2.dist-info → wandb-0.21.4.dist-info}/WHEEL +0 -0
- {wandb-0.21.2.dist-info → wandb-0.21.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.2.dist-info → wandb-0.21.4.dist-info}/licenses/LICENSE +0 -0
@@ -87,7 +87,7 @@ class ArtifactFileCache:
|
|
87
87
|
) -> tuple[FilePathStr, bool, Opener]:
|
88
88
|
opener = self._opener(path, size, skip_cache=skip_cache)
|
89
89
|
hit = path.is_file() and path.stat().st_size == size
|
90
|
-
return FilePathStr(
|
90
|
+
return FilePathStr(path), hit, opener
|
91
91
|
|
92
92
|
def cleanup(
|
93
93
|
self,
|
@@ -11,8 +11,8 @@ from typing import TYPE_CHECKING
|
|
11
11
|
from urllib.parse import urlparse
|
12
12
|
|
13
13
|
from wandb.proto.wandb_deprecated import Deprecated
|
14
|
-
from wandb.sdk.lib import filesystem
|
15
14
|
from wandb.sdk.lib.deprecate import deprecate
|
15
|
+
from wandb.sdk.lib.filesystem import copy_or_overwrite_changed
|
16
16
|
from wandb.sdk.lib.hashutil import (
|
17
17
|
B64MD5,
|
18
18
|
ETag,
|
@@ -184,13 +184,11 @@ class ArtifactManifestEntry:
|
|
184
184
|
executor=executor,
|
185
185
|
multipart=multipart,
|
186
186
|
)
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
str(filesystem.copy_or_overwrite_changed(cache_path, dest_path))
|
193
|
-
)
|
187
|
+
return FilePathStr(
|
188
|
+
dest_path
|
189
|
+
if skip_cache
|
190
|
+
else copy_or_overwrite_changed(cache_path, dest_path)
|
191
|
+
)
|
194
192
|
|
195
193
|
def ref_target(self) -> FilePathStr | URIStr:
|
196
194
|
"""Get the reference URL that is targeted by this artifact entry.
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
5
5
|
from typing import TYPE_CHECKING, TypeVar
|
6
6
|
|
7
7
|
from wandb import errors
|
8
|
+
from wandb._strutils import nameof
|
8
9
|
|
9
10
|
if TYPE_CHECKING:
|
10
11
|
from wandb.sdk.artifacts.artifact import Artifact
|
@@ -39,7 +40,7 @@ class ArtifactNotLoggedError(ArtifactStatusError):
|
|
39
40
|
*_, name = fullname.split(".")
|
40
41
|
msg = (
|
41
42
|
f"{fullname!r} used prior to logging artifact or while in offline mode. "
|
42
|
-
f"Call {
|
43
|
+
f"Call {nameof(obj.wait)}() before accessing logged artifact properties."
|
43
44
|
)
|
44
45
|
super().__init__(msg=msg, name=name, obj=obj)
|
45
46
|
|
@@ -198,7 +198,7 @@ class GCSHandler(StorageHandler):
|
|
198
198
|
posix_ref = posix_path / relpath
|
199
199
|
return ArtifactManifestEntry(
|
200
200
|
path=posix_name,
|
201
|
-
ref=URIStr(f"{self._scheme}://{
|
201
|
+
ref=URIStr(f"{self._scheme}://{posix_ref}"),
|
202
202
|
digest=obj.etag,
|
203
203
|
size=obj.size,
|
204
204
|
extra={"versionID": obj.generation},
|
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Sequence
|
|
10
10
|
from urllib.parse import parse_qsl, urlparse
|
11
11
|
|
12
12
|
from wandb import util
|
13
|
+
from wandb._strutils import ensureprefix
|
13
14
|
from wandb.errors import CommError
|
14
15
|
from wandb.errors.term import termlog
|
15
16
|
from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
|
@@ -328,7 +329,7 @@ class S3Handler(StorageHandler):
|
|
328
329
|
return True
|
329
330
|
|
330
331
|
# Enforce HTTPS otherwise
|
331
|
-
https_url = url
|
332
|
+
https_url = ensureprefix(url, "https://")
|
332
333
|
netloc = urlparse(https_url).netloc
|
333
334
|
return bool(
|
334
335
|
# Match for https://cwobject.com
|
@@ -171,6 +171,29 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
|
|
171
171
|
return ret
|
172
172
|
|
173
173
|
|
174
|
+
def _prepare_schema(schema: Any) -> dict:
|
175
|
+
"""Prepare a schema for validation.
|
176
|
+
|
177
|
+
This function prepares a schema for validation by:
|
178
|
+
1. Converting a Pydantic model instance or class to a dict
|
179
|
+
2. Replacing $ref with their associated definition in defs
|
180
|
+
3. Removing any "allOf" lists that only have one item, "lifting" the item up
|
181
|
+
|
182
|
+
We support both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
183
|
+
or the BaseModel class itself (e.g. schema=MySchema)
|
184
|
+
"""
|
185
|
+
if hasattr(schema, "model_json_schema") and callable(
|
186
|
+
schema.model_json_schema # type: ignore
|
187
|
+
):
|
188
|
+
schema = schema.model_json_schema()
|
189
|
+
if not isinstance(schema, dict):
|
190
|
+
raise LaunchError(
|
191
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
192
|
+
)
|
193
|
+
defs = schema.pop("$defs", None)
|
194
|
+
return _replace_refs_and_allofs(schema, defs)
|
195
|
+
|
196
|
+
|
174
197
|
def _validate_schema(schema: dict) -> None:
|
175
198
|
jsonschema = get_module(
|
176
199
|
"jsonschema",
|
@@ -210,18 +233,7 @@ def handle_config_file_input(
|
|
210
233
|
dest,
|
211
234
|
)
|
212
235
|
if schema:
|
213
|
-
|
214
|
-
# or the BaseModel class itself (e.g. schema=MySchema)
|
215
|
-
if hasattr(schema, "model_json_schema") and callable(
|
216
|
-
schema.model_json_schema # type: ignore
|
217
|
-
):
|
218
|
-
schema = schema.model_json_schema()
|
219
|
-
if not isinstance(schema, dict):
|
220
|
-
raise LaunchError(
|
221
|
-
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
222
|
-
)
|
223
|
-
defs = schema.pop("$defs", None)
|
224
|
-
schema = _replace_refs_and_allofs(schema, defs)
|
236
|
+
schema = _prepare_schema(schema)
|
225
237
|
_validate_schema(schema)
|
226
238
|
arguments = JobInputArguments(
|
227
239
|
include=include,
|
@@ -251,18 +263,7 @@ def handle_run_config_input(
|
|
251
263
|
when a run is created.
|
252
264
|
"""
|
253
265
|
if schema:
|
254
|
-
|
255
|
-
# or the BaseModel class itself (e.g. schema=MySchema)
|
256
|
-
if hasattr(schema, "model_json_schema") and callable(
|
257
|
-
schema.model_json_schema # type: ignore
|
258
|
-
):
|
259
|
-
schema = schema.model_json_schema()
|
260
|
-
if not isinstance(schema, dict):
|
261
|
-
raise LaunchError(
|
262
|
-
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
263
|
-
)
|
264
|
-
defs = schema.pop("$defs", None)
|
265
|
-
schema = _replace_refs_and_allofs(schema, defs)
|
266
|
+
schema = _prepare_schema(schema)
|
266
267
|
_validate_schema(schema)
|
267
268
|
arguments = JobInputArguments(
|
268
269
|
include=include,
|
@@ -3,7 +3,7 @@ META_SCHEMA = {
|
|
3
3
|
"properties": {
|
4
4
|
"type": {
|
5
5
|
"type": "string",
|
6
|
-
"enum": ["boolean", "integer", "number", "string", "object"],
|
6
|
+
"enum": ["boolean", "integer", "number", "string", "object", "array"],
|
7
7
|
},
|
8
8
|
"title": {"type": "string"},
|
9
9
|
"description": {"type": "string"},
|
@@ -11,6 +11,11 @@ META_SCHEMA = {
|
|
11
11
|
"enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
|
12
12
|
"properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
|
13
13
|
"allOf": {"type": "array", "items": {"$ref": "#"}},
|
14
|
+
# Array-specific properties
|
15
|
+
"items": {"$ref": "#"},
|
16
|
+
"uniqueItems": {"type": "boolean"},
|
17
|
+
"minItems": {"type": "integer", "minimum": 0},
|
18
|
+
"maxItems": {"type": "integer", "minimum": 0},
|
14
19
|
},
|
15
20
|
"allOf": [
|
16
21
|
{
|
@@ -35,6 +40,31 @@ META_SCHEMA = {
|
|
35
40
|
}
|
36
41
|
},
|
37
42
|
},
|
43
|
+
{
|
44
|
+
"if": {"properties": {"type": {"const": "array"}}},
|
45
|
+
"then": {
|
46
|
+
"required": ["items"],
|
47
|
+
"properties": {
|
48
|
+
"items": {
|
49
|
+
"properties": {
|
50
|
+
"type": {"enum": ["integer", "number", "string"]},
|
51
|
+
"enum": {
|
52
|
+
"type": "array",
|
53
|
+
"items": {"type": ["integer", "number", "string"]},
|
54
|
+
},
|
55
|
+
"title": {"type": "string"},
|
56
|
+
"description": {"type": "string"},
|
57
|
+
"format": {"type": "string"},
|
58
|
+
},
|
59
|
+
"required": ["type", "enum"],
|
60
|
+
"unevaluatedProperties": False,
|
61
|
+
},
|
62
|
+
"uniqueItems": {"type": "boolean"},
|
63
|
+
"minItems": {"type": "integer", "minimum": 0},
|
64
|
+
"maxItems": {"type": "integer", "minimum": 0},
|
65
|
+
},
|
66
|
+
},
|
67
|
+
},
|
38
68
|
],
|
39
69
|
"unevaluatedProperties": False,
|
40
70
|
}
|
wandb/sdk/lib/asyncio_compat.py
CHANGED
@@ -7,7 +7,7 @@ import concurrent
|
|
7
7
|
import concurrent.futures
|
8
8
|
import contextlib
|
9
9
|
import threading
|
10
|
-
from typing import Any, AsyncIterator, Callable, Coroutine,
|
10
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, TypeVar
|
11
11
|
|
12
12
|
_T = TypeVar("_T")
|
13
13
|
|
@@ -143,34 +143,71 @@ class TaskGroup:
|
|
143
143
|
"""
|
144
144
|
self._tasks.append(asyncio.create_task(coro))
|
145
145
|
|
146
|
-
async def _wait_all(self) -> None:
|
147
|
-
"""Block until
|
146
|
+
async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
|
147
|
+
"""Block until tasks complete.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
race: If true, blocks until the first task completes and then
|
151
|
+
cancels the rest. Otherwise, waits for all tasks or until
|
152
|
+
the first exception.
|
153
|
+
timeout: How long to wait.
|
148
154
|
|
149
155
|
Raises:
|
156
|
+
TimeoutError: If the timeout expires.
|
150
157
|
Exception: If one or more tasks raises an exception, one of these
|
151
158
|
is raised arbitrarily.
|
152
159
|
"""
|
153
|
-
|
160
|
+
if not self._tasks:
|
161
|
+
return
|
162
|
+
|
163
|
+
if race:
|
164
|
+
return_when = asyncio.FIRST_COMPLETED
|
165
|
+
else:
|
166
|
+
return_when = asyncio.FIRST_EXCEPTION
|
167
|
+
|
168
|
+
done, pending = await asyncio.wait(
|
154
169
|
self._tasks,
|
155
|
-
|
156
|
-
|
157
|
-
return_when=concurrent.futures.FIRST_EXCEPTION,
|
170
|
+
timeout=timeout,
|
171
|
+
return_when=return_when,
|
158
172
|
)
|
159
173
|
|
174
|
+
if not done:
|
175
|
+
raise TimeoutError(f"Timed out after {timeout} seconds.")
|
176
|
+
|
177
|
+
# If any of the finished tasks raised an exception, pick the first one.
|
160
178
|
for task in done:
|
161
|
-
|
162
|
-
|
163
|
-
raise exc
|
179
|
+
if exc := task.exception():
|
180
|
+
raise exc
|
164
181
|
|
165
|
-
|
166
|
-
|
182
|
+
# Wait for remaining tasks to clean up, then re-raise any exceptions
|
183
|
+
# that arise. Note that pending is only non-empty when race=True.
|
184
|
+
for task in pending:
|
185
|
+
task.cancel()
|
186
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
187
|
+
for task in pending:
|
188
|
+
if task.cancelled():
|
189
|
+
continue
|
190
|
+
if exc := task.exception():
|
191
|
+
raise exc
|
192
|
+
|
193
|
+
async def _cancel_all(self) -> None:
|
194
|
+
"""Cancel all tasks.
|
195
|
+
|
196
|
+
Blocks until cancelled tasks complete to allow them to clean up.
|
197
|
+
Ignores exceptions.
|
198
|
+
"""
|
167
199
|
for task in self._tasks:
|
168
200
|
# NOTE: It is safe to cancel tasks that have already completed.
|
169
201
|
task.cancel()
|
202
|
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
170
203
|
|
171
204
|
|
172
205
|
@contextlib.asynccontextmanager
|
173
|
-
async def open_task_group(
|
206
|
+
async def open_task_group(
|
207
|
+
*,
|
208
|
+
exit_timeout: float | None = None,
|
209
|
+
race: bool = False,
|
210
|
+
) -> AsyncIterator[TaskGroup]:
|
174
211
|
"""Create a task group.
|
175
212
|
|
176
213
|
`asyncio` gained task groups in Python 3.11.
|
@@ -184,30 +221,58 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
|
|
184
221
|
NOTE: Subtask exceptions do not propagate until the context manager exits.
|
185
222
|
This means that the task group cannot cancel code running inside the
|
186
223
|
`async with` block .
|
224
|
+
|
225
|
+
Args:
|
226
|
+
exit_timeout: An optional timeout in seconds. When exiting the
|
227
|
+
context manager, if tasks don't complete in this time,
|
228
|
+
they are cancelled and a TimeoutError is raised.
|
229
|
+
race: If true, all pending tasks are cancelled once any task
|
230
|
+
in the group completes. Prefer to use the race() function instead.
|
231
|
+
|
232
|
+
Raises:
|
233
|
+
TimeoutError: if exit_timeout is specified and tasks don't finish
|
234
|
+
in time.
|
187
235
|
"""
|
188
236
|
task_group = TaskGroup()
|
189
237
|
|
190
238
|
try:
|
191
239
|
yield task_group
|
192
|
-
await task_group._wait_all()
|
240
|
+
await task_group._wait_all(race=race, timeout=exit_timeout)
|
193
241
|
finally:
|
194
|
-
task_group._cancel_all()
|
242
|
+
await task_group._cancel_all()
|
195
243
|
|
196
244
|
|
197
|
-
@contextlib.
|
198
|
-
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) ->
|
245
|
+
@contextlib.asynccontextmanager
|
246
|
+
async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
|
199
247
|
"""Schedule a task, cancelling it when exiting the context manager.
|
200
248
|
|
201
249
|
If the context manager exits successfully but the given coroutine raises
|
202
250
|
an exception, that exception is reraised. The exception is suppressed
|
203
251
|
if the context manager raises an exception.
|
204
252
|
"""
|
205
|
-
task = asyncio.create_task(coro)
|
206
253
|
|
207
|
-
|
254
|
+
async def stop_immediately():
|
255
|
+
pass
|
256
|
+
|
257
|
+
async with open_task_group(race=True) as group:
|
258
|
+
group.start_soon(stop_immediately())
|
259
|
+
group.start_soon(coro)
|
208
260
|
yield
|
209
261
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
262
|
+
|
263
|
+
async def race(*coros: Coroutine[Any, Any, Any]) -> None:
|
264
|
+
"""Wait until the first completed task.
|
265
|
+
|
266
|
+
After any coroutine completes, all others are cancelled.
|
267
|
+
If the current task is cancelled, all coroutines are cancelled too.
|
268
|
+
|
269
|
+
If coroutines complete simultaneously and any one of them raises
|
270
|
+
an exception, an arbitrary one is propagated. Similarly, if any coroutines
|
271
|
+
raise exceptions during cancellation, one of them propagates.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
coros: Coroutines to race.
|
275
|
+
"""
|
276
|
+
async with open_task_group(race=True) as tg:
|
277
|
+
for coro in coros:
|
278
|
+
tg.start_soon(coro)
|
wandb/sdk/lib/gql_request.py
CHANGED
@@ -4,7 +4,9 @@ Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
|
|
4
4
|
The only substantial change is to reuse a requests.Session object.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
from typing import Any, Callable
|
8
10
|
|
9
11
|
import requests
|
10
12
|
from wandb_gql.transport.http import HTTPTransport
|
@@ -12,15 +14,17 @@ from wandb_graphql.execution import ExecutionResult
|
|
12
14
|
from wandb_graphql.language import ast
|
13
15
|
from wandb_graphql.language.printer import print_ast
|
14
16
|
|
17
|
+
from wandb._analytics import tracked_func
|
18
|
+
|
15
19
|
|
16
20
|
class GraphQLSession(HTTPTransport):
|
17
21
|
def __init__(
|
18
22
|
self,
|
19
23
|
url: str,
|
20
|
-
auth:
|
24
|
+
auth: tuple[str, str] | Callable | None = None,
|
21
25
|
use_json: bool = False,
|
22
|
-
timeout:
|
23
|
-
proxies:
|
26
|
+
timeout: int | float | None = None,
|
27
|
+
proxies: dict[str, str] | None = None,
|
24
28
|
**kwargs: Any,
|
25
29
|
) -> None:
|
26
30
|
"""Setup a session for sending GraphQL queries and mutations.
|
@@ -42,15 +46,22 @@ class GraphQLSession(HTTPTransport):
|
|
42
46
|
def execute(
|
43
47
|
self,
|
44
48
|
document: ast.Node,
|
45
|
-
variable_values:
|
46
|
-
timeout:
|
49
|
+
variable_values: dict[str, Any] | None = None,
|
50
|
+
timeout: int | float | None = None,
|
47
51
|
) -> ExecutionResult:
|
48
52
|
query_str = print_ast(document)
|
49
53
|
payload = {"query": query_str, "variables": variable_values or {}}
|
50
54
|
|
51
55
|
data_key = "json" if self.use_json else "data"
|
56
|
+
|
57
|
+
headers = self.headers.copy() if self.headers else {}
|
58
|
+
|
59
|
+
# If we're tracking a calling python function, include it in the headers
|
60
|
+
if func_info := tracked_func():
|
61
|
+
headers.update(func_info.to_headers())
|
62
|
+
|
52
63
|
post_args = {
|
53
|
-
"headers":
|
64
|
+
"headers": headers or None,
|
54
65
|
"cookies": self.cookies,
|
55
66
|
"timeout": timeout or self.default_timeout,
|
56
67
|
data_key: payload,
|
wandb/sdk/lib/paths.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import platform
|
3
5
|
from functools import wraps
|
4
6
|
from pathlib import PurePath, PurePosixPath
|
5
|
-
from typing import Any,
|
7
|
+
from typing import Any, Union
|
8
|
+
|
9
|
+
from typing_extensions import TypeAlias
|
6
10
|
|
7
11
|
# Path _inputs_ should generally accept any kind of path. This is named the same and
|
8
12
|
# modeled after the hint defined in the Python standard library's `typeshed`:
|
9
13
|
# https://github.com/python/typeshed/blob/0b1cd5989669544866213807afa833a88f649ee7/stdlib/_typeshed/__init__.pyi#L56-L65
|
10
|
-
StrPath = Union[str, "os.PathLike[str]"]
|
11
|
-
|
12
|
-
# A native path to a file on a local filesystem.
|
13
|
-
FilePathStr = NewType("FilePathStr", str)
|
14
|
+
StrPath: TypeAlias = Union[str, "os.PathLike[str]"]
|
14
15
|
|
15
|
-
|
16
|
+
FilePathStr: TypeAlias = str #: A native path to a file on a local filesystem.
|
17
|
+
URIStr: TypeAlias = str
|
16
18
|
|
17
19
|
|
18
20
|
class LogicalPath(str):
|
@@ -54,7 +56,7 @@ class LogicalPath(str):
|
|
54
56
|
# will result in different outputs on different platforms; however, it doesn't alter
|
55
57
|
# absolute paths or check for prohibited characters etc.
|
56
58
|
|
57
|
-
def __new__(cls, path: StrPath) ->
|
59
|
+
def __new__(cls, path: StrPath) -> LogicalPath:
|
58
60
|
if isinstance(path, LogicalPath):
|
59
61
|
return super().__new__(cls, path)
|
60
62
|
if hasattr(path, "as_posix"):
|
@@ -77,30 +79,30 @@ class LogicalPath(str):
|
|
77
79
|
"""Convert this path to a PurePosixPath."""
|
78
80
|
return PurePosixPath(self)
|
79
81
|
|
80
|
-
def __getattr__(self,
|
82
|
+
def __getattr__(self, name: str) -> Any:
|
81
83
|
"""Act like a subclass of PurePosixPath for all methods not defined on str."""
|
82
84
|
try:
|
83
|
-
|
84
|
-
except AttributeError
|
85
|
-
|
85
|
+
attr = getattr(self.to_path(), name)
|
86
|
+
except AttributeError:
|
87
|
+
classname = type(self).__qualname__
|
88
|
+
raise AttributeError(f"{classname!r} has no attribute {name!r}") from None
|
86
89
|
|
87
|
-
if isinstance(
|
88
|
-
return LogicalPath(
|
90
|
+
if isinstance(attr, PurePosixPath):
|
91
|
+
return LogicalPath(attr)
|
89
92
|
|
90
93
|
# If the result is a callable (a method), wrap it so that it has the same
|
91
94
|
# behavior: if the call result returns a PurePosixPath, return a LogicalPath.
|
92
|
-
if callable(
|
95
|
+
if callable(fn := attr):
|
93
96
|
|
94
|
-
@wraps(
|
97
|
+
@wraps(fn)
|
95
98
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
return inner_result
|
99
|
+
if isinstance(res := fn(*args, **kwargs), PurePosixPath):
|
100
|
+
return LogicalPath(res)
|
101
|
+
return res
|
100
102
|
|
101
103
|
return wrapper
|
102
|
-
return
|
104
|
+
return attr
|
103
105
|
|
104
|
-
def __truediv__(self, other: StrPath) ->
|
106
|
+
def __truediv__(self, other: StrPath) -> LogicalPath:
|
105
107
|
"""Act like a PurePosixPath for the / operator, but return a LogicalPath."""
|
106
108
|
return LogicalPath(self.to_path() / LogicalPath(other))
|
wandb/sdk/lib/printer.py
CHANGED
@@ -156,14 +156,10 @@ class Printer(abc.ABC):
|
|
156
156
|
"""
|
157
157
|
|
158
158
|
@abc.abstractmethod
|
159
|
-
def progress_close(self
|
159
|
+
def progress_close(self) -> None:
|
160
160
|
"""Close the progress indicator.
|
161
161
|
|
162
162
|
After this, `progress_update` should not be used.
|
163
|
-
|
164
|
-
Args:
|
165
|
-
text: The final text to set on the progress indicator.
|
166
|
-
Ignored in Jupyter notebooks.
|
167
163
|
"""
|
168
164
|
|
169
165
|
@staticmethod
|
@@ -342,13 +338,10 @@ class _PrinterTerm(Printer):
|
|
342
338
|
wandb.termlog(f"{next(self._progress)} {text}", newline=False)
|
343
339
|
|
344
340
|
@override
|
345
|
-
def progress_close(self
|
341
|
+
def progress_close(self) -> None:
|
346
342
|
if self._settings and self._settings.silent:
|
347
343
|
return
|
348
344
|
|
349
|
-
text = text or " " * 79
|
350
|
-
wandb.termlog(text)
|
351
|
-
|
352
345
|
@property
|
353
346
|
@override
|
354
347
|
def supports_html(self) -> bool:
|
@@ -462,11 +455,14 @@ class _PrinterJupyter(Printer):
|
|
462
455
|
display_id=True,
|
463
456
|
)
|
464
457
|
|
465
|
-
if handle:
|
458
|
+
if not handle:
|
459
|
+
yield None
|
460
|
+
return
|
461
|
+
|
462
|
+
try:
|
466
463
|
yield _DynamicJupyterText(handle)
|
464
|
+
finally:
|
467
465
|
handle.update(self._ipython_display.HTML(""))
|
468
|
-
else:
|
469
|
-
yield None
|
470
466
|
|
471
467
|
@override
|
472
468
|
def display(
|
@@ -539,7 +535,7 @@ class _PrinterJupyter(Printer):
|
|
539
535
|
self._progress.update(percent_done, text)
|
540
536
|
|
541
537
|
@override
|
542
|
-
def progress_close(self
|
538
|
+
def progress_close(self) -> None:
|
543
539
|
if self._progress:
|
544
540
|
self._progress.close()
|
545
541
|
|
wandb/sdk/lib/progress.py
CHANGED
@@ -70,12 +70,14 @@ def progress_printer(
|
|
70
70
|
default_text: The text to show if no information is available.
|
71
71
|
"""
|
72
72
|
with printer.dynamic_text() as text_area:
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
73
|
+
try:
|
74
|
+
yield ProgressPrinter(
|
75
|
+
printer,
|
76
|
+
text_area,
|
77
|
+
default_text=default_text,
|
78
|
+
)
|
79
|
+
finally:
|
80
|
+
printer.progress_close()
|
79
81
|
|
80
82
|
|
81
83
|
class ProgressPrinter:
|
@@ -1,10 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import atexit
|
4
|
+
import pathlib
|
4
5
|
from typing import Callable
|
5
6
|
|
6
7
|
from wandb.proto import wandb_server_pb2 as spb
|
7
|
-
from wandb.proto import wandb_settings_pb2
|
8
|
+
from wandb.proto import wandb_settings_pb2, wandb_sync_pb2
|
8
9
|
from wandb.sdk import wandb_settings
|
9
10
|
from wandb.sdk.interface.interface import InterfaceBase
|
10
11
|
from wandb.sdk.interface.interface_sock import InterfaceSock
|
@@ -12,6 +13,7 @@ from wandb.sdk.lib import asyncio_manager
|
|
12
13
|
from wandb.sdk.lib.exit_hooks import ExitHooks
|
13
14
|
from wandb.sdk.lib.service.service_client import ServiceClient
|
14
15
|
from wandb.sdk.mailbox import HandleAbandonedError, MailboxClosedError
|
16
|
+
from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
|
15
17
|
|
16
18
|
from . import service_process, service_token
|
17
19
|
|
@@ -96,6 +98,45 @@ class ServiceConnection:
|
|
96
98
|
"""Returns an interface for communicating with the service."""
|
97
99
|
return InterfaceSock(self._client, stream_id=stream_id)
|
98
100
|
|
101
|
+
def init_sync(
|
102
|
+
self,
|
103
|
+
paths: set[pathlib.Path],
|
104
|
+
settings: wandb_settings.Settings,
|
105
|
+
) -> MailboxHandle[wandb_sync_pb2.ServerInitSyncResponse]:
|
106
|
+
"""Send a ServerInitSyncRequest."""
|
107
|
+
init_sync = wandb_sync_pb2.ServerInitSyncRequest(
|
108
|
+
path=(str(path) for path in paths),
|
109
|
+
settings=settings.to_proto(),
|
110
|
+
)
|
111
|
+
request = spb.ServerRequest(init_sync=init_sync)
|
112
|
+
|
113
|
+
handle = self._client.deliver(request)
|
114
|
+
return handle.map(lambda r: r.init_sync_response)
|
115
|
+
|
116
|
+
def sync(
|
117
|
+
self,
|
118
|
+
id: str,
|
119
|
+
*,
|
120
|
+
parallelism: int,
|
121
|
+
) -> MailboxHandle[wandb_sync_pb2.ServerSyncResponse]:
|
122
|
+
"""Send a ServerSyncRequest."""
|
123
|
+
sync = wandb_sync_pb2.ServerSyncRequest(id=id, parallelism=parallelism)
|
124
|
+
request = spb.ServerRequest(sync=sync)
|
125
|
+
|
126
|
+
handle = self._client.deliver(request)
|
127
|
+
return handle.map(lambda r: r.sync_response)
|
128
|
+
|
129
|
+
def sync_status(
|
130
|
+
self,
|
131
|
+
id: str,
|
132
|
+
) -> MailboxHandle[wandb_sync_pb2.ServerSyncStatusResponse]:
|
133
|
+
"""Send a ServerSyncStatusRequest."""
|
134
|
+
sync_status = wandb_sync_pb2.ServerSyncStatusRequest(id=id)
|
135
|
+
request = spb.ServerRequest(sync_status=sync_status)
|
136
|
+
|
137
|
+
handle = self._client.deliver(request)
|
138
|
+
return handle.map(lambda r: r.sync_status_response)
|
139
|
+
|
99
140
|
def inform_init(
|
100
141
|
self,
|
101
142
|
settings: wandb_settings_pb2.Settings,
|
@@ -143,17 +184,6 @@ class ServiceConnection:
|
|
143
184
|
else:
|
144
185
|
return response.inform_attach_response.settings
|
145
186
|
|
146
|
-
def inform_start(
|
147
|
-
self,
|
148
|
-
settings: wandb_settings_pb2.Settings,
|
149
|
-
run_id: str,
|
150
|
-
) -> None:
|
151
|
-
"""Send a start request to the service."""
|
152
|
-
request = spb.ServerInformStartRequest()
|
153
|
-
request.settings.CopyFrom(settings)
|
154
|
-
request._info.stream_id = run_id
|
155
|
-
self._client.publish(spb.ServerRequest(inform_start=request))
|
156
|
-
|
157
187
|
def teardown(self, exit_code: int) -> int | None:
|
158
188
|
"""Close the connection.
|
159
189
|
|