wandb 0.21.3__py3-none-musllinux_1_2_aarch64.whl → 0.22.0__py3-none-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/_analytics.py +65 -0
- wandb/_iterutils.py +8 -0
- wandb/_pydantic/__init__.py +10 -11
- wandb/_pydantic/base.py +3 -53
- wandb/_pydantic/field_types.py +29 -0
- wandb/_pydantic/v1_compat.py +47 -30
- wandb/_strutils.py +40 -0
- wandb/apis/public/__init__.py +42 -0
- wandb/apis/public/api.py +17 -4
- wandb/apis/public/artifacts.py +5 -4
- wandb/apis/public/automations.py +2 -1
- wandb/apis/public/registries/_freezable_list.py +6 -6
- wandb/apis/public/registries/_utils.py +2 -1
- wandb/apis/public/registries/registries_search.py +4 -0
- wandb/apis/public/registries/registry.py +7 -0
- wandb/apis/public/runs.py +24 -6
- wandb/automations/_filters/expressions.py +3 -2
- wandb/automations/_filters/operators.py +2 -1
- wandb/automations/_validators.py +20 -0
- wandb/automations/actions.py +4 -2
- wandb/automations/events.py +4 -5
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +48 -130
- wandb/cli/beta_sync.py +226 -0
- wandb/integration/dspy/__init__.py +5 -0
- wandb/integration/dspy/dspy.py +422 -0
- wandb/integration/weave/weave.py +55 -0
- wandb/proto/v3/wandb_internal_pb2.py +234 -224
- wandb/proto/v3/wandb_server_pb2.py +38 -57
- wandb/proto/v3/wandb_sync_pb2.py +87 -0
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +226 -224
- wandb/proto/v4/wandb_server_pb2.py +38 -41
- wandb/proto/v4/wandb_sync_pb2.py +38 -0
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +226 -224
- wandb/proto/v5/wandb_server_pb2.py +38 -41
- wandb/proto/v5/wandb_sync_pb2.py +39 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_base_pb2.py +3 -3
- wandb/proto/v6/wandb_internal_pb2.py +229 -227
- wandb/proto/v6/wandb_server_pb2.py +41 -44
- wandb/proto/v6/wandb_settings_pb2.py +3 -3
- wandb/proto/v6/wandb_sync_pb2.py +49 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +15 -15
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/proto/wandb_sync_pb2.py +12 -0
- wandb/sdk/artifacts/_validators.py +50 -49
- wandb/sdk/artifacts/artifact.py +7 -7
- wandb/sdk/artifacts/exceptions.py +2 -1
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +3 -2
- wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +59 -124
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/interface/interface_shared.py +9 -0
- wandb/sdk/lib/asyncio_compat.py +88 -23
- wandb/sdk/lib/gql_request.py +18 -7
- wandb/sdk/lib/printer.py +9 -13
- wandb/sdk/lib/progress.py +8 -6
- wandb/sdk/lib/service/service_connection.py +42 -12
- wandb/sdk/mailbox/wait_with_progress.py +1 -1
- wandb/sdk/wandb_init.py +9 -9
- wandb/sdk/wandb_run.py +13 -1
- wandb/sdk/wandb_settings.py +55 -0
- wandb/wandb_agent.py +35 -4
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/METADATA +1 -1
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/RECORD +818 -806
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/WHEEL +0 -0
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/licenses/LICENSE +0 -0
wandb/cli/beta_sync.py
ADDED
@@ -0,0 +1,226 @@
|
|
1
|
+
"""Implements `wandb sync` using wandb-core."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import pathlib
|
7
|
+
import time
|
8
|
+
from itertools import filterfalse
|
9
|
+
from typing import Iterable, Iterator
|
10
|
+
|
11
|
+
import click
|
12
|
+
|
13
|
+
import wandb
|
14
|
+
from wandb.proto.wandb_sync_pb2 import ServerSyncResponse
|
15
|
+
from wandb.sdk import wandb_setup
|
16
|
+
from wandb.sdk.lib import asyncio_compat
|
17
|
+
from wandb.sdk.lib.printer import ERROR, Printer, new_printer
|
18
|
+
from wandb.sdk.lib.progress import progress_printer
|
19
|
+
from wandb.sdk.lib.service.service_connection import ServiceConnection
|
20
|
+
from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
|
21
|
+
|
22
|
+
_MAX_LIST_LINES = 20
|
23
|
+
_POLL_WAIT_SECONDS = 0.1
|
24
|
+
_SLEEP = asyncio.sleep # patched in tests
|
25
|
+
|
26
|
+
|
27
|
+
def sync(
|
28
|
+
paths: list[pathlib.Path],
|
29
|
+
*,
|
30
|
+
dry_run: bool,
|
31
|
+
skip_synced: bool,
|
32
|
+
verbose: bool,
|
33
|
+
parallelism: int,
|
34
|
+
) -> None:
|
35
|
+
"""Replay one or more .wandb files.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
paths: One or more .wandb files, run directories containing
|
39
|
+
.wandb files, and wandb directories containing run directories.
|
40
|
+
dry_run: If true, just prints what it would do and exits.
|
41
|
+
skip_synced: If true, skips files that have already been synced
|
42
|
+
as indicated by a .wandb.synced marker file in the same directory.
|
43
|
+
verbose: Verbose mode for printing more info.
|
44
|
+
parallelism: Max number of runs to sync at a time.
|
45
|
+
"""
|
46
|
+
wandb_files: set[pathlib.Path] = set()
|
47
|
+
for path in paths:
|
48
|
+
for wandb_file in _find_wandb_files(path, skip_synced=skip_synced):
|
49
|
+
wandb_files.add(wandb_file.resolve())
|
50
|
+
|
51
|
+
if not wandb_files:
|
52
|
+
click.echo("No files to sync.")
|
53
|
+
return
|
54
|
+
|
55
|
+
if dry_run:
|
56
|
+
click.echo(f"Would sync {len(wandb_files)} file(s):")
|
57
|
+
_print_sorted_paths(wandb_files, verbose=verbose)
|
58
|
+
return
|
59
|
+
|
60
|
+
click.echo(f"Syncing {len(wandb_files)} file(s):")
|
61
|
+
_print_sorted_paths(wandb_files, verbose=verbose)
|
62
|
+
|
63
|
+
singleton = wandb_setup.singleton()
|
64
|
+
service = singleton.ensure_service()
|
65
|
+
printer = new_printer()
|
66
|
+
singleton.asyncer.run(
|
67
|
+
lambda: _do_sync(
|
68
|
+
wandb_files,
|
69
|
+
service=service,
|
70
|
+
settings=singleton.settings,
|
71
|
+
printer=printer,
|
72
|
+
parallelism=parallelism,
|
73
|
+
)
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
async def _do_sync(
|
78
|
+
wandb_files: set[pathlib.Path],
|
79
|
+
*,
|
80
|
+
service: ServiceConnection,
|
81
|
+
settings: wandb.Settings,
|
82
|
+
printer: Printer,
|
83
|
+
parallelism: int,
|
84
|
+
) -> None:
|
85
|
+
"""Sync the specified files.
|
86
|
+
|
87
|
+
This is factored out to make the progress animation testable.
|
88
|
+
"""
|
89
|
+
init_result = await service.init_sync(
|
90
|
+
wandb_files,
|
91
|
+
settings,
|
92
|
+
).wait_async(timeout=5)
|
93
|
+
|
94
|
+
sync_handle = service.sync(init_result.id, parallelism=parallelism)
|
95
|
+
|
96
|
+
await _SyncStatusLoop(
|
97
|
+
init_result.id,
|
98
|
+
service,
|
99
|
+
printer,
|
100
|
+
).wait_with_progress(sync_handle)
|
101
|
+
|
102
|
+
|
103
|
+
class _SyncStatusLoop:
|
104
|
+
"""Displays a sync operation's status until it completes."""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
id: str,
|
109
|
+
service: ServiceConnection,
|
110
|
+
printer: Printer,
|
111
|
+
) -> None:
|
112
|
+
self._id = id
|
113
|
+
self._service = service
|
114
|
+
self._printer = printer
|
115
|
+
|
116
|
+
self._rate_limit_last_time: float | None = None
|
117
|
+
self._done = asyncio.Event()
|
118
|
+
|
119
|
+
async def wait_with_progress(
|
120
|
+
self,
|
121
|
+
handle: MailboxHandle[ServerSyncResponse],
|
122
|
+
) -> None:
|
123
|
+
"""Display status updates until the handle completes."""
|
124
|
+
async with asyncio_compat.open_task_group() as group:
|
125
|
+
group.start_soon(self._wait_then_mark_done(handle))
|
126
|
+
group.start_soon(self._show_progress_until_done())
|
127
|
+
|
128
|
+
async def _wait_then_mark_done(
|
129
|
+
self,
|
130
|
+
handle: MailboxHandle[ServerSyncResponse],
|
131
|
+
) -> None:
|
132
|
+
response = await handle.wait_async(timeout=None)
|
133
|
+
if messages := list(response.errors):
|
134
|
+
self._printer.display(messages, level=ERROR)
|
135
|
+
self._done.set()
|
136
|
+
|
137
|
+
async def _show_progress_until_done(self) -> None:
|
138
|
+
"""Show rate-limited status updates until _done is set."""
|
139
|
+
with progress_printer(self._printer, "Syncing...") as progress:
|
140
|
+
while not await self._rate_limit_check_done():
|
141
|
+
handle = self._service.sync_status(self._id)
|
142
|
+
response = await handle.wait_async(timeout=None)
|
143
|
+
|
144
|
+
if messages := list(response.new_errors):
|
145
|
+
self._printer.display(messages, level=ERROR)
|
146
|
+
progress.update(response.stats)
|
147
|
+
|
148
|
+
async def _rate_limit_check_done(self) -> bool:
|
149
|
+
"""Wait for rate limit and return whether _done is set."""
|
150
|
+
now = time.monotonic()
|
151
|
+
last_time = self._rate_limit_last_time
|
152
|
+
self._rate_limit_last_time = now
|
153
|
+
|
154
|
+
if last_time and (time_since_last := now - last_time) < _POLL_WAIT_SECONDS:
|
155
|
+
await asyncio_compat.race(
|
156
|
+
_SLEEP(_POLL_WAIT_SECONDS - time_since_last),
|
157
|
+
self._done.wait(),
|
158
|
+
)
|
159
|
+
|
160
|
+
return self._done.is_set()
|
161
|
+
|
162
|
+
|
163
|
+
def _find_wandb_files(
|
164
|
+
path: pathlib.Path,
|
165
|
+
*,
|
166
|
+
skip_synced: bool,
|
167
|
+
) -> Iterator[pathlib.Path]:
|
168
|
+
"""Returns paths to the .wandb files to sync."""
|
169
|
+
if skip_synced:
|
170
|
+
yield from filterfalse(_is_synced, _expand_wandb_files(path))
|
171
|
+
else:
|
172
|
+
yield from _expand_wandb_files(path)
|
173
|
+
|
174
|
+
|
175
|
+
def _expand_wandb_files(
|
176
|
+
path: pathlib.Path,
|
177
|
+
) -> Iterator[pathlib.Path]:
|
178
|
+
"""Iterate over .wandb files selected by the path."""
|
179
|
+
if path.suffix == ".wandb":
|
180
|
+
yield path
|
181
|
+
return
|
182
|
+
|
183
|
+
files_in_run_directory = path.glob("*.wandb")
|
184
|
+
try:
|
185
|
+
first_file = next(files_in_run_directory)
|
186
|
+
except StopIteration:
|
187
|
+
pass
|
188
|
+
else:
|
189
|
+
yield first_file
|
190
|
+
yield from files_in_run_directory
|
191
|
+
return
|
192
|
+
|
193
|
+
yield from path.glob("*/*.wandb")
|
194
|
+
|
195
|
+
|
196
|
+
def _is_synced(path: pathlib.Path) -> bool:
|
197
|
+
"""Returns whether the .wandb file is synced."""
|
198
|
+
return path.with_suffix(".wandb.synced").exists()
|
199
|
+
|
200
|
+
|
201
|
+
def _print_sorted_paths(paths: Iterable[pathlib.Path], verbose: bool) -> None:
|
202
|
+
"""Print file paths, sorting them and truncating the list if needed.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
paths: Paths to print. Must be absolute with symlinks resolved.
|
206
|
+
verbose: If true, doesn't truncate paths.
|
207
|
+
"""
|
208
|
+
# Prefer to print paths relative to the current working directory.
|
209
|
+
cwd = pathlib.Path(".").resolve()
|
210
|
+
formatted_paths: list[str] = []
|
211
|
+
for path in paths:
|
212
|
+
try:
|
213
|
+
formatted_path = str(path.relative_to(cwd))
|
214
|
+
except ValueError:
|
215
|
+
formatted_path = str(path)
|
216
|
+
formatted_paths.append(formatted_path)
|
217
|
+
|
218
|
+
sorted_paths = sorted(formatted_paths)
|
219
|
+
max_lines = len(sorted_paths) if verbose else _MAX_LIST_LINES
|
220
|
+
|
221
|
+
for i in range(min(len(sorted_paths), max_lines)):
|
222
|
+
click.echo(f" {sorted_paths[i]}")
|
223
|
+
|
224
|
+
if len(sorted_paths) > max_lines:
|
225
|
+
remaining = len(sorted_paths) - max_lines
|
226
|
+
click.echo(f" +{remaining:,d} more (pass --verbose to see all)")
|
@@ -0,0 +1,422 @@
|
|
1
|
+
"""DSPy ↔ Weights & Biases integration."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from collections.abc import Mapping, Sequence
|
8
|
+
from typing import Any, Literal
|
9
|
+
|
10
|
+
import wandb
|
11
|
+
import wandb.util
|
12
|
+
from wandb.sdk.wandb_run import Run
|
13
|
+
|
14
|
+
dspy = wandb.util.get_module(
|
15
|
+
name="dspy",
|
16
|
+
required=(
|
17
|
+
"To use the W&B DSPy integration you need to have the `dspy` "
|
18
|
+
"python package installed. Install it with `uv pip install dspy`."
|
19
|
+
),
|
20
|
+
lazy=False,
|
21
|
+
)
|
22
|
+
if dspy is not None:
|
23
|
+
assert dspy.__version__ >= "3.0.0", (
|
24
|
+
"DSPy 3.0.0 or higher is required. You have " + dspy.__version__
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def _flatten_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
32
|
+
"""Flatten a list of nested row dicts into flat key/value dicts.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
rows (list[dict[str, Any]]): List of nested dictionaries to flatten.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
list[dict[str, Any]]: List of flattened dictionaries.
|
39
|
+
|
40
|
+
"""
|
41
|
+
|
42
|
+
def _flatten(
|
43
|
+
d: dict[str, Any], parent_key: str = "", sep: str = "."
|
44
|
+
) -> dict[str, Any]:
|
45
|
+
items = []
|
46
|
+
for k, v in d.items():
|
47
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
48
|
+
if isinstance(v, dict):
|
49
|
+
items.extend(_flatten(v, new_key, sep=sep).items())
|
50
|
+
else:
|
51
|
+
items.append((new_key, v))
|
52
|
+
return dict(items)
|
53
|
+
|
54
|
+
return [_flatten(row) for row in rows]
|
55
|
+
|
56
|
+
|
57
|
+
class WandbDSPyCallback(dspy.utils.BaseCallback):
|
58
|
+
"""W&B callback for tracking DSPy evaluation and optimization.
|
59
|
+
|
60
|
+
This callback logs evaluation scores, per-step predictions (optional), and
|
61
|
+
a table capturing the DSPy program signature over time. It can also save
|
62
|
+
the best program as a W&B Artifact for reproducibility.
|
63
|
+
|
64
|
+
Examples:
|
65
|
+
Basic usage within DSPy settings:
|
66
|
+
|
67
|
+
```python
|
68
|
+
import dspy
|
69
|
+
import wandb
|
70
|
+
from wandb.integration.dspy import WandbDSPyCallback
|
71
|
+
|
72
|
+
with wandb.init(project="dspy-optimization") as run:
|
73
|
+
dspy.settings.callbacks.append(WandbDSPyCallback(run=run))
|
74
|
+
# Run your DSPy optimization/evaluation
|
75
|
+
```
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(self, log_results: bool = True, run: Run | None = None) -> None:
|
79
|
+
"""Initialize the callback.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
log_results (bool): Whether to log per-evaluation prediction tables.
|
83
|
+
run (Run | None): Optional W&B run to use. Defaults to the
|
84
|
+
current global run if available.
|
85
|
+
|
86
|
+
Raises:
|
87
|
+
wandb.Error: If no active run is provided or found.
|
88
|
+
"""
|
89
|
+
# If no run is provided, use the current global run if available.
|
90
|
+
if run is None:
|
91
|
+
if wandb.run is None:
|
92
|
+
raise wandb.Error(
|
93
|
+
"You must call `wandb.init()` before instantiating WandbDSPyCallback()."
|
94
|
+
)
|
95
|
+
run = wandb.run
|
96
|
+
|
97
|
+
self.log_results = log_results
|
98
|
+
|
99
|
+
with wandb.wandb_lib.telemetry.context(run=run) as tel:
|
100
|
+
tel.feature.dspy_callback = True
|
101
|
+
|
102
|
+
self._run = run
|
103
|
+
self._did_log_config: bool = False
|
104
|
+
self._program_info: dict[str, Any] = {}
|
105
|
+
self._program_table: wandb.Table | None = None
|
106
|
+
self._row_idx: int = 0
|
107
|
+
|
108
|
+
def _flatten_dict(
|
109
|
+
self, nested: Any, parent_key: str = "", sep: str = "."
|
110
|
+
) -> dict[str, Any]:
|
111
|
+
"""Recursively flatten arbitrarily nested mappings and sequences.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
nested (Any): Nested structure of mappings/lists to flatten.
|
115
|
+
parent_key (str): Prefix to prepend to keys in the flattened output.
|
116
|
+
sep (str): Key separator for nested fields.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
dict[str, Any]: Flattened dictionary representation.
|
120
|
+
"""
|
121
|
+
flat: dict[str, Any] = {}
|
122
|
+
|
123
|
+
def _walk(obj: Any, base: str) -> None:
|
124
|
+
if isinstance(obj, Mapping):
|
125
|
+
for k, v in obj.items():
|
126
|
+
new_key = f"{base}{sep}{k}" if base else str(k)
|
127
|
+
_walk(v, new_key)
|
128
|
+
elif isinstance(obj, Sequence) and not isinstance(
|
129
|
+
obj, (str, bytes, bytearray)
|
130
|
+
):
|
131
|
+
for idx, v in enumerate(obj):
|
132
|
+
new_key = f"{base}{sep}{idx}" if base else str(idx)
|
133
|
+
_walk(v, new_key)
|
134
|
+
else:
|
135
|
+
# Base can be empty only if the top-level is a scalar; guard against that.
|
136
|
+
key = base if base else ""
|
137
|
+
if key:
|
138
|
+
flat[key] = obj
|
139
|
+
|
140
|
+
_walk(nested, parent_key)
|
141
|
+
return flat
|
142
|
+
|
143
|
+
def _extract_fields(self, fields: list[dict[str, Any]]) -> dict[str, str]:
|
144
|
+
"""Convert signature fields to a flat mapping of strings.
|
145
|
+
|
146
|
+
Note:
|
147
|
+
The input is expected to be a dict-like mapping from field names to
|
148
|
+
field metadata. Values are stringified for logging.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
fields (list[dict[str, Any]]): Mapping of field name to metadata.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
dict[str, str]: Mapping of field name to string value.
|
155
|
+
"""
|
156
|
+
return {k: str(v) for k, v in fields.items()}
|
157
|
+
|
158
|
+
def _extract_program_info(self, program_obj: Any) -> dict[str, Any]:
|
159
|
+
"""Extract signature-related info from a DSPy program.
|
160
|
+
|
161
|
+
Attempts to read the program signature, instructions, input and output
|
162
|
+
fields from a DSPy `Predict` parameter if available.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
program_obj (Any): DSPy program/module instance.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
dict[str, Any]: Flattened dictionary of signature metadata.
|
169
|
+
"""
|
170
|
+
info_dict = {}
|
171
|
+
|
172
|
+
if program_obj is None:
|
173
|
+
return info_dict
|
174
|
+
|
175
|
+
try:
|
176
|
+
sig = next(
|
177
|
+
param.signature
|
178
|
+
for _, param in program_obj.named_parameters()
|
179
|
+
if isinstance(param, dspy.Predict)
|
180
|
+
)
|
181
|
+
|
182
|
+
if getattr(sig, "signature", None):
|
183
|
+
info_dict["signature"] = sig.signature
|
184
|
+
if getattr(sig, "instructions", None):
|
185
|
+
info_dict["instructions"] = sig.instructions
|
186
|
+
if getattr(sig, "input_fields", None):
|
187
|
+
input_fields = sig.input_fields
|
188
|
+
info_dict["input_fields"] = self._extract_fields(input_fields)
|
189
|
+
if getattr(sig, "output_fields", None):
|
190
|
+
output_fields = sig.output_fields
|
191
|
+
info_dict["output_fields"] = self._extract_fields(output_fields)
|
192
|
+
|
193
|
+
return self._flatten_dict(info_dict)
|
194
|
+
except Exception as e:
|
195
|
+
logger.warning(
|
196
|
+
"Failed to extract program info from Evaluate instance: %s", e
|
197
|
+
)
|
198
|
+
return info_dict
|
199
|
+
|
200
|
+
def on_evaluate_start(
|
201
|
+
self,
|
202
|
+
call_id: str,
|
203
|
+
instance: Any,
|
204
|
+
inputs: dict[str, Any],
|
205
|
+
) -> None:
|
206
|
+
"""Handle start of a DSPy evaluation call.
|
207
|
+
|
208
|
+
Logs non-private fields from the evaluator instance to W&B config and
|
209
|
+
captures program signature info for later logging.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
call_id (str): Unique identifier for the evaluation call.
|
213
|
+
instance (Any): The evaluation instance (e.g., `dspy.Evaluate`).
|
214
|
+
inputs (dict[str, Any]): Inputs passed to the evaluation (may
|
215
|
+
include a `program` key with the DSPy program).
|
216
|
+
"""
|
217
|
+
if not self._did_log_config:
|
218
|
+
instance_vars = vars(instance) if hasattr(instance, "__dict__") else {}
|
219
|
+
serializable = {
|
220
|
+
k: v for k, v in instance_vars.items() if not k.startswith("_")
|
221
|
+
}
|
222
|
+
if "devset" in serializable:
|
223
|
+
# we don't want to log the devset in the config
|
224
|
+
del serializable["devset"]
|
225
|
+
|
226
|
+
self._run.config.update(serializable)
|
227
|
+
self._did_log_config = True
|
228
|
+
|
229
|
+
# 2) Build/append program signature tables from the 'program' inputs
|
230
|
+
if program_obj := inputs.get("program"):
|
231
|
+
self._program_info = self._extract_program_info(program_obj)
|
232
|
+
|
233
|
+
def on_evaluate_end(
|
234
|
+
self,
|
235
|
+
call_id: str,
|
236
|
+
outputs: Any | None,
|
237
|
+
exception: Exception | None = None,
|
238
|
+
) -> None:
|
239
|
+
"""Handle end of a DSPy evaluation call.
|
240
|
+
|
241
|
+
If available, logs a numeric `score` metric and (optionally) per-step
|
242
|
+
prediction tables. Always appends a row to the program-signature table.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
call_id (str): Unique identifier for the evaluation call.
|
246
|
+
outputs (Any | None): Evaluation outputs; supports
|
247
|
+
`dspy.evaluate.evaluate.EvaluationResult`.
|
248
|
+
exception (Exception | None): Exception raised during evaluation, if any.
|
249
|
+
"""
|
250
|
+
# The `BaseCallback` does not define the interface for the `outputs` parameter,
|
251
|
+
# Currently, we know of `EvaluationResult` which is a subclass of `dspy.Prediction`.
|
252
|
+
# We currently support this type and will warn the user if a different type is passed.
|
253
|
+
score: float | None = None
|
254
|
+
if exception is None:
|
255
|
+
if isinstance(outputs, dspy.evaluate.evaluate.EvaluationResult):
|
256
|
+
# log the float score as a wandb metric
|
257
|
+
score = outputs.score
|
258
|
+
wandb.log({"score": float(score)}, step=self._row_idx)
|
259
|
+
|
260
|
+
# Log the predictions as a separate table for each eval end.
|
261
|
+
# We know that results if of type `list[tuple["dspy.Example", "dspy.Example", Any]]`
|
262
|
+
results = outputs.results
|
263
|
+
if self.log_results:
|
264
|
+
rows = self._parse_results(results)
|
265
|
+
if rows:
|
266
|
+
self._log_predictions_table(rows)
|
267
|
+
else:
|
268
|
+
wandb.termwarn(
|
269
|
+
f"on_evaluate_end received unexpected outputs type: {type(outputs)}. "
|
270
|
+
"Expected dspy.evaluate.evaluate.EvaluationResult; skipping logging score and `log_results`."
|
271
|
+
)
|
272
|
+
else:
|
273
|
+
wandb.termwarn(
|
274
|
+
f"on_evaluate_end received exception: {exception}. "
|
275
|
+
"Skipping logging score and `log_results`."
|
276
|
+
)
|
277
|
+
|
278
|
+
# Log the program signature iteratively
|
279
|
+
if self._program_table is None:
|
280
|
+
columns = ["step", *self._program_info.keys()]
|
281
|
+
if isinstance(score, float):
|
282
|
+
columns.append("score")
|
283
|
+
self._program_table = wandb.Table(columns=columns, log_mode="INCREMENTAL")
|
284
|
+
|
285
|
+
if self._program_table is not None:
|
286
|
+
values = list(self._program_info.values())
|
287
|
+
if isinstance(score, float):
|
288
|
+
values.append(score)
|
289
|
+
|
290
|
+
self._program_table.add_data(
|
291
|
+
self._row_idx,
|
292
|
+
*values,
|
293
|
+
)
|
294
|
+
self._run.log(
|
295
|
+
{"program_signature": self._program_table}, step=self._row_idx
|
296
|
+
)
|
297
|
+
|
298
|
+
self._row_idx += 1
|
299
|
+
|
300
|
+
def _parse_results(
|
301
|
+
self,
|
302
|
+
results: list[tuple[dspy.Example, dspy.Prediction | dspy.Completions, bool]],
|
303
|
+
) -> list[dict[str, Any]]:
|
304
|
+
"""Normalize evaluation results into serializable row dicts.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
results (list[tuple]): Sequence of `(example, prediction, is_correct)`
|
308
|
+
tuples from DSPy evaluation.
|
309
|
+
|
310
|
+
Returns:
|
311
|
+
list[dict[str, Any]]: Rows with `example`, `prediction`, `is_correct`.
|
312
|
+
"""
|
313
|
+
_rows: list[dict[str, Any]] = []
|
314
|
+
for example, prediction, is_correct in results:
|
315
|
+
if isinstance(prediction, dspy.Prediction):
|
316
|
+
prediction_dict = prediction.toDict()
|
317
|
+
if isinstance(prediction, dspy.Completions):
|
318
|
+
prediction_dict = prediction.items()
|
319
|
+
|
320
|
+
row: dict[str, Any] = {
|
321
|
+
"example": example.toDict(),
|
322
|
+
"prediction": prediction_dict,
|
323
|
+
"is_correct": is_correct,
|
324
|
+
}
|
325
|
+
_rows.append(row)
|
326
|
+
|
327
|
+
return _rows
|
328
|
+
|
329
|
+
def _log_predictions_table(self, rows: list[dict[str, Any]]) -> None:
|
330
|
+
"""Log a W&B Table of predictions for the current evaluation step.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
rows (list[dict[str, Any]]): Prediction rows to log.
|
334
|
+
"""
|
335
|
+
rows = _flatten_rows(rows)
|
336
|
+
columns = list(rows[0].keys())
|
337
|
+
|
338
|
+
data: list[list[Any]] = [list(row.values()) for row in rows]
|
339
|
+
|
340
|
+
preds_table = wandb.Table(columns=columns, data=data, log_mode="IMMUTABLE")
|
341
|
+
self._run.log({f"predictions_{self._row_idx}": preds_table}, step=self._row_idx)
|
342
|
+
|
343
|
+
def log_best_model(
|
344
|
+
self,
|
345
|
+
model: dspy.Module,
|
346
|
+
*,
|
347
|
+
save_program: bool = True,
|
348
|
+
save_dir: str | None = None,
|
349
|
+
filetype: Literal["json", "pkl"] = "json",
|
350
|
+
aliases: Sequence[str] = ("best", "latest"),
|
351
|
+
artifact_name: str = "dspy-program",
|
352
|
+
) -> None:
|
353
|
+
"""Save and log the best DSPy program as a W&B Artifact.
|
354
|
+
|
355
|
+
You can choose to save the full program (architecture + state) or only
|
356
|
+
the state to a single file (JSON or pickle).
|
357
|
+
|
358
|
+
Args:
|
359
|
+
model (dspy.Module): DSPy module to save.
|
360
|
+
save_program (bool): Save full program directory if True; otherwise
|
361
|
+
save only the state file. Defaults to `True`.
|
362
|
+
save_dir (str): Directory to store program files before logging. Defaults to a
|
363
|
+
subdirectory `dspy_program` within the active run's files directory
|
364
|
+
(i.e., `wandb.run.dir`).
|
365
|
+
filetype (Literal["json", "pkl"]): State file format when
|
366
|
+
`save_program` is False. Defaults to `json`.
|
367
|
+
aliases (Sequence[str]): Aliases for the logged Artifact version. Defaults to `("best", "latest")`.
|
368
|
+
artifact_name (str): Base name for the Artifact. Defaults to `dspy-program`.
|
369
|
+
|
370
|
+
Examples:
|
371
|
+
Save the complete program and add aliases:
|
372
|
+
|
373
|
+
```python
|
374
|
+
callback.log_best_model(
|
375
|
+
optimized_program, save_program=True, aliases=("best", "production")
|
376
|
+
)
|
377
|
+
```
|
378
|
+
|
379
|
+
Save only the state as JSON:
|
380
|
+
|
381
|
+
```python
|
382
|
+
callback.log_best_model(
|
383
|
+
optimized_program, save_program=False, filetype="json"
|
384
|
+
)
|
385
|
+
```
|
386
|
+
"""
|
387
|
+
# Derive metadata to help discoverability in the UI
|
388
|
+
info_dict = self._extract_program_info(model)
|
389
|
+
metadata = {
|
390
|
+
"dspy_version": getattr(dspy, "__version__", "unknown"),
|
391
|
+
"module_class": model.__class__.__name__,
|
392
|
+
**info_dict,
|
393
|
+
}
|
394
|
+
artifact = wandb.Artifact(
|
395
|
+
name=f"{artifact_name}-{self._run.id}",
|
396
|
+
type="model",
|
397
|
+
metadata=metadata,
|
398
|
+
)
|
399
|
+
|
400
|
+
# Resolve and normalize the save directory in a cross-platform way
|
401
|
+
if save_dir is None:
|
402
|
+
save_dir = os.path.join(self._run.dir, "dspy_program")
|
403
|
+
save_dir = os.path.normpath(save_dir)
|
404
|
+
|
405
|
+
try:
|
406
|
+
os.makedirs(save_dir, exist_ok=True)
|
407
|
+
except Exception as exc:
|
408
|
+
wandb.termwarn(
|
409
|
+
f"Could not create or access directory '{save_dir}': {exc}. Skipping artifact logging."
|
410
|
+
)
|
411
|
+
return
|
412
|
+
# Save per requested mode
|
413
|
+
if save_program:
|
414
|
+
model.save(save_dir, save_program=True)
|
415
|
+
artifact.add_dir(save_dir)
|
416
|
+
else:
|
417
|
+
filename = f"program.{filetype}"
|
418
|
+
file_path = os.path.join(save_dir, filename)
|
419
|
+
model.save(file_path, save_program=False)
|
420
|
+
artifact.add_file(file_path)
|
421
|
+
|
422
|
+
self._run.log_artifact(artifact, aliases=list(aliases))
|