docent-python 0.1.41a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +4 -0
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +134 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +14 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +473 -0
- docent/data_models/chat/__init__.py +37 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +191 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +187 -0
- docent/data_models/formatted_objects.py +84 -0
- docent/data_models/judge.py +17 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/transcript.py +305 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +320 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +68 -0
- docent/judges/util/voting.py +139 -0
- docent/loaders/load_inspect.py +215 -0
- docent/py.typed +0 -0
- docent/samples/__init__.py +3 -0
- docent/samples/load.py +9 -0
- docent/samples/log.eval +0 -0
- docent/samples/tb_airline.json +1 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/agent_run_writer.py +317 -0
- docent/sdk/client.py +1186 -0
- docent/sdk/llm_context.py +432 -0
- docent/trace.py +2741 -0
- docent/trace_temp.py +1086 -0
- docent_python-0.1.41a0.dist-info/METADATA +33 -0
- docent_python-0.1.41a0.dist-info/RECORD +59 -0
- docent_python-0.1.41a0.dist-info/WHEEL +4 -0
- docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
docent/sdk/client.py
ADDED
|
@@ -0,0 +1,1186 @@
|
|
|
1
|
+
import gzip
|
|
2
|
+
import itertools
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
import webbrowser
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Iterator, Literal
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import requests
|
|
12
|
+
from pydantic_core import to_jsonable_python
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
from docent._log_util.logger import get_logger
|
|
16
|
+
from docent.data_models.agent_run import AgentRun
|
|
17
|
+
from docent.data_models.judge import Label
|
|
18
|
+
from docent.judges.util.meta_schema import validate_judge_result_schema
|
|
19
|
+
from docent.loaders import load_inspect
|
|
20
|
+
from docent.sdk.llm_context import LLMContext, LLMContextItem
|
|
21
|
+
|
|
22
|
+
MAX_AGENT_RUN_PAYLOAD_BYTES = 100 * 1024 * 1024 # 100MB backend limit
|
|
23
|
+
_AGENT_RUNS_PAYLOAD_PREFIX = b'{"agent_runs":['
|
|
24
|
+
_AGENT_RUNS_PAYLOAD_SUFFIX = b"]}"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _serialize_agent_run(agent_run: AgentRun) -> bytes:
|
|
28
|
+
"""Serialize an AgentRun to compact JSON bytes."""
|
|
29
|
+
return json.dumps(to_jsonable_python(agent_run), separators=(",", ":")).encode("utf-8")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _build_agent_runs_payload(serialized_runs: list[bytes]) -> bytes:
|
|
33
|
+
"""Wrap serialized individual runs into the API payload envelope."""
|
|
34
|
+
body = b",".join(serialized_runs)
|
|
35
|
+
return _AGENT_RUNS_PAYLOAD_PREFIX + body + _AGENT_RUNS_PAYLOAD_SUFFIX
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _yield_agent_run_batches_by_size(
|
|
39
|
+
agent_runs: list[AgentRun], max_payload_bytes: int
|
|
40
|
+
) -> Iterator[tuple[int, bytes]]:
|
|
41
|
+
"""Yield batches of agent runs whose serialized payloads stay within max_payload_bytes."""
|
|
42
|
+
envelope_len = len(_AGENT_RUNS_PAYLOAD_PREFIX) + len(_AGENT_RUNS_PAYLOAD_SUFFIX)
|
|
43
|
+
comma_len = 1
|
|
44
|
+
|
|
45
|
+
current_serialized: list[bytes] = []
|
|
46
|
+
current_size = envelope_len
|
|
47
|
+
|
|
48
|
+
for agent_run in agent_runs:
|
|
49
|
+
serialized = _serialize_agent_run(agent_run)
|
|
50
|
+
serialized_len = len(serialized)
|
|
51
|
+
|
|
52
|
+
if envelope_len + serialized_len > max_payload_bytes:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"A single agent run (id={agent_run.id}) exceeds the maximum payload size of "
|
|
55
|
+
f"{max_payload_bytes} bytes. Reduce the size of that run before uploading."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
delimiter = 0 if not current_serialized else comma_len
|
|
59
|
+
projected_size = current_size + delimiter + serialized_len
|
|
60
|
+
|
|
61
|
+
# If adding the next run would exceed the max payload size, yield the current batch
|
|
62
|
+
if current_serialized and projected_size > max_payload_bytes:
|
|
63
|
+
yield len(current_serialized), _build_agent_runs_payload(current_serialized)
|
|
64
|
+
|
|
65
|
+
# Add the "next run" as the first run in the next batch
|
|
66
|
+
current_serialized = [serialized]
|
|
67
|
+
current_size = envelope_len + serialized_len
|
|
68
|
+
# Otherwise, add to the current batch and continue
|
|
69
|
+
else:
|
|
70
|
+
current_serialized.append(serialized)
|
|
71
|
+
current_size = projected_size
|
|
72
|
+
|
|
73
|
+
if current_serialized:
|
|
74
|
+
yield len(current_serialized), _build_agent_runs_payload(current_serialized)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
logger = get_logger(__name__)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Docent:
|
|
81
|
+
"""Client for interacting with the Docent API.
|
|
82
|
+
|
|
83
|
+
This client provides methods for creating and managing Collections,
|
|
84
|
+
dimensions, agent runs, and filters in the Docent system. It handles
|
|
85
|
+
authentication via API keys and provides a high-level interface for
|
|
86
|
+
logging, querying, and analyzing agent traces.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
>>> from docent import Docent
|
|
90
|
+
>>> client = Docent(api_key="your-api-key")
|
|
91
|
+
>>> collection_id = client.create_collection(name="My Collection")
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
*,
|
|
97
|
+
domain: str = "docent.transluce.org",
|
|
98
|
+
use_https: bool = True,
|
|
99
|
+
api_key: str | None = None,
|
|
100
|
+
# Deprecated
|
|
101
|
+
server_url: str | None = None, # Use domain instead
|
|
102
|
+
web_url: str | None = None, # Use domain instead
|
|
103
|
+
):
|
|
104
|
+
"""Initialize the Docent client.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
domain: The domain of the Docent instance. Defaults to "docent.transluce.org".
|
|
108
|
+
The API and web URLs will be constructed from this domain automatically.
|
|
109
|
+
api_key: API key for authentication. If not provided, will attempt to read
|
|
110
|
+
from the DOCENT_API_KEY environment variable.
|
|
111
|
+
server_url: (Deprecated) Direct URL of the Docent API server. Use `domain` instead.
|
|
112
|
+
web_url: (Deprecated) Direct URL of the Docent web UI. Use `domain` instead.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
ValueError: If no API key is provided and DOCENT_API_KEY is not set.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> client = Docent(domain="my-instance.docent.com", api_key="sk-...")
|
|
119
|
+
"""
|
|
120
|
+
# Warn about deprecated parameters
|
|
121
|
+
if server_url is not None:
|
|
122
|
+
logger.warning(
|
|
123
|
+
"The 'server_url' parameter is deprecated and will be removed in a future version. "
|
|
124
|
+
"Please use 'domain' instead."
|
|
125
|
+
)
|
|
126
|
+
if web_url is not None:
|
|
127
|
+
logger.warning(
|
|
128
|
+
"The 'web_url' parameter is deprecated and will be removed in a future version. "
|
|
129
|
+
"Please use 'domain' instead."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self._domain = domain
|
|
133
|
+
|
|
134
|
+
# Set server URL; server_url takes precedence over domain
|
|
135
|
+
prefix = "https://" if use_https else "http://"
|
|
136
|
+
server_url = (server_url or f"{prefix}api.{domain}").rstrip("/")
|
|
137
|
+
if not server_url.endswith("/rest"):
|
|
138
|
+
server_url = f"{server_url}/rest"
|
|
139
|
+
self._server_url = server_url
|
|
140
|
+
|
|
141
|
+
# Set web URL; web_url takes precedence over domain
|
|
142
|
+
self._web_url = (web_url or f"{prefix}{domain}").rstrip("/")
|
|
143
|
+
|
|
144
|
+
# Use requests.Session for connection pooling and persistent headers
|
|
145
|
+
self._session = requests.Session()
|
|
146
|
+
|
|
147
|
+
api_key = api_key or os.getenv("DOCENT_API_KEY")
|
|
148
|
+
|
|
149
|
+
if api_key is None:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"api_key is required. Please provide an "
|
|
152
|
+
"api_key or set the DOCENT_API_KEY environment variable."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
self._login(api_key)
|
|
156
|
+
|
|
157
|
+
def _handle_response_errors(self, response: requests.Response):
|
|
158
|
+
"""Handle API response and raise informative errors."""
|
|
159
|
+
if response.status_code >= 400:
|
|
160
|
+
try:
|
|
161
|
+
error_data = response.json()
|
|
162
|
+
detail = error_data.get("detail", response.text)
|
|
163
|
+
except Exception:
|
|
164
|
+
detail = response.text
|
|
165
|
+
|
|
166
|
+
raise requests.HTTPError(f"HTTP {response.status_code}: {detail}", response=response)
|
|
167
|
+
|
|
168
|
+
def _login(self, api_key: str):
|
|
169
|
+
"""Login with email/password to establish session."""
|
|
170
|
+
self._session.headers.update({"Authorization": f"Bearer {api_key}"})
|
|
171
|
+
|
|
172
|
+
url = f"{self._server_url}/api-keys/test"
|
|
173
|
+
response = self._session.get(url)
|
|
174
|
+
self._handle_response_errors(response)
|
|
175
|
+
|
|
176
|
+
logger.info("Logged in with API key")
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
def create_collection(
|
|
180
|
+
self,
|
|
181
|
+
collection_id: str | None = None,
|
|
182
|
+
name: str | None = None,
|
|
183
|
+
description: str | None = None,
|
|
184
|
+
) -> str:
|
|
185
|
+
"""Creates a new Collection.
|
|
186
|
+
|
|
187
|
+
Creates a new Collection and sets up a default MECE dimension
|
|
188
|
+
for grouping on the homepage.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
collection_id: Optional ID for the new Collection. If not provided, one will be generated.
|
|
192
|
+
name: Optional name for the Collection.
|
|
193
|
+
description: Optional description for the Collection.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
str: The ID of the created Collection.
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
ValueError: If the response is missing the Collection ID.
|
|
200
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
201
|
+
"""
|
|
202
|
+
url = f"{self._server_url}/create"
|
|
203
|
+
payload = {
|
|
204
|
+
"collection_id": collection_id,
|
|
205
|
+
"name": name,
|
|
206
|
+
"description": description,
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
response = self._session.post(url, json=payload)
|
|
210
|
+
self._handle_response_errors(response)
|
|
211
|
+
|
|
212
|
+
response_data = response.json()
|
|
213
|
+
collection_id = response_data.get("collection_id")
|
|
214
|
+
if collection_id is None:
|
|
215
|
+
raise ValueError("Failed to create collection: 'collection_id' missing in response.")
|
|
216
|
+
|
|
217
|
+
logger.info(f"Successfully created Collection with id='{collection_id}'")
|
|
218
|
+
|
|
219
|
+
logger.info(
|
|
220
|
+
f"Collection creation complete. Frontend available at: {self._web_url}/dashboard/{collection_id}"
|
|
221
|
+
)
|
|
222
|
+
return collection_id
|
|
223
|
+
|
|
224
|
+
def update_collection(
|
|
225
|
+
self,
|
|
226
|
+
collection_id: str,
|
|
227
|
+
name: str | None = None,
|
|
228
|
+
description: str | None = None,
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Updates a Collection's name and/or description.
|
|
231
|
+
|
|
232
|
+
Requires WRITE permission on the collection.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
collection_id: ID of the Collection to update.
|
|
236
|
+
name: New name for the Collection. If None, the name will be cleared.
|
|
237
|
+
description: New description for the Collection. If None, the description will be cleared.
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
241
|
+
"""
|
|
242
|
+
url = f"{self._server_url}/{collection_id}/collection"
|
|
243
|
+
payload: dict[str, Any] = {}
|
|
244
|
+
if name is not None:
|
|
245
|
+
payload["name"] = name
|
|
246
|
+
if description is not None:
|
|
247
|
+
payload["description"] = description
|
|
248
|
+
|
|
249
|
+
response = self._session.put(url, json=payload)
|
|
250
|
+
self._handle_response_errors(response)
|
|
251
|
+
|
|
252
|
+
logger.info(f"Successfully updated Collection '{collection_id}'")
|
|
253
|
+
|
|
254
|
+
def add_agent_runs(
|
|
255
|
+
self,
|
|
256
|
+
collection_id: str,
|
|
257
|
+
agent_runs: list[AgentRun],
|
|
258
|
+
*,
|
|
259
|
+
compression: Literal["gzip", "none"] = "gzip",
|
|
260
|
+
wait: bool = True,
|
|
261
|
+
poll_interval: float = 1.0,
|
|
262
|
+
# Deprecated
|
|
263
|
+
batch_size: int | None = None,
|
|
264
|
+
) -> dict[str, Any]:
|
|
265
|
+
"""Adds agent runs to a Collection.
|
|
266
|
+
|
|
267
|
+
Agent runs represent execution traces that can be visualized and analyzed.
|
|
268
|
+
Requests are automatically chunked to stay under the backend's payload limit.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
collection_id: ID of the Collection.
|
|
272
|
+
agent_runs: List of AgentRun objects to add.
|
|
273
|
+
compression: Compression algorithm for request bodies. Defaults to gzip.
|
|
274
|
+
Set to "none" to retain legacy behavior.
|
|
275
|
+
wait: If True (default), wait for all ingestion jobs to complete before returning.
|
|
276
|
+
If False, return immediately after enqueuing jobs.
|
|
277
|
+
poll_interval: Seconds between status checks when wait=True. Defaults to 1.0.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
dict: API response data containing:
|
|
281
|
+
- status: "success" if all jobs completed, "enqueued" if wait=False
|
|
282
|
+
- total_runs_added: Number of agent runs submitted
|
|
283
|
+
- job_ids: List of job IDs for tracking
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
ValueError: If any single agent run exceeds the maximum payload size.
|
|
287
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
288
|
+
RuntimeError: If any job fails during processing (when wait=True).
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
if batch_size is not None:
|
|
292
|
+
logger.warning(
|
|
293
|
+
"The 'batch_size' parameter is deprecated and will be removed in a future version. "
|
|
294
|
+
"We have transitioned to a new batching strategy based on the size of the payload."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
url = f"{self._server_url}/{collection_id}/agent_runs"
|
|
298
|
+
total_runs = len(agent_runs)
|
|
299
|
+
job_ids: list[str] = []
|
|
300
|
+
|
|
301
|
+
# Process agent runs in batches
|
|
302
|
+
desc = f"Uploading agent runs (compression={compression})"
|
|
303
|
+
with tqdm(total=total_runs, desc=desc, unit="runs") as pbar:
|
|
304
|
+
for batch_size, payload_bytes in _yield_agent_run_batches_by_size(
|
|
305
|
+
agent_runs, MAX_AGENT_RUN_PAYLOAD_BYTES
|
|
306
|
+
):
|
|
307
|
+
request_kwargs: dict[str, Any] = {}
|
|
308
|
+
if compression == "none":
|
|
309
|
+
request_kwargs["data"] = payload_bytes
|
|
310
|
+
request_kwargs["headers"] = {"Content-Type": "application/json"}
|
|
311
|
+
elif compression == "gzip":
|
|
312
|
+
request_kwargs["data"] = gzip.compress(payload_bytes)
|
|
313
|
+
request_kwargs["headers"] = {
|
|
314
|
+
"Content-Type": "application/json",
|
|
315
|
+
"Content-Encoding": "gzip",
|
|
316
|
+
}
|
|
317
|
+
else:
|
|
318
|
+
raise ValueError(f"Unsupported compression '{compression}'")
|
|
319
|
+
|
|
320
|
+
response = self._session.post(url, **request_kwargs)
|
|
321
|
+
self._handle_response_errors(response)
|
|
322
|
+
|
|
323
|
+
# Server returns 202 with job_id for async processing
|
|
324
|
+
response_data = response.json()
|
|
325
|
+
job_id = response_data.get("job_id")
|
|
326
|
+
if job_id:
|
|
327
|
+
job_ids.append(job_id)
|
|
328
|
+
|
|
329
|
+
pbar.update(batch_size)
|
|
330
|
+
|
|
331
|
+
if not wait:
|
|
332
|
+
logger.info(
|
|
333
|
+
f"Enqueued {total_runs} agent runs to Collection '{collection_id}' "
|
|
334
|
+
f"({len(job_ids)} job(s)). Use get_agent_run_job_status() to check progress."
|
|
335
|
+
)
|
|
336
|
+
return {
|
|
337
|
+
"status": "enqueued",
|
|
338
|
+
"total_runs_added": total_runs,
|
|
339
|
+
"job_ids": job_ids,
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
# Wait for all jobs to complete
|
|
343
|
+
if job_ids:
|
|
344
|
+
logger.info(
|
|
345
|
+
f"Uploaded {total_runs} agent runs in {len(job_ids)} batch(es). "
|
|
346
|
+
f"Waiting for server-side processing to complete... "
|
|
347
|
+
f"(set wait=False to skip waiting)"
|
|
348
|
+
)
|
|
349
|
+
self._wait_for_jobs(collection_id, job_ids, poll_interval)
|
|
350
|
+
|
|
351
|
+
logger.info(
|
|
352
|
+
f"Successfully added {total_runs} agent runs to Collection '{collection_id}'. "
|
|
353
|
+
f"All {len(job_ids)} job(s) completed."
|
|
354
|
+
)
|
|
355
|
+
return {"status": "success", "total_runs_added": total_runs, "job_ids": job_ids}
|
|
356
|
+
|
|
357
|
+
def _wait_for_jobs(
|
|
358
|
+
self,
|
|
359
|
+
collection_id: str,
|
|
360
|
+
job_ids: list[str],
|
|
361
|
+
poll_interval: float = 1.0,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""Wait for all jobs to complete, showing progress.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
collection_id: ID of the Collection.
|
|
367
|
+
job_ids: List of job IDs to wait for.
|
|
368
|
+
poll_interval: Seconds between status checks.
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
RuntimeError: If any job fails or is canceled.
|
|
372
|
+
"""
|
|
373
|
+
pending_jobs = set(job_ids)
|
|
374
|
+
failed_jobs: dict[str, str] = {}
|
|
375
|
+
|
|
376
|
+
with tqdm(total=len(job_ids), desc="Waiting for server processing", unit="jobs") as pbar:
|
|
377
|
+
while pending_jobs:
|
|
378
|
+
statuses = self.get_agent_run_job_statuses(collection_id, list(pending_jobs))
|
|
379
|
+
|
|
380
|
+
for job_status in statuses:
|
|
381
|
+
job_id = job_status["job_id"]
|
|
382
|
+
status = job_status["status"]
|
|
383
|
+
|
|
384
|
+
if status == "completed":
|
|
385
|
+
pending_jobs.discard(job_id)
|
|
386
|
+
pbar.update(1)
|
|
387
|
+
elif status == "canceled":
|
|
388
|
+
pending_jobs.discard(job_id)
|
|
389
|
+
failed_jobs[job_id] = "Job was canceled"
|
|
390
|
+
pbar.update(1)
|
|
391
|
+
|
|
392
|
+
if pending_jobs:
|
|
393
|
+
time.sleep(poll_interval)
|
|
394
|
+
|
|
395
|
+
if failed_jobs:
|
|
396
|
+
failed_msg = ", ".join(f"{k}: {v}" for k, v in failed_jobs.items())
|
|
397
|
+
raise RuntimeError(f"Some jobs failed: {failed_msg}")
|
|
398
|
+
|
|
399
|
+
def get_agent_run_job_statuses(
|
|
400
|
+
self, collection_id: str, job_ids: list[str]
|
|
401
|
+
) -> list[dict[str, Any]]:
|
|
402
|
+
"""Get the status of multiple agent run ingestion jobs.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
collection_id: ID of the Collection.
|
|
406
|
+
job_ids: List of job IDs to check (max 100).
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
list: List of job status dictionaries, each containing:
|
|
410
|
+
- job_id: The job ID
|
|
411
|
+
- status: One of "pending", "running", "completed", "canceled"
|
|
412
|
+
- type: The job type
|
|
413
|
+
- created_at: ISO timestamp of job creation
|
|
414
|
+
|
|
415
|
+
Raises:
|
|
416
|
+
ValueError: If more than 100 job IDs are provided.
|
|
417
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
418
|
+
"""
|
|
419
|
+
if len(job_ids) > 100:
|
|
420
|
+
raise ValueError("Cannot request more than 100 job IDs at once")
|
|
421
|
+
|
|
422
|
+
url = f"{self._server_url}/{collection_id}/agent_runs/jobs/batch_status"
|
|
423
|
+
response = self._session.post(url, json={"job_ids": job_ids})
|
|
424
|
+
self._handle_response_errors(response)
|
|
425
|
+
return response.json()["jobs"]
|
|
426
|
+
|
|
427
|
+
def get_agent_run_job_status(self, collection_id: str, job_id: str) -> dict[str, Any]:
|
|
428
|
+
"""Get the status of an agent run ingestion job.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
collection_id: ID of the Collection.
|
|
432
|
+
job_id: The ID of the job to check.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
dict: Job status information including:
|
|
436
|
+
- job_id: The job ID
|
|
437
|
+
- status: One of "pending", "running", "completed", "canceled"
|
|
438
|
+
- type: The job type
|
|
439
|
+
- created_at: ISO timestamp of job creation
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
443
|
+
"""
|
|
444
|
+
url = f"{self._server_url}/{collection_id}/agent_runs/jobs/{job_id}"
|
|
445
|
+
response = self._session.get(url)
|
|
446
|
+
self._handle_response_errors(response)
|
|
447
|
+
return response.json()
|
|
448
|
+
|
|
449
|
+
def list_collections(self) -> list[dict[str, Any]]:
|
|
450
|
+
"""Lists all available Collections.
|
|
451
|
+
|
|
452
|
+
Returns:
|
|
453
|
+
list: List of Collection objects.
|
|
454
|
+
|
|
455
|
+
Raises:
|
|
456
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
457
|
+
"""
|
|
458
|
+
url = f"{self._server_url}/collections"
|
|
459
|
+
response = self._session.get(url)
|
|
460
|
+
self._handle_response_errors(response)
|
|
461
|
+
return response.json()
|
|
462
|
+
|
|
463
|
+
def get_collection(self, collection_id: str) -> dict[str, Any] | None:
|
|
464
|
+
"""Get details about a specific Collection.
|
|
465
|
+
|
|
466
|
+
Requires READ permission on the collection.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
collection_id: ID of the Collection to retrieve.
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Collection: Collection object with id, name, description, created_at, and created_by.
|
|
473
|
+
Returns None if collection not found.
|
|
474
|
+
|
|
475
|
+
Raises:
|
|
476
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
477
|
+
"""
|
|
478
|
+
url = f"{self._server_url}/{collection_id}/collection_details"
|
|
479
|
+
response = self._session.get(url)
|
|
480
|
+
self._handle_response_errors(response)
|
|
481
|
+
return response.json()
|
|
482
|
+
|
|
483
|
+
def list_rubrics(self, collection_id: str) -> list[dict[str, Any]]:
|
|
484
|
+
"""List all rubrics for a given collection.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
collection_id: ID of the Collection.
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
list: List of dictionaries containing rubric information.
|
|
491
|
+
|
|
492
|
+
Raises:
|
|
493
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
494
|
+
"""
|
|
495
|
+
url = f"{self._server_url}/rubric/{collection_id}/rubrics"
|
|
496
|
+
response = self._session.get(url)
|
|
497
|
+
self._handle_response_errors(response)
|
|
498
|
+
return response.json()
|
|
499
|
+
|
|
500
|
+
def get_rubric_run_state(
|
|
501
|
+
self, collection_id: str, rubric_id: str, version: int | None = None
|
|
502
|
+
) -> dict[str, Any]:
|
|
503
|
+
"""Get rubric run state for a given collection and rubric.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
collection_id: ID of the Collection.
|
|
507
|
+
rubric_id: The ID of the rubric to get run state for.
|
|
508
|
+
version: The version of the rubric to get run state for. If None, the latest version is used.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
dict: Dictionary containing rubric run state with results, job_id, and total_results_needed.
|
|
512
|
+
|
|
513
|
+
Raises:
|
|
514
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
515
|
+
"""
|
|
516
|
+
url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/rubric_run_state"
|
|
517
|
+
response = self._session.get(url, params={"version": version})
|
|
518
|
+
self._handle_response_errors(response)
|
|
519
|
+
return response.json()
|
|
520
|
+
|
|
521
|
+
def get_clustering_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
|
|
522
|
+
"""Get clustering state for a given collection and rubric.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
collection_id: ID of the Collection.
|
|
526
|
+
rubric_id: The ID of the rubric to get clustering state for.
|
|
527
|
+
|
|
528
|
+
Returns:
|
|
529
|
+
dict: Dictionary containing job_id, centroids, and assignments.
|
|
530
|
+
|
|
531
|
+
Raises:
|
|
532
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
533
|
+
"""
|
|
534
|
+
url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/clustering_job"
|
|
535
|
+
response = self._session.get(url)
|
|
536
|
+
self._handle_response_errors(response)
|
|
537
|
+
return response.json()
|
|
538
|
+
|
|
539
|
+
def get_cluster_centroids(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
|
|
540
|
+
"""Get centroids for a given collection and rubric.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
collection_id: ID of the Collection.
|
|
544
|
+
rubric_id: The ID of the rubric to get centroids for.
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
list: List of dictionaries containing centroid information.
|
|
548
|
+
|
|
549
|
+
Raises:
|
|
550
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
551
|
+
"""
|
|
552
|
+
clustering_state = self.get_clustering_state(collection_id, rubric_id)
|
|
553
|
+
return clustering_state.get("centroids", [])
|
|
554
|
+
|
|
555
|
+
def get_cluster_assignments(self, collection_id: str, rubric_id: str) -> dict[str, list[str]]:
|
|
556
|
+
"""Get centroid assignments for a given rubric.
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
collection_id: ID of the Collection.
|
|
560
|
+
rubric_id: The ID of the rubric to get assignments for.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
dict: Dictionary mapping centroid IDs to lists of judge result IDs.
|
|
564
|
+
|
|
565
|
+
Raises:
|
|
566
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
567
|
+
"""
|
|
568
|
+
clustering_state = self.get_clustering_state(collection_id, rubric_id)
|
|
569
|
+
return clustering_state.get("assignments", {})
|
|
570
|
+
|
|
571
|
+
def create_label_set(
|
|
572
|
+
self,
|
|
573
|
+
collection_id: str,
|
|
574
|
+
name: str,
|
|
575
|
+
label_schema: dict[str, Any],
|
|
576
|
+
description: str | None = None,
|
|
577
|
+
) -> str:
|
|
578
|
+
"""Create a new label set with a JSON schema.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
collection_id: ID of the collection.
|
|
582
|
+
name: Name of the label set.
|
|
583
|
+
label_schema: JSON schema for validating labels in this set.
|
|
584
|
+
description: Optional description of the label set.
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
str: The ID of the created label set.
|
|
588
|
+
|
|
589
|
+
Raises:
|
|
590
|
+
ValueError: If the response is missing the label_set_id.
|
|
591
|
+
jsonschema.ValidationError: If the label schema is invalid.
|
|
592
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
593
|
+
"""
|
|
594
|
+
validate_judge_result_schema(label_schema)
|
|
595
|
+
|
|
596
|
+
url = f"{self._server_url}/label/{collection_id}/label_set"
|
|
597
|
+
payload = {
|
|
598
|
+
"name": name,
|
|
599
|
+
"label_schema": label_schema,
|
|
600
|
+
"description": description,
|
|
601
|
+
}
|
|
602
|
+
response = self._session.post(url, json=payload)
|
|
603
|
+
self._handle_response_errors(response)
|
|
604
|
+
return response.json()["label_set_id"]
|
|
605
|
+
|
|
606
|
+
def add_label(
|
|
607
|
+
self,
|
|
608
|
+
collection_id: str,
|
|
609
|
+
label: Label,
|
|
610
|
+
) -> dict[str, str]:
|
|
611
|
+
"""Create a label in a label set.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
collection_id: ID of the Collection.
|
|
615
|
+
label: A `Label` object that must comply with the label set's schema.
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
dict: API response containing the label_id.
|
|
619
|
+
|
|
620
|
+
Raises:
|
|
621
|
+
requests.exceptions.HTTPError: If the API request fails or validation errors occur.
|
|
622
|
+
"""
|
|
623
|
+
url = f"{self._server_url}/label/{collection_id}/label"
|
|
624
|
+
payload = {"label": label.model_dump(mode="json")}
|
|
625
|
+
response = self._session.post(url, json=payload)
|
|
626
|
+
self._handle_response_errors(response)
|
|
627
|
+
return response.json()
|
|
628
|
+
|
|
629
|
+
def add_labels(
|
|
630
|
+
self,
|
|
631
|
+
collection_id: str,
|
|
632
|
+
labels: list[Label],
|
|
633
|
+
) -> dict[str, Any]:
|
|
634
|
+
"""Create multiple labels.
|
|
635
|
+
|
|
636
|
+
Args:
|
|
637
|
+
collection_id: ID of the Collection.
|
|
638
|
+
labels: List of `Label` objects.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
dict: API response containing label_ids list and optional errors list.
|
|
642
|
+
|
|
643
|
+
Raises:
|
|
644
|
+
ValueError: If no labels are provided.
|
|
645
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
646
|
+
"""
|
|
647
|
+
if not labels:
|
|
648
|
+
raise ValueError("labels must contain at least one entry")
|
|
649
|
+
|
|
650
|
+
url = f"{self._server_url}/label/{collection_id}/labels"
|
|
651
|
+
payload = {"labels": [label.model_dump(mode="json") for label in labels]}
|
|
652
|
+
response = self._session.post(url, json=payload)
|
|
653
|
+
self._handle_response_errors(response)
|
|
654
|
+
return response.json()
|
|
655
|
+
|
|
656
|
+
def get_labels(
|
|
657
|
+
self, collection_id: str, label_set_id: str, filter_valid_labels: bool = False
|
|
658
|
+
) -> list[dict[str, Any]]:
|
|
659
|
+
"""Retrieve all labels in a label set.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
collection_id: ID of the Collection.
|
|
663
|
+
label_set_id: ID of the label set to fetch labels for.
|
|
664
|
+
filter_valid_labels: If True, only return labels that match the label set schema
|
|
665
|
+
INCLUDING requirements. Default is False (returns all labels).
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
list: List of label dictionaries.
|
|
669
|
+
|
|
670
|
+
Raises:
|
|
671
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
672
|
+
"""
|
|
673
|
+
url = f"{self._server_url}/label/{collection_id}/label_set/{label_set_id}/labels"
|
|
674
|
+
params = {"filter_valid_labels": filter_valid_labels}
|
|
675
|
+
response = self._session.get(url, params=params)
|
|
676
|
+
self._handle_response_errors(response)
|
|
677
|
+
return response.json()
|
|
678
|
+
|
|
679
|
+
def tag_transcript(self, collection_id: str, agent_run_id: str, value: str) -> None:
|
|
680
|
+
"""Add a tag to an agent run transcript.
|
|
681
|
+
|
|
682
|
+
Args:
|
|
683
|
+
collection_id: ID of the Collection.
|
|
684
|
+
agent_run_id: The agent run to tag.
|
|
685
|
+
value: The tag value (max length enforced by the server).
|
|
686
|
+
|
|
687
|
+
Raises:
|
|
688
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
689
|
+
"""
|
|
690
|
+
url = f"{self._server_url}/label/{collection_id}/tag"
|
|
691
|
+
payload = {"agent_run_id": agent_run_id, "value": value}
|
|
692
|
+
response = self._session.post(url, json=payload)
|
|
693
|
+
self._handle_response_errors(response)
|
|
694
|
+
|
|
695
|
+
def get_tags(self, collection_id: str, value: str | None = None) -> list[dict[str, Any]]:
|
|
696
|
+
"""Get all tags in a collection, optionally filtered by value."""
|
|
697
|
+
url = f"{self._server_url}/label/{collection_id}/tags"
|
|
698
|
+
params = {"value": value} if value is not None else None
|
|
699
|
+
response = self._session.get(url, params=params)
|
|
700
|
+
self._handle_response_errors(response)
|
|
701
|
+
return response.json()
|
|
702
|
+
|
|
703
|
+
def get_tags_for_agent_run(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
|
|
704
|
+
"""Get all tags attached to a specific agent run."""
|
|
705
|
+
url = f"{self._server_url}/label/{collection_id}/agent_run/{agent_run_id}/tags"
|
|
706
|
+
response = self._session.get(url)
|
|
707
|
+
self._handle_response_errors(response)
|
|
708
|
+
return response.json()
|
|
709
|
+
|
|
710
|
+
def delete_tag(self, collection_id: str, tag_id: str) -> None:
|
|
711
|
+
"""Delete a tag by ID."""
|
|
712
|
+
url = f"{self._server_url}/label/{collection_id}/tag/{tag_id}"
|
|
713
|
+
response = self._session.delete(url)
|
|
714
|
+
self._handle_response_errors(response)
|
|
715
|
+
|
|
716
|
+
def get_agent_run(self, collection_id: str, agent_run_id: str) -> AgentRun | None:
|
|
717
|
+
"""Get a specific agent run by its ID.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
collection_id: ID of the Collection.
|
|
721
|
+
agent_run_id: The ID of the agent run to retrieve.
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
dict: Dictionary containing the agent run information.
|
|
725
|
+
|
|
726
|
+
Raises:
|
|
727
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
728
|
+
"""
|
|
729
|
+
url = f"{self._server_url}/{collection_id}/agent_run"
|
|
730
|
+
response = self._session.get(url, params={"agent_run_id": agent_run_id})
|
|
731
|
+
self._handle_response_errors(response)
|
|
732
|
+
if response.json() is None:
|
|
733
|
+
return None
|
|
734
|
+
else:
|
|
735
|
+
# We do this to avoid metadata validation failing
|
|
736
|
+
# TODO(mengk): kinda hacky
|
|
737
|
+
return AgentRun.model_validate(response.json())
|
|
738
|
+
|
|
739
|
+
def get_chat_sessions(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
|
|
740
|
+
"""Get all chat sessions for an agent run, excluding judge result sessions.
|
|
741
|
+
|
|
742
|
+
Args:
|
|
743
|
+
collection_id: ID of the Collection.
|
|
744
|
+
agent_run_id: The ID of the agent run to retrieve chat sessions for.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
list: List of chat session dictionaries.
|
|
748
|
+
|
|
749
|
+
Raises:
|
|
750
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
751
|
+
"""
|
|
752
|
+
url = f"{self._server_url}/chat/{collection_id}/{agent_run_id}/sessions"
|
|
753
|
+
response = self._session.get(url)
|
|
754
|
+
self._handle_response_errors(response)
|
|
755
|
+
return response.json()
|
|
756
|
+
|
|
757
|
+
def make_collection_public(self, collection_id: str) -> dict[str, Any]:
|
|
758
|
+
"""Make a collection publicly accessible to anyone with the link.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
collection_id: ID of the Collection to make public.
|
|
762
|
+
|
|
763
|
+
Returns:
|
|
764
|
+
dict: API response data.
|
|
765
|
+
|
|
766
|
+
Raises:
|
|
767
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
768
|
+
"""
|
|
769
|
+
url = f"{self._server_url}/{collection_id}/make_public"
|
|
770
|
+
response = self._session.post(url)
|
|
771
|
+
self._handle_response_errors(response)
|
|
772
|
+
|
|
773
|
+
logger.info(f"Successfully made Collection '{collection_id}' public")
|
|
774
|
+
return response.json()
|
|
775
|
+
|
|
776
|
+
def share_collection_with_email(self, collection_id: str, email: str) -> dict[str, Any]:
|
|
777
|
+
"""Share a collection with a specific user by email address.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
collection_id: ID of the Collection to share.
|
|
781
|
+
email: Email address of the user to share with.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
dict: API response data.
|
|
785
|
+
|
|
786
|
+
Raises:
|
|
787
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
788
|
+
"""
|
|
789
|
+
url = f"{self._server_url}/{collection_id}/share_with_email"
|
|
790
|
+
payload = {"email": email}
|
|
791
|
+
response = self._session.post(url, json=payload)
|
|
792
|
+
|
|
793
|
+
self._handle_response_errors(response)
|
|
794
|
+
|
|
795
|
+
logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
|
|
796
|
+
return response.json()
|
|
797
|
+
|
|
798
|
+
def collection_exists(self, collection_id: str) -> bool:
|
|
799
|
+
"""Check if a collection exists without raising if it does not."""
|
|
800
|
+
url = f"{self._server_url}/{collection_id}/exists"
|
|
801
|
+
response = self._session.get(url)
|
|
802
|
+
self._handle_response_errors(response)
|
|
803
|
+
return bool(response.json())
|
|
804
|
+
|
|
805
|
+
def has_collection_permission(self, collection_id: str, permission: str = "write") -> bool:
|
|
806
|
+
"""Check whether the authenticated user has a specific permission on a collection.
|
|
807
|
+
|
|
808
|
+
Args:
|
|
809
|
+
collection_id: Collection to check.
|
|
810
|
+
permission: Permission level to verify (`read`, `write`, or `admin`).
|
|
811
|
+
|
|
812
|
+
Returns:
|
|
813
|
+
bool: True if the current API key has the requested permission; otherwise False.
|
|
814
|
+
|
|
815
|
+
Raises:
|
|
816
|
+
ValueError: If an unsupported permission value is provided.
|
|
817
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
818
|
+
"""
|
|
819
|
+
valid_permissions = {"read", "write", "admin"}
|
|
820
|
+
if permission not in valid_permissions:
|
|
821
|
+
raise ValueError(f"permission must be one of {sorted(valid_permissions)}")
|
|
822
|
+
|
|
823
|
+
url = f"{self._server_url}/{collection_id}/has_permission"
|
|
824
|
+
response = self._session.get(url, params={"permission": permission})
|
|
825
|
+
self._handle_response_errors(response)
|
|
826
|
+
|
|
827
|
+
payload = response.json()
|
|
828
|
+
return bool(payload.get("has_permission", False))
|
|
829
|
+
|
|
830
|
+
def get_dql_schema(self, collection_id: str) -> dict[str, Any]:
|
|
831
|
+
"""Retrieve the DQL schema for a collection.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
collection_id: ID of the Collection.
|
|
835
|
+
|
|
836
|
+
Returns:
|
|
837
|
+
dict: Dictionary containing available tables, columns, and metadata for DQL queries.
|
|
838
|
+
|
|
839
|
+
Raises:
|
|
840
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
841
|
+
"""
|
|
842
|
+
url = f"{self._server_url}/dql/{collection_id}/schema"
|
|
843
|
+
response = self._session.get(url)
|
|
844
|
+
self._handle_response_errors(response)
|
|
845
|
+
return response.json()
|
|
846
|
+
|
|
847
|
+
def execute_dql(self, collection_id: str, dql: str) -> dict[str, Any]:
|
|
848
|
+
"""Execute a DQL query against a collection.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
collection_id: ID of the Collection.
|
|
852
|
+
dql: The DQL query string to execute.
|
|
853
|
+
|
|
854
|
+
Returns:
|
|
855
|
+
dict: Query execution results including rows, columns, execution metadata, and selected columns.
|
|
856
|
+
|
|
857
|
+
Raises:
|
|
858
|
+
ValueError: If `dql` is empty.
|
|
859
|
+
requests.exceptions.HTTPError: If the API request fails or the query is invalid.
|
|
860
|
+
"""
|
|
861
|
+
if not dql.strip():
|
|
862
|
+
raise ValueError("dql must be a non-empty string")
|
|
863
|
+
|
|
864
|
+
url = f"{self._server_url}/dql/{collection_id}/execute"
|
|
865
|
+
response = self._session.post(url, json={"dql": dql})
|
|
866
|
+
self._handle_response_errors(response)
|
|
867
|
+
return response.json()
|
|
868
|
+
|
|
869
|
+
def dql_result_to_dicts(self, dql_result: dict[str, Any]) -> list[dict[str, Any]]:
|
|
870
|
+
"""Convert a DQL result to a list of dictionaries."""
|
|
871
|
+
cols = dql_result["columns"]
|
|
872
|
+
rows = dql_result["rows"]
|
|
873
|
+
return [dict(zip(cols, row)) for row in rows]
|
|
874
|
+
|
|
875
|
+
def dql_result_to_df_experimental(self, dql_result: dict[str, Any]):
|
|
876
|
+
"""The implementation is not stable by any means!"""
|
|
877
|
+
|
|
878
|
+
cols = dql_result["columns"]
|
|
879
|
+
rows = dql_result["rows"]
|
|
880
|
+
|
|
881
|
+
def _cast_value(v: Any) -> Any:
|
|
882
|
+
"""Cast a value to int, float, bool, or str as appropriate."""
|
|
883
|
+
if v is None:
|
|
884
|
+
return None
|
|
885
|
+
if isinstance(v, (bool, int, float)):
|
|
886
|
+
return v
|
|
887
|
+
|
|
888
|
+
# If a string, try to cast into a number
|
|
889
|
+
if isinstance(v, str):
|
|
890
|
+
try:
|
|
891
|
+
if "." not in v:
|
|
892
|
+
return int(v)
|
|
893
|
+
except (ValueError, TypeError):
|
|
894
|
+
pass
|
|
895
|
+
|
|
896
|
+
try:
|
|
897
|
+
return float(v)
|
|
898
|
+
except (ValueError, TypeError):
|
|
899
|
+
pass
|
|
900
|
+
|
|
901
|
+
# Keep as original
|
|
902
|
+
return v
|
|
903
|
+
|
|
904
|
+
dicts: list[dict[str, Any]] = []
|
|
905
|
+
for row in rows:
|
|
906
|
+
combo = list(zip(cols, row))
|
|
907
|
+
combo = {k: _cast_value(v) for k, v in combo}
|
|
908
|
+
dicts.append(combo)
|
|
909
|
+
|
|
910
|
+
return pd.DataFrame(dicts)
|
|
911
|
+
|
|
912
|
+
def select_agent_run_ids(
|
|
913
|
+
self,
|
|
914
|
+
collection_id: str,
|
|
915
|
+
where_clause: str | None = None,
|
|
916
|
+
limit: int | None = None,
|
|
917
|
+
) -> list[str]:
|
|
918
|
+
"""Convenience helper to fetch agent run IDs via DQL.
|
|
919
|
+
|
|
920
|
+
Args:
|
|
921
|
+
collection_id: ID of the Collection to query.
|
|
922
|
+
where_clause: Optional DQL WHERE clause applied to the agent_runs table.
|
|
923
|
+
limit: Optional LIMIT applied to the underlying DQL query.
|
|
924
|
+
|
|
925
|
+
Returns:
|
|
926
|
+
list[str]: Agent run IDs matching the criteria.
|
|
927
|
+
|
|
928
|
+
Raises:
|
|
929
|
+
ValueError: If the inputs are invalid.
|
|
930
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
931
|
+
"""
|
|
932
|
+
query = "SELECT agent_runs.id AS agent_run_id FROM agent_runs"
|
|
933
|
+
|
|
934
|
+
if where_clause:
|
|
935
|
+
where_clause = where_clause.strip()
|
|
936
|
+
if not where_clause:
|
|
937
|
+
raise ValueError("where_clause must be a non-empty string when provided")
|
|
938
|
+
query += f" WHERE {where_clause}"
|
|
939
|
+
|
|
940
|
+
if limit is not None:
|
|
941
|
+
if limit <= 0:
|
|
942
|
+
raise ValueError("limit must be a positive integer when provided")
|
|
943
|
+
query += f" LIMIT {limit}"
|
|
944
|
+
|
|
945
|
+
result = self.execute_dql(collection_id, query)
|
|
946
|
+
rows = result.get("rows", [])
|
|
947
|
+
agent_run_ids = [str(row[0]) for row in rows if row]
|
|
948
|
+
|
|
949
|
+
if result.get("truncated"):
|
|
950
|
+
logger.warning(
|
|
951
|
+
"DQL query truncated at applied limit %s; returning %s agent run IDs",
|
|
952
|
+
result.get("applied_limit"),
|
|
953
|
+
len(agent_run_ids),
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
return agent_run_ids
|
|
957
|
+
|
|
958
|
+
def list_agent_run_ids(self, collection_id: str) -> list[str]:
|
|
959
|
+
"""Get all agent run IDs for a collection.
|
|
960
|
+
|
|
961
|
+
Args:
|
|
962
|
+
collection_id: ID of the Collection.
|
|
963
|
+
|
|
964
|
+
Returns:
|
|
965
|
+
str: JSON string containing the list of agent run IDs.
|
|
966
|
+
|
|
967
|
+
Raises:
|
|
968
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
969
|
+
"""
|
|
970
|
+
url = f"{self._server_url}/{collection_id}/agent_run_ids"
|
|
971
|
+
response = self._session.get(url)
|
|
972
|
+
self._handle_response_errors(response)
|
|
973
|
+
return response.json()
|
|
974
|
+
|
|
975
|
+
def recursively_ingest_inspect_logs(self, collection_id: str, fpath: str):
|
|
976
|
+
"""Recursively search directory for .eval files and ingest them as agent runs.
|
|
977
|
+
|
|
978
|
+
Args:
|
|
979
|
+
collection_id: ID of the Collection to add agent runs to.
|
|
980
|
+
fpath: Path to directory to search recursively.
|
|
981
|
+
|
|
982
|
+
Raises:
|
|
983
|
+
ValueError: If the path doesn't exist or isn't a directory.
|
|
984
|
+
requests.exceptions.HTTPError: If any API requests fail.
|
|
985
|
+
"""
|
|
986
|
+
root_path = Path(fpath)
|
|
987
|
+
if not root_path.exists():
|
|
988
|
+
raise ValueError(f"Path does not exist: {fpath}")
|
|
989
|
+
if not root_path.is_dir():
|
|
990
|
+
raise ValueError(f"Path is not a directory: {fpath}")
|
|
991
|
+
|
|
992
|
+
# Find all .eval files recursively
|
|
993
|
+
eval_files = list(root_path.rglob("*.eval"))
|
|
994
|
+
|
|
995
|
+
if not eval_files:
|
|
996
|
+
logger.info(f"No .eval files found in {fpath}")
|
|
997
|
+
return
|
|
998
|
+
|
|
999
|
+
logger.info(f"Found {len(eval_files)} .eval files in {fpath}")
|
|
1000
|
+
|
|
1001
|
+
total_runs_added = 0
|
|
1002
|
+
batch_size = 100
|
|
1003
|
+
|
|
1004
|
+
# Process each .eval file
|
|
1005
|
+
for eval_file in tqdm(eval_files, desc="Processing .eval files", unit="files"):
|
|
1006
|
+
# Get total samples for progress tracking
|
|
1007
|
+
total_samples = load_inspect.get_total_samples(eval_file, format="eval")
|
|
1008
|
+
|
|
1009
|
+
if total_samples == 0:
|
|
1010
|
+
logger.info(f"No samples found in {eval_file}")
|
|
1011
|
+
continue
|
|
1012
|
+
|
|
1013
|
+
# Load runs from file
|
|
1014
|
+
with open(eval_file, "rb") as f:
|
|
1015
|
+
_, runs_generator = load_inspect.runs_from_file(f, format="eval")
|
|
1016
|
+
|
|
1017
|
+
# Process runs in batches
|
|
1018
|
+
runs_from_file = 0
|
|
1019
|
+
batches = itertools.batched(runs_generator, batch_size)
|
|
1020
|
+
|
|
1021
|
+
with tqdm(
|
|
1022
|
+
total=total_samples,
|
|
1023
|
+
desc=f"Processing {eval_file.name}",
|
|
1024
|
+
unit="runs",
|
|
1025
|
+
leave=False,
|
|
1026
|
+
) as file_pbar:
|
|
1027
|
+
for batch in batches:
|
|
1028
|
+
batch_list = list(batch) # Convert generator batch to list
|
|
1029
|
+
if not batch_list:
|
|
1030
|
+
break
|
|
1031
|
+
|
|
1032
|
+
# Add batch to collection
|
|
1033
|
+
url = f"{self._server_url}/{collection_id}/agent_runs"
|
|
1034
|
+
payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch_list]}
|
|
1035
|
+
|
|
1036
|
+
response = self._session.post(url, json=payload)
|
|
1037
|
+
self._handle_response_errors(response)
|
|
1038
|
+
|
|
1039
|
+
runs_from_file += len(batch_list)
|
|
1040
|
+
file_pbar.update(len(batch_list))
|
|
1041
|
+
|
|
1042
|
+
total_runs_added += runs_from_file
|
|
1043
|
+
logger.info(f"Added {runs_from_file} runs from {eval_file}")
|
|
1044
|
+
|
|
1045
|
+
logger.info(
|
|
1046
|
+
f"Successfully ingested {total_runs_added} total agent runs from {len(eval_files)} files"
|
|
1047
|
+
)
|
|
1048
|
+
|
|
1049
|
+
def start_chat(
|
|
1050
|
+
self,
|
|
1051
|
+
context: LLMContext | list[LLMContextItem],
|
|
1052
|
+
model_string: str | None = None,
|
|
1053
|
+
reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
|
|
1054
|
+
) -> str:
|
|
1055
|
+
"""Start a chat session with multiple objects and open it in the browser.
|
|
1056
|
+
|
|
1057
|
+
This method creates a new chat session with the provided objects (agent runs,
|
|
1058
|
+
transcripts, or formatted versions) and opens the chat UI in your default browser.
|
|
1059
|
+
|
|
1060
|
+
Args:
|
|
1061
|
+
objects: List of objects to include in the chat context. Can include:
|
|
1062
|
+
- AgentRun or FormattedAgentRun instances
|
|
1063
|
+
- Transcript or FormattedTranscript instances
|
|
1064
|
+
chat_model: Optional model to use for the chat. If None, uses default.
|
|
1065
|
+
|
|
1066
|
+
Returns:
|
|
1067
|
+
str: The session ID of the created chat session.
|
|
1068
|
+
|
|
1069
|
+
Raises:
|
|
1070
|
+
ValueError: If objects list is empty or contains unsupported types.
|
|
1071
|
+
requests.exceptions.HTTPError: If the API request fails.
|
|
1072
|
+
|
|
1073
|
+
Example:
|
|
1074
|
+
```python
|
|
1075
|
+
from docent.sdk import Docent
|
|
1076
|
+
|
|
1077
|
+
client = Docent()
|
|
1078
|
+
run1 = client.get_agent_run(collection_id, run_id_1)
|
|
1079
|
+
run2 = client.get_agent_run(collection_id, run_id_2)
|
|
1080
|
+
|
|
1081
|
+
session_id = client.start_chat([run1, run2])
|
|
1082
|
+
# Opens browser to chat UI
|
|
1083
|
+
```
|
|
1084
|
+
"""
|
|
1085
|
+
if isinstance(context, list):
|
|
1086
|
+
context = LLMContext(items=context)
|
|
1087
|
+
else:
|
|
1088
|
+
context = context
|
|
1089
|
+
|
|
1090
|
+
serialized_context = context.to_dict()
|
|
1091
|
+
|
|
1092
|
+
url = f"{self._server_url}/chat/start"
|
|
1093
|
+
payload = {
|
|
1094
|
+
"context_serialized": serialized_context,
|
|
1095
|
+
"model_string": model_string,
|
|
1096
|
+
"reasoning_effort": reasoning_effort,
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
response = self._session.post(url, json=payload)
|
|
1100
|
+
self._handle_response_errors(response)
|
|
1101
|
+
|
|
1102
|
+
response_data = response.json()
|
|
1103
|
+
session_id = response_data.get("session_id")
|
|
1104
|
+
if not session_id:
|
|
1105
|
+
raise ValueError("Failed to create chat session: 'session_id' missing in response")
|
|
1106
|
+
|
|
1107
|
+
chat_url = f"{self._web_url}/chat/{session_id}"
|
|
1108
|
+
logger.info(f"Chat session created. Opening browser to: {chat_url}")
|
|
1109
|
+
|
|
1110
|
+
webbrowser.open(chat_url)
|
|
1111
|
+
|
|
1112
|
+
return session_id
|
|
1113
|
+
|
|
1114
|
+
def open_agent_run(self, collection_id: str, agent_run_id: str) -> str:
|
|
1115
|
+
"""Open an agent run in the browser.
|
|
1116
|
+
|
|
1117
|
+
Args:
|
|
1118
|
+
collection_id: ID of the Collection containing the agent run.
|
|
1119
|
+
agent_run_id: ID of the agent run to open.
|
|
1120
|
+
|
|
1121
|
+
Returns:
|
|
1122
|
+
str: The URL that was opened.
|
|
1123
|
+
|
|
1124
|
+
Example:
|
|
1125
|
+
```python
|
|
1126
|
+
from docent.sdk import Docent
|
|
1127
|
+
|
|
1128
|
+
client = Docent()
|
|
1129
|
+
client.open_agent_run(collection_id, agent_run_id)
|
|
1130
|
+
# Opens browser to agent run page
|
|
1131
|
+
```
|
|
1132
|
+
"""
|
|
1133
|
+
agent_run_url = f"{self._web_url}/dashboard/{collection_id}/agent_run/{agent_run_id}"
|
|
1134
|
+
logger.info(f"Opening agent run in browser: {agent_run_url}")
|
|
1135
|
+
|
|
1136
|
+
webbrowser.open(agent_run_url)
|
|
1137
|
+
|
|
1138
|
+
return agent_run_url
|
|
1139
|
+
|
|
1140
|
+
def open_rubric(
|
|
1141
|
+
self,
|
|
1142
|
+
collection_id: str,
|
|
1143
|
+
rubric_id: str,
|
|
1144
|
+
agent_run_id: str | None = None,
|
|
1145
|
+
judge_result_id: str | None = None,
|
|
1146
|
+
) -> str:
|
|
1147
|
+
"""Open a rubric, agent run, or judge result in the browser.
|
|
1148
|
+
|
|
1149
|
+
Args:
|
|
1150
|
+
collection_id: ID of the Collection.
|
|
1151
|
+
rubric_id: ID of the rubric.
|
|
1152
|
+
agent_run_id: Optional ID of the agent run to view within the rubric.
|
|
1153
|
+
judge_result_id: Optional ID of the judge result to view. Requires agent_run_id.
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
str: The URL that was opened.
|
|
1157
|
+
|
|
1158
|
+
Raises:
|
|
1159
|
+
ValueError: If judge_result_id is provided without agent_run_id.
|
|
1160
|
+
|
|
1161
|
+
Example:
|
|
1162
|
+
```python
|
|
1163
|
+
from docent.sdk import Docent
|
|
1164
|
+
|
|
1165
|
+
client = Docent()
|
|
1166
|
+
# Open rubric overview
|
|
1167
|
+
client.open_rubric(collection_id, rubric_id)
|
|
1168
|
+
# Open specific agent run within rubric
|
|
1169
|
+
client.open_rubric(collection_id, rubric_id, agent_run_id)
|
|
1170
|
+
# Open specific judge result
|
|
1171
|
+
client.open_rubric(collection_id, rubric_id, agent_run_id, judge_result_id)
|
|
1172
|
+
```
|
|
1173
|
+
"""
|
|
1174
|
+
if judge_result_id is not None and agent_run_id is None:
|
|
1175
|
+
raise ValueError("judge_result_id requires agent_run_id to be specified")
|
|
1176
|
+
|
|
1177
|
+
url = f"{self._web_url}/dashboard/{collection_id}/rubric/{rubric_id}"
|
|
1178
|
+
if agent_run_id is not None:
|
|
1179
|
+
url += f"/agent_run/{agent_run_id}"
|
|
1180
|
+
if judge_result_id is not None:
|
|
1181
|
+
url += f"/result/{judge_result_id}"
|
|
1182
|
+
|
|
1183
|
+
logger.info(f"Opening rubric in browser: {url}")
|
|
1184
|
+
webbrowser.open(url)
|
|
1185
|
+
|
|
1186
|
+
return url
|