docent-python 0.1.41a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (59) hide show
  1. docent/__init__.py +4 -0
  2. docent/_llm_util/__init__.py +0 -0
  3. docent/_llm_util/data_models/__init__.py +0 -0
  4. docent/_llm_util/data_models/exceptions.py +48 -0
  5. docent/_llm_util/data_models/llm_output.py +331 -0
  6. docent/_llm_util/llm_cache.py +193 -0
  7. docent/_llm_util/llm_svc.py +472 -0
  8. docent/_llm_util/model_registry.py +134 -0
  9. docent/_llm_util/providers/__init__.py +0 -0
  10. docent/_llm_util/providers/anthropic.py +537 -0
  11. docent/_llm_util/providers/common.py +41 -0
  12. docent/_llm_util/providers/google.py +530 -0
  13. docent/_llm_util/providers/openai.py +745 -0
  14. docent/_llm_util/providers/openrouter.py +375 -0
  15. docent/_llm_util/providers/preference_types.py +104 -0
  16. docent/_llm_util/providers/provider_registry.py +164 -0
  17. docent/_log_util/__init__.py +3 -0
  18. docent/_log_util/logger.py +141 -0
  19. docent/data_models/__init__.py +14 -0
  20. docent/data_models/_tiktoken_util.py +91 -0
  21. docent/data_models/agent_run.py +473 -0
  22. docent/data_models/chat/__init__.py +37 -0
  23. docent/data_models/chat/content.py +56 -0
  24. docent/data_models/chat/message.py +191 -0
  25. docent/data_models/chat/tool.py +109 -0
  26. docent/data_models/citation.py +187 -0
  27. docent/data_models/formatted_objects.py +84 -0
  28. docent/data_models/judge.py +17 -0
  29. docent/data_models/metadata_util.py +16 -0
  30. docent/data_models/regex.py +56 -0
  31. docent/data_models/transcript.py +305 -0
  32. docent/data_models/util.py +170 -0
  33. docent/judges/__init__.py +23 -0
  34. docent/judges/analysis.py +77 -0
  35. docent/judges/impl.py +587 -0
  36. docent/judges/runner.py +129 -0
  37. docent/judges/stats.py +205 -0
  38. docent/judges/types.py +320 -0
  39. docent/judges/util/forgiving_json.py +108 -0
  40. docent/judges/util/meta_schema.json +86 -0
  41. docent/judges/util/meta_schema.py +29 -0
  42. docent/judges/util/parse_output.py +68 -0
  43. docent/judges/util/voting.py +139 -0
  44. docent/loaders/load_inspect.py +215 -0
  45. docent/py.typed +0 -0
  46. docent/samples/__init__.py +3 -0
  47. docent/samples/load.py +9 -0
  48. docent/samples/log.eval +0 -0
  49. docent/samples/tb_airline.json +1 -0
  50. docent/sdk/__init__.py +0 -0
  51. docent/sdk/agent_run_writer.py +317 -0
  52. docent/sdk/client.py +1186 -0
  53. docent/sdk/llm_context.py +432 -0
  54. docent/trace.py +2741 -0
  55. docent/trace_temp.py +1086 -0
  56. docent_python-0.1.41a0.dist-info/METADATA +33 -0
  57. docent_python-0.1.41a0.dist-info/RECORD +59 -0
  58. docent_python-0.1.41a0.dist-info/WHEEL +4 -0
  59. docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
docent/sdk/client.py ADDED
@@ -0,0 +1,1186 @@
1
+ import gzip
2
+ import itertools
3
+ import json
4
+ import os
5
+ import time
6
+ import webbrowser
7
+ from pathlib import Path
8
+ from typing import Any, Iterator, Literal
9
+
10
+ import pandas as pd
11
+ import requests
12
+ from pydantic_core import to_jsonable_python
13
+ from tqdm import tqdm
14
+
15
+ from docent._log_util.logger import get_logger
16
+ from docent.data_models.agent_run import AgentRun
17
+ from docent.data_models.judge import Label
18
+ from docent.judges.util.meta_schema import validate_judge_result_schema
19
+ from docent.loaders import load_inspect
20
+ from docent.sdk.llm_context import LLMContext, LLMContextItem
21
+
22
+ MAX_AGENT_RUN_PAYLOAD_BYTES = 100 * 1024 * 1024 # 100MB backend limit
23
+ _AGENT_RUNS_PAYLOAD_PREFIX = b'{"agent_runs":['
24
+ _AGENT_RUNS_PAYLOAD_SUFFIX = b"]}"
25
+
26
+
27
+ def _serialize_agent_run(agent_run: AgentRun) -> bytes:
28
+ """Serialize an AgentRun to compact JSON bytes."""
29
+ return json.dumps(to_jsonable_python(agent_run), separators=(",", ":")).encode("utf-8")
30
+
31
+
32
+ def _build_agent_runs_payload(serialized_runs: list[bytes]) -> bytes:
33
+ """Wrap serialized individual runs into the API payload envelope."""
34
+ body = b",".join(serialized_runs)
35
+ return _AGENT_RUNS_PAYLOAD_PREFIX + body + _AGENT_RUNS_PAYLOAD_SUFFIX
36
+
37
+
38
+ def _yield_agent_run_batches_by_size(
39
+ agent_runs: list[AgentRun], max_payload_bytes: int
40
+ ) -> Iterator[tuple[int, bytes]]:
41
+ """Yield batches of agent runs whose serialized payloads stay within max_payload_bytes."""
42
+ envelope_len = len(_AGENT_RUNS_PAYLOAD_PREFIX) + len(_AGENT_RUNS_PAYLOAD_SUFFIX)
43
+ comma_len = 1
44
+
45
+ current_serialized: list[bytes] = []
46
+ current_size = envelope_len
47
+
48
+ for agent_run in agent_runs:
49
+ serialized = _serialize_agent_run(agent_run)
50
+ serialized_len = len(serialized)
51
+
52
+ if envelope_len + serialized_len > max_payload_bytes:
53
+ raise ValueError(
54
+ f"A single agent run (id={agent_run.id}) exceeds the maximum payload size of "
55
+ f"{max_payload_bytes} bytes. Reduce the size of that run before uploading."
56
+ )
57
+
58
+ delimiter = 0 if not current_serialized else comma_len
59
+ projected_size = current_size + delimiter + serialized_len
60
+
61
+ # If adding the next run would exceed the max payload size, yield the current batch
62
+ if current_serialized and projected_size > max_payload_bytes:
63
+ yield len(current_serialized), _build_agent_runs_payload(current_serialized)
64
+
65
+ # Add the "next run" as the first run in the next batch
66
+ current_serialized = [serialized]
67
+ current_size = envelope_len + serialized_len
68
+ # Otherwise, add to the current batch and continue
69
+ else:
70
+ current_serialized.append(serialized)
71
+ current_size = projected_size
72
+
73
+ if current_serialized:
74
+ yield len(current_serialized), _build_agent_runs_payload(current_serialized)
75
+
76
+
77
+ logger = get_logger(__name__)
78
+
79
+
80
+ class Docent:
81
+ """Client for interacting with the Docent API.
82
+
83
+ This client provides methods for creating and managing Collections,
84
+ dimensions, agent runs, and filters in the Docent system. It handles
85
+ authentication via API keys and provides a high-level interface for
86
+ logging, querying, and analyzing agent traces.
87
+
88
+ Example:
89
+ >>> from docent import Docent
90
+ >>> client = Docent(api_key="your-api-key")
91
+ >>> collection_id = client.create_collection(name="My Collection")
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ *,
97
+ domain: str = "docent.transluce.org",
98
+ use_https: bool = True,
99
+ api_key: str | None = None,
100
+ # Deprecated
101
+ server_url: str | None = None, # Use domain instead
102
+ web_url: str | None = None, # Use domain instead
103
+ ):
104
+ """Initialize the Docent client.
105
+
106
+ Args:
107
+ domain: The domain of the Docent instance. Defaults to "docent.transluce.org".
108
+ The API and web URLs will be constructed from this domain automatically.
109
+ api_key: API key for authentication. If not provided, will attempt to read
110
+ from the DOCENT_API_KEY environment variable.
111
+ server_url: (Deprecated) Direct URL of the Docent API server. Use `domain` instead.
112
+ web_url: (Deprecated) Direct URL of the Docent web UI. Use `domain` instead.
113
+
114
+ Raises:
115
+ ValueError: If no API key is provided and DOCENT_API_KEY is not set.
116
+
117
+ Example:
118
+ >>> client = Docent(domain="my-instance.docent.com", api_key="sk-...")
119
+ """
120
+ # Warn about deprecated parameters
121
+ if server_url is not None:
122
+ logger.warning(
123
+ "The 'server_url' parameter is deprecated and will be removed in a future version. "
124
+ "Please use 'domain' instead."
125
+ )
126
+ if web_url is not None:
127
+ logger.warning(
128
+ "The 'web_url' parameter is deprecated and will be removed in a future version. "
129
+ "Please use 'domain' instead."
130
+ )
131
+
132
+ self._domain = domain
133
+
134
+ # Set server URL; server_url takes precedence over domain
135
+ prefix = "https://" if use_https else "http://"
136
+ server_url = (server_url or f"{prefix}api.{domain}").rstrip("/")
137
+ if not server_url.endswith("/rest"):
138
+ server_url = f"{server_url}/rest"
139
+ self._server_url = server_url
140
+
141
+ # Set web URL; web_url takes precedence over domain
142
+ self._web_url = (web_url or f"{prefix}{domain}").rstrip("/")
143
+
144
+ # Use requests.Session for connection pooling and persistent headers
145
+ self._session = requests.Session()
146
+
147
+ api_key = api_key or os.getenv("DOCENT_API_KEY")
148
+
149
+ if api_key is None:
150
+ raise ValueError(
151
+ "api_key is required. Please provide an "
152
+ "api_key or set the DOCENT_API_KEY environment variable."
153
+ )
154
+
155
+ self._login(api_key)
156
+
157
+ def _handle_response_errors(self, response: requests.Response):
158
+ """Handle API response and raise informative errors."""
159
+ if response.status_code >= 400:
160
+ try:
161
+ error_data = response.json()
162
+ detail = error_data.get("detail", response.text)
163
+ except Exception:
164
+ detail = response.text
165
+
166
+ raise requests.HTTPError(f"HTTP {response.status_code}: {detail}", response=response)
167
+
168
+ def _login(self, api_key: str):
169
+ """Login with email/password to establish session."""
170
+ self._session.headers.update({"Authorization": f"Bearer {api_key}"})
171
+
172
+ url = f"{self._server_url}/api-keys/test"
173
+ response = self._session.get(url)
174
+ self._handle_response_errors(response)
175
+
176
+ logger.info("Logged in with API key")
177
+ return
178
+
179
+ def create_collection(
180
+ self,
181
+ collection_id: str | None = None,
182
+ name: str | None = None,
183
+ description: str | None = None,
184
+ ) -> str:
185
+ """Creates a new Collection.
186
+
187
+ Creates a new Collection and sets up a default MECE dimension
188
+ for grouping on the homepage.
189
+
190
+ Args:
191
+ collection_id: Optional ID for the new Collection. If not provided, one will be generated.
192
+ name: Optional name for the Collection.
193
+ description: Optional description for the Collection.
194
+
195
+ Returns:
196
+ str: The ID of the created Collection.
197
+
198
+ Raises:
199
+ ValueError: If the response is missing the Collection ID.
200
+ requests.exceptions.HTTPError: If the API request fails.
201
+ """
202
+ url = f"{self._server_url}/create"
203
+ payload = {
204
+ "collection_id": collection_id,
205
+ "name": name,
206
+ "description": description,
207
+ }
208
+
209
+ response = self._session.post(url, json=payload)
210
+ self._handle_response_errors(response)
211
+
212
+ response_data = response.json()
213
+ collection_id = response_data.get("collection_id")
214
+ if collection_id is None:
215
+ raise ValueError("Failed to create collection: 'collection_id' missing in response.")
216
+
217
+ logger.info(f"Successfully created Collection with id='{collection_id}'")
218
+
219
+ logger.info(
220
+ f"Collection creation complete. Frontend available at: {self._web_url}/dashboard/{collection_id}"
221
+ )
222
+ return collection_id
223
+
224
+ def update_collection(
225
+ self,
226
+ collection_id: str,
227
+ name: str | None = None,
228
+ description: str | None = None,
229
+ ) -> None:
230
+ """Updates a Collection's name and/or description.
231
+
232
+ Requires WRITE permission on the collection.
233
+
234
+ Args:
235
+ collection_id: ID of the Collection to update.
236
+ name: New name for the Collection. If None, the name will be cleared.
237
+ description: New description for the Collection. If None, the description will be cleared.
238
+
239
+ Raises:
240
+ requests.exceptions.HTTPError: If the API request fails.
241
+ """
242
+ url = f"{self._server_url}/{collection_id}/collection"
243
+ payload: dict[str, Any] = {}
244
+ if name is not None:
245
+ payload["name"] = name
246
+ if description is not None:
247
+ payload["description"] = description
248
+
249
+ response = self._session.put(url, json=payload)
250
+ self._handle_response_errors(response)
251
+
252
+ logger.info(f"Successfully updated Collection '{collection_id}'")
253
+
254
+ def add_agent_runs(
255
+ self,
256
+ collection_id: str,
257
+ agent_runs: list[AgentRun],
258
+ *,
259
+ compression: Literal["gzip", "none"] = "gzip",
260
+ wait: bool = True,
261
+ poll_interval: float = 1.0,
262
+ # Deprecated
263
+ batch_size: int | None = None,
264
+ ) -> dict[str, Any]:
265
+ """Adds agent runs to a Collection.
266
+
267
+ Agent runs represent execution traces that can be visualized and analyzed.
268
+ Requests are automatically chunked to stay under the backend's payload limit.
269
+
270
+ Args:
271
+ collection_id: ID of the Collection.
272
+ agent_runs: List of AgentRun objects to add.
273
+ compression: Compression algorithm for request bodies. Defaults to gzip.
274
+ Set to "none" to retain legacy behavior.
275
+ wait: If True (default), wait for all ingestion jobs to complete before returning.
276
+ If False, return immediately after enqueuing jobs.
277
+ poll_interval: Seconds between status checks when wait=True. Defaults to 1.0.
278
+
279
+ Returns:
280
+ dict: API response data containing:
281
+ - status: "success" if all jobs completed, "enqueued" if wait=False
282
+ - total_runs_added: Number of agent runs submitted
283
+ - job_ids: List of job IDs for tracking
284
+
285
+ Raises:
286
+ ValueError: If any single agent run exceeds the maximum payload size.
287
+ requests.exceptions.HTTPError: If the API request fails.
288
+ RuntimeError: If any job fails during processing (when wait=True).
289
+ """
290
+
291
+ if batch_size is not None:
292
+ logger.warning(
293
+ "The 'batch_size' parameter is deprecated and will be removed in a future version. "
294
+ "We have transitioned to a new batching strategy based on the size of the payload."
295
+ )
296
+
297
+ url = f"{self._server_url}/{collection_id}/agent_runs"
298
+ total_runs = len(agent_runs)
299
+ job_ids: list[str] = []
300
+
301
+ # Process agent runs in batches
302
+ desc = f"Uploading agent runs (compression={compression})"
303
+ with tqdm(total=total_runs, desc=desc, unit="runs") as pbar:
304
+ for batch_size, payload_bytes in _yield_agent_run_batches_by_size(
305
+ agent_runs, MAX_AGENT_RUN_PAYLOAD_BYTES
306
+ ):
307
+ request_kwargs: dict[str, Any] = {}
308
+ if compression == "none":
309
+ request_kwargs["data"] = payload_bytes
310
+ request_kwargs["headers"] = {"Content-Type": "application/json"}
311
+ elif compression == "gzip":
312
+ request_kwargs["data"] = gzip.compress(payload_bytes)
313
+ request_kwargs["headers"] = {
314
+ "Content-Type": "application/json",
315
+ "Content-Encoding": "gzip",
316
+ }
317
+ else:
318
+ raise ValueError(f"Unsupported compression '{compression}'")
319
+
320
+ response = self._session.post(url, **request_kwargs)
321
+ self._handle_response_errors(response)
322
+
323
+ # Server returns 202 with job_id for async processing
324
+ response_data = response.json()
325
+ job_id = response_data.get("job_id")
326
+ if job_id:
327
+ job_ids.append(job_id)
328
+
329
+ pbar.update(batch_size)
330
+
331
+ if not wait:
332
+ logger.info(
333
+ f"Enqueued {total_runs} agent runs to Collection '{collection_id}' "
334
+ f"({len(job_ids)} job(s)). Use get_agent_run_job_status() to check progress."
335
+ )
336
+ return {
337
+ "status": "enqueued",
338
+ "total_runs_added": total_runs,
339
+ "job_ids": job_ids,
340
+ }
341
+
342
+ # Wait for all jobs to complete
343
+ if job_ids:
344
+ logger.info(
345
+ f"Uploaded {total_runs} agent runs in {len(job_ids)} batch(es). "
346
+ f"Waiting for server-side processing to complete... "
347
+ f"(set wait=False to skip waiting)"
348
+ )
349
+ self._wait_for_jobs(collection_id, job_ids, poll_interval)
350
+
351
+ logger.info(
352
+ f"Successfully added {total_runs} agent runs to Collection '{collection_id}'. "
353
+ f"All {len(job_ids)} job(s) completed."
354
+ )
355
+ return {"status": "success", "total_runs_added": total_runs, "job_ids": job_ids}
356
+
357
+ def _wait_for_jobs(
358
+ self,
359
+ collection_id: str,
360
+ job_ids: list[str],
361
+ poll_interval: float = 1.0,
362
+ ) -> None:
363
+ """Wait for all jobs to complete, showing progress.
364
+
365
+ Args:
366
+ collection_id: ID of the Collection.
367
+ job_ids: List of job IDs to wait for.
368
+ poll_interval: Seconds between status checks.
369
+
370
+ Raises:
371
+ RuntimeError: If any job fails or is canceled.
372
+ """
373
+ pending_jobs = set(job_ids)
374
+ failed_jobs: dict[str, str] = {}
375
+
376
+ with tqdm(total=len(job_ids), desc="Waiting for server processing", unit="jobs") as pbar:
377
+ while pending_jobs:
378
+ statuses = self.get_agent_run_job_statuses(collection_id, list(pending_jobs))
379
+
380
+ for job_status in statuses:
381
+ job_id = job_status["job_id"]
382
+ status = job_status["status"]
383
+
384
+ if status == "completed":
385
+ pending_jobs.discard(job_id)
386
+ pbar.update(1)
387
+ elif status == "canceled":
388
+ pending_jobs.discard(job_id)
389
+ failed_jobs[job_id] = "Job was canceled"
390
+ pbar.update(1)
391
+
392
+ if pending_jobs:
393
+ time.sleep(poll_interval)
394
+
395
+ if failed_jobs:
396
+ failed_msg = ", ".join(f"{k}: {v}" for k, v in failed_jobs.items())
397
+ raise RuntimeError(f"Some jobs failed: {failed_msg}")
398
+
399
+ def get_agent_run_job_statuses(
400
+ self, collection_id: str, job_ids: list[str]
401
+ ) -> list[dict[str, Any]]:
402
+ """Get the status of multiple agent run ingestion jobs.
403
+
404
+ Args:
405
+ collection_id: ID of the Collection.
406
+ job_ids: List of job IDs to check (max 100).
407
+
408
+ Returns:
409
+ list: List of job status dictionaries, each containing:
410
+ - job_id: The job ID
411
+ - status: One of "pending", "running", "completed", "canceled"
412
+ - type: The job type
413
+ - created_at: ISO timestamp of job creation
414
+
415
+ Raises:
416
+ ValueError: If more than 100 job IDs are provided.
417
+ requests.exceptions.HTTPError: If the API request fails.
418
+ """
419
+ if len(job_ids) > 100:
420
+ raise ValueError("Cannot request more than 100 job IDs at once")
421
+
422
+ url = f"{self._server_url}/{collection_id}/agent_runs/jobs/batch_status"
423
+ response = self._session.post(url, json={"job_ids": job_ids})
424
+ self._handle_response_errors(response)
425
+ return response.json()["jobs"]
426
+
427
+ def get_agent_run_job_status(self, collection_id: str, job_id: str) -> dict[str, Any]:
428
+ """Get the status of an agent run ingestion job.
429
+
430
+ Args:
431
+ collection_id: ID of the Collection.
432
+ job_id: The ID of the job to check.
433
+
434
+ Returns:
435
+ dict: Job status information including:
436
+ - job_id: The job ID
437
+ - status: One of "pending", "running", "completed", "canceled"
438
+ - type: The job type
439
+ - created_at: ISO timestamp of job creation
440
+
441
+ Raises:
442
+ requests.exceptions.HTTPError: If the API request fails.
443
+ """
444
+ url = f"{self._server_url}/{collection_id}/agent_runs/jobs/{job_id}"
445
+ response = self._session.get(url)
446
+ self._handle_response_errors(response)
447
+ return response.json()
448
+
449
+ def list_collections(self) -> list[dict[str, Any]]:
450
+ """Lists all available Collections.
451
+
452
+ Returns:
453
+ list: List of Collection objects.
454
+
455
+ Raises:
456
+ requests.exceptions.HTTPError: If the API request fails.
457
+ """
458
+ url = f"{self._server_url}/collections"
459
+ response = self._session.get(url)
460
+ self._handle_response_errors(response)
461
+ return response.json()
462
+
463
+ def get_collection(self, collection_id: str) -> dict[str, Any] | None:
464
+ """Get details about a specific Collection.
465
+
466
+ Requires READ permission on the collection.
467
+
468
+ Args:
469
+ collection_id: ID of the Collection to retrieve.
470
+
471
+ Returns:
472
+ Collection: Collection object with id, name, description, created_at, and created_by.
473
+ Returns None if collection not found.
474
+
475
+ Raises:
476
+ requests.exceptions.HTTPError: If the API request fails.
477
+ """
478
+ url = f"{self._server_url}/{collection_id}/collection_details"
479
+ response = self._session.get(url)
480
+ self._handle_response_errors(response)
481
+ return response.json()
482
+
483
+ def list_rubrics(self, collection_id: str) -> list[dict[str, Any]]:
484
+ """List all rubrics for a given collection.
485
+
486
+ Args:
487
+ collection_id: ID of the Collection.
488
+
489
+ Returns:
490
+ list: List of dictionaries containing rubric information.
491
+
492
+ Raises:
493
+ requests.exceptions.HTTPError: If the API request fails.
494
+ """
495
+ url = f"{self._server_url}/rubric/{collection_id}/rubrics"
496
+ response = self._session.get(url)
497
+ self._handle_response_errors(response)
498
+ return response.json()
499
+
500
+ def get_rubric_run_state(
501
+ self, collection_id: str, rubric_id: str, version: int | None = None
502
+ ) -> dict[str, Any]:
503
+ """Get rubric run state for a given collection and rubric.
504
+
505
+ Args:
506
+ collection_id: ID of the Collection.
507
+ rubric_id: The ID of the rubric to get run state for.
508
+ version: The version of the rubric to get run state for. If None, the latest version is used.
509
+
510
+ Returns:
511
+ dict: Dictionary containing rubric run state with results, job_id, and total_results_needed.
512
+
513
+ Raises:
514
+ requests.exceptions.HTTPError: If the API request fails.
515
+ """
516
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/rubric_run_state"
517
+ response = self._session.get(url, params={"version": version})
518
+ self._handle_response_errors(response)
519
+ return response.json()
520
+
521
+ def get_clustering_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
522
+ """Get clustering state for a given collection and rubric.
523
+
524
+ Args:
525
+ collection_id: ID of the Collection.
526
+ rubric_id: The ID of the rubric to get clustering state for.
527
+
528
+ Returns:
529
+ dict: Dictionary containing job_id, centroids, and assignments.
530
+
531
+ Raises:
532
+ requests.exceptions.HTTPError: If the API request fails.
533
+ """
534
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/clustering_job"
535
+ response = self._session.get(url)
536
+ self._handle_response_errors(response)
537
+ return response.json()
538
+
539
+ def get_cluster_centroids(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
540
+ """Get centroids for a given collection and rubric.
541
+
542
+ Args:
543
+ collection_id: ID of the Collection.
544
+ rubric_id: The ID of the rubric to get centroids for.
545
+
546
+ Returns:
547
+ list: List of dictionaries containing centroid information.
548
+
549
+ Raises:
550
+ requests.exceptions.HTTPError: If the API request fails.
551
+ """
552
+ clustering_state = self.get_clustering_state(collection_id, rubric_id)
553
+ return clustering_state.get("centroids", [])
554
+
555
+ def get_cluster_assignments(self, collection_id: str, rubric_id: str) -> dict[str, list[str]]:
556
+ """Get centroid assignments for a given rubric.
557
+
558
+ Args:
559
+ collection_id: ID of the Collection.
560
+ rubric_id: The ID of the rubric to get assignments for.
561
+
562
+ Returns:
563
+ dict: Dictionary mapping centroid IDs to lists of judge result IDs.
564
+
565
+ Raises:
566
+ requests.exceptions.HTTPError: If the API request fails.
567
+ """
568
+ clustering_state = self.get_clustering_state(collection_id, rubric_id)
569
+ return clustering_state.get("assignments", {})
570
+
571
+ def create_label_set(
572
+ self,
573
+ collection_id: str,
574
+ name: str,
575
+ label_schema: dict[str, Any],
576
+ description: str | None = None,
577
+ ) -> str:
578
+ """Create a new label set with a JSON schema.
579
+
580
+ Args:
581
+ collection_id: ID of the collection.
582
+ name: Name of the label set.
583
+ label_schema: JSON schema for validating labels in this set.
584
+ description: Optional description of the label set.
585
+
586
+ Returns:
587
+ str: The ID of the created label set.
588
+
589
+ Raises:
590
+ ValueError: If the response is missing the label_set_id.
591
+ jsonschema.ValidationError: If the label schema is invalid.
592
+ requests.exceptions.HTTPError: If the API request fails.
593
+ """
594
+ validate_judge_result_schema(label_schema)
595
+
596
+ url = f"{self._server_url}/label/{collection_id}/label_set"
597
+ payload = {
598
+ "name": name,
599
+ "label_schema": label_schema,
600
+ "description": description,
601
+ }
602
+ response = self._session.post(url, json=payload)
603
+ self._handle_response_errors(response)
604
+ return response.json()["label_set_id"]
605
+
606
+ def add_label(
607
+ self,
608
+ collection_id: str,
609
+ label: Label,
610
+ ) -> dict[str, str]:
611
+ """Create a label in a label set.
612
+
613
+ Args:
614
+ collection_id: ID of the Collection.
615
+ label: A `Label` object that must comply with the label set's schema.
616
+
617
+ Returns:
618
+ dict: API response containing the label_id.
619
+
620
+ Raises:
621
+ requests.exceptions.HTTPError: If the API request fails or validation errors occur.
622
+ """
623
+ url = f"{self._server_url}/label/{collection_id}/label"
624
+ payload = {"label": label.model_dump(mode="json")}
625
+ response = self._session.post(url, json=payload)
626
+ self._handle_response_errors(response)
627
+ return response.json()
628
+
629
+ def add_labels(
630
+ self,
631
+ collection_id: str,
632
+ labels: list[Label],
633
+ ) -> dict[str, Any]:
634
+ """Create multiple labels.
635
+
636
+ Args:
637
+ collection_id: ID of the Collection.
638
+ labels: List of `Label` objects.
639
+
640
+ Returns:
641
+ dict: API response containing label_ids list and optional errors list.
642
+
643
+ Raises:
644
+ ValueError: If no labels are provided.
645
+ requests.exceptions.HTTPError: If the API request fails.
646
+ """
647
+ if not labels:
648
+ raise ValueError("labels must contain at least one entry")
649
+
650
+ url = f"{self._server_url}/label/{collection_id}/labels"
651
+ payload = {"labels": [label.model_dump(mode="json") for label in labels]}
652
+ response = self._session.post(url, json=payload)
653
+ self._handle_response_errors(response)
654
+ return response.json()
655
+
656
+ def get_labels(
657
+ self, collection_id: str, label_set_id: str, filter_valid_labels: bool = False
658
+ ) -> list[dict[str, Any]]:
659
+ """Retrieve all labels in a label set.
660
+
661
+ Args:
662
+ collection_id: ID of the Collection.
663
+ label_set_id: ID of the label set to fetch labels for.
664
+ filter_valid_labels: If True, only return labels that match the label set schema
665
+ INCLUDING requirements. Default is False (returns all labels).
666
+
667
+ Returns:
668
+ list: List of label dictionaries.
669
+
670
+ Raises:
671
+ requests.exceptions.HTTPError: If the API request fails.
672
+ """
673
+ url = f"{self._server_url}/label/{collection_id}/label_set/{label_set_id}/labels"
674
+ params = {"filter_valid_labels": filter_valid_labels}
675
+ response = self._session.get(url, params=params)
676
+ self._handle_response_errors(response)
677
+ return response.json()
678
+
679
+ def tag_transcript(self, collection_id: str, agent_run_id: str, value: str) -> None:
680
+ """Add a tag to an agent run transcript.
681
+
682
+ Args:
683
+ collection_id: ID of the Collection.
684
+ agent_run_id: The agent run to tag.
685
+ value: The tag value (max length enforced by the server).
686
+
687
+ Raises:
688
+ requests.exceptions.HTTPError: If the API request fails.
689
+ """
690
+ url = f"{self._server_url}/label/{collection_id}/tag"
691
+ payload = {"agent_run_id": agent_run_id, "value": value}
692
+ response = self._session.post(url, json=payload)
693
+ self._handle_response_errors(response)
694
+
695
+ def get_tags(self, collection_id: str, value: str | None = None) -> list[dict[str, Any]]:
696
+ """Get all tags in a collection, optionally filtered by value."""
697
+ url = f"{self._server_url}/label/{collection_id}/tags"
698
+ params = {"value": value} if value is not None else None
699
+ response = self._session.get(url, params=params)
700
+ self._handle_response_errors(response)
701
+ return response.json()
702
+
703
+ def get_tags_for_agent_run(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
704
+ """Get all tags attached to a specific agent run."""
705
+ url = f"{self._server_url}/label/{collection_id}/agent_run/{agent_run_id}/tags"
706
+ response = self._session.get(url)
707
+ self._handle_response_errors(response)
708
+ return response.json()
709
+
710
+ def delete_tag(self, collection_id: str, tag_id: str) -> None:
711
+ """Delete a tag by ID."""
712
+ url = f"{self._server_url}/label/{collection_id}/tag/{tag_id}"
713
+ response = self._session.delete(url)
714
+ self._handle_response_errors(response)
715
+
716
+ def get_agent_run(self, collection_id: str, agent_run_id: str) -> AgentRun | None:
717
+ """Get a specific agent run by its ID.
718
+
719
+ Args:
720
+ collection_id: ID of the Collection.
721
+ agent_run_id: The ID of the agent run to retrieve.
722
+
723
+ Returns:
724
+ dict: Dictionary containing the agent run information.
725
+
726
+ Raises:
727
+ requests.exceptions.HTTPError: If the API request fails.
728
+ """
729
+ url = f"{self._server_url}/{collection_id}/agent_run"
730
+ response = self._session.get(url, params={"agent_run_id": agent_run_id})
731
+ self._handle_response_errors(response)
732
+ if response.json() is None:
733
+ return None
734
+ else:
735
+ # We do this to avoid metadata validation failing
736
+ # TODO(mengk): kinda hacky
737
+ return AgentRun.model_validate(response.json())
738
+
739
+ def get_chat_sessions(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
740
+ """Get all chat sessions for an agent run, excluding judge result sessions.
741
+
742
+ Args:
743
+ collection_id: ID of the Collection.
744
+ agent_run_id: The ID of the agent run to retrieve chat sessions for.
745
+
746
+ Returns:
747
+ list: List of chat session dictionaries.
748
+
749
+ Raises:
750
+ requests.exceptions.HTTPError: If the API request fails.
751
+ """
752
+ url = f"{self._server_url}/chat/{collection_id}/{agent_run_id}/sessions"
753
+ response = self._session.get(url)
754
+ self._handle_response_errors(response)
755
+ return response.json()
756
+
757
+ def make_collection_public(self, collection_id: str) -> dict[str, Any]:
758
+ """Make a collection publicly accessible to anyone with the link.
759
+
760
+ Args:
761
+ collection_id: ID of the Collection to make public.
762
+
763
+ Returns:
764
+ dict: API response data.
765
+
766
+ Raises:
767
+ requests.exceptions.HTTPError: If the API request fails.
768
+ """
769
+ url = f"{self._server_url}/{collection_id}/make_public"
770
+ response = self._session.post(url)
771
+ self._handle_response_errors(response)
772
+
773
+ logger.info(f"Successfully made Collection '{collection_id}' public")
774
+ return response.json()
775
+
776
+ def share_collection_with_email(self, collection_id: str, email: str) -> dict[str, Any]:
777
+ """Share a collection with a specific user by email address.
778
+
779
+ Args:
780
+ collection_id: ID of the Collection to share.
781
+ email: Email address of the user to share with.
782
+
783
+ Returns:
784
+ dict: API response data.
785
+
786
+ Raises:
787
+ requests.exceptions.HTTPError: If the API request fails.
788
+ """
789
+ url = f"{self._server_url}/{collection_id}/share_with_email"
790
+ payload = {"email": email}
791
+ response = self._session.post(url, json=payload)
792
+
793
+ self._handle_response_errors(response)
794
+
795
+ logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
796
+ return response.json()
797
+
798
+ def collection_exists(self, collection_id: str) -> bool:
799
+ """Check if a collection exists without raising if it does not."""
800
+ url = f"{self._server_url}/{collection_id}/exists"
801
+ response = self._session.get(url)
802
+ self._handle_response_errors(response)
803
+ return bool(response.json())
804
+
805
+ def has_collection_permission(self, collection_id: str, permission: str = "write") -> bool:
806
+ """Check whether the authenticated user has a specific permission on a collection.
807
+
808
+ Args:
809
+ collection_id: Collection to check.
810
+ permission: Permission level to verify (`read`, `write`, or `admin`).
811
+
812
+ Returns:
813
+ bool: True if the current API key has the requested permission; otherwise False.
814
+
815
+ Raises:
816
+ ValueError: If an unsupported permission value is provided.
817
+ requests.exceptions.HTTPError: If the API request fails.
818
+ """
819
+ valid_permissions = {"read", "write", "admin"}
820
+ if permission not in valid_permissions:
821
+ raise ValueError(f"permission must be one of {sorted(valid_permissions)}")
822
+
823
+ url = f"{self._server_url}/{collection_id}/has_permission"
824
+ response = self._session.get(url, params={"permission": permission})
825
+ self._handle_response_errors(response)
826
+
827
+ payload = response.json()
828
+ return bool(payload.get("has_permission", False))
829
+
830
+ def get_dql_schema(self, collection_id: str) -> dict[str, Any]:
831
+ """Retrieve the DQL schema for a collection.
832
+
833
+ Args:
834
+ collection_id: ID of the Collection.
835
+
836
+ Returns:
837
+ dict: Dictionary containing available tables, columns, and metadata for DQL queries.
838
+
839
+ Raises:
840
+ requests.exceptions.HTTPError: If the API request fails.
841
+ """
842
+ url = f"{self._server_url}/dql/{collection_id}/schema"
843
+ response = self._session.get(url)
844
+ self._handle_response_errors(response)
845
+ return response.json()
846
+
847
+ def execute_dql(self, collection_id: str, dql: str) -> dict[str, Any]:
848
+ """Execute a DQL query against a collection.
849
+
850
+ Args:
851
+ collection_id: ID of the Collection.
852
+ dql: The DQL query string to execute.
853
+
854
+ Returns:
855
+ dict: Query execution results including rows, columns, execution metadata, and selected columns.
856
+
857
+ Raises:
858
+ ValueError: If `dql` is empty.
859
+ requests.exceptions.HTTPError: If the API request fails or the query is invalid.
860
+ """
861
+ if not dql.strip():
862
+ raise ValueError("dql must be a non-empty string")
863
+
864
+ url = f"{self._server_url}/dql/{collection_id}/execute"
865
+ response = self._session.post(url, json={"dql": dql})
866
+ self._handle_response_errors(response)
867
+ return response.json()
868
+
869
+ def dql_result_to_dicts(self, dql_result: dict[str, Any]) -> list[dict[str, Any]]:
870
+ """Convert a DQL result to a list of dictionaries."""
871
+ cols = dql_result["columns"]
872
+ rows = dql_result["rows"]
873
+ return [dict(zip(cols, row)) for row in rows]
874
+
875
+ def dql_result_to_df_experimental(self, dql_result: dict[str, Any]):
876
+ """The implementation is not stable by any means!"""
877
+
878
+ cols = dql_result["columns"]
879
+ rows = dql_result["rows"]
880
+
881
+ def _cast_value(v: Any) -> Any:
882
+ """Cast a value to int, float, bool, or str as appropriate."""
883
+ if v is None:
884
+ return None
885
+ if isinstance(v, (bool, int, float)):
886
+ return v
887
+
888
+ # If a string, try to cast into a number
889
+ if isinstance(v, str):
890
+ try:
891
+ if "." not in v:
892
+ return int(v)
893
+ except (ValueError, TypeError):
894
+ pass
895
+
896
+ try:
897
+ return float(v)
898
+ except (ValueError, TypeError):
899
+ pass
900
+
901
+ # Keep as original
902
+ return v
903
+
904
+ dicts: list[dict[str, Any]] = []
905
+ for row in rows:
906
+ combo = list(zip(cols, row))
907
+ combo = {k: _cast_value(v) for k, v in combo}
908
+ dicts.append(combo)
909
+
910
+ return pd.DataFrame(dicts)
911
+
912
+ def select_agent_run_ids(
913
+ self,
914
+ collection_id: str,
915
+ where_clause: str | None = None,
916
+ limit: int | None = None,
917
+ ) -> list[str]:
918
+ """Convenience helper to fetch agent run IDs via DQL.
919
+
920
+ Args:
921
+ collection_id: ID of the Collection to query.
922
+ where_clause: Optional DQL WHERE clause applied to the agent_runs table.
923
+ limit: Optional LIMIT applied to the underlying DQL query.
924
+
925
+ Returns:
926
+ list[str]: Agent run IDs matching the criteria.
927
+
928
+ Raises:
929
+ ValueError: If the inputs are invalid.
930
+ requests.exceptions.HTTPError: If the API request fails.
931
+ """
932
+ query = "SELECT agent_runs.id AS agent_run_id FROM agent_runs"
933
+
934
+ if where_clause:
935
+ where_clause = where_clause.strip()
936
+ if not where_clause:
937
+ raise ValueError("where_clause must be a non-empty string when provided")
938
+ query += f" WHERE {where_clause}"
939
+
940
+ if limit is not None:
941
+ if limit <= 0:
942
+ raise ValueError("limit must be a positive integer when provided")
943
+ query += f" LIMIT {limit}"
944
+
945
+ result = self.execute_dql(collection_id, query)
946
+ rows = result.get("rows", [])
947
+ agent_run_ids = [str(row[0]) for row in rows if row]
948
+
949
+ if result.get("truncated"):
950
+ logger.warning(
951
+ "DQL query truncated at applied limit %s; returning %s agent run IDs",
952
+ result.get("applied_limit"),
953
+ len(agent_run_ids),
954
+ )
955
+
956
+ return agent_run_ids
957
+
958
+ def list_agent_run_ids(self, collection_id: str) -> list[str]:
959
+ """Get all agent run IDs for a collection.
960
+
961
+ Args:
962
+ collection_id: ID of the Collection.
963
+
964
+ Returns:
965
+ str: JSON string containing the list of agent run IDs.
966
+
967
+ Raises:
968
+ requests.exceptions.HTTPError: If the API request fails.
969
+ """
970
+ url = f"{self._server_url}/{collection_id}/agent_run_ids"
971
+ response = self._session.get(url)
972
+ self._handle_response_errors(response)
973
+ return response.json()
974
+
975
+ def recursively_ingest_inspect_logs(self, collection_id: str, fpath: str):
976
+ """Recursively search directory for .eval files and ingest them as agent runs.
977
+
978
+ Args:
979
+ collection_id: ID of the Collection to add agent runs to.
980
+ fpath: Path to directory to search recursively.
981
+
982
+ Raises:
983
+ ValueError: If the path doesn't exist or isn't a directory.
984
+ requests.exceptions.HTTPError: If any API requests fail.
985
+ """
986
+ root_path = Path(fpath)
987
+ if not root_path.exists():
988
+ raise ValueError(f"Path does not exist: {fpath}")
989
+ if not root_path.is_dir():
990
+ raise ValueError(f"Path is not a directory: {fpath}")
991
+
992
+ # Find all .eval files recursively
993
+ eval_files = list(root_path.rglob("*.eval"))
994
+
995
+ if not eval_files:
996
+ logger.info(f"No .eval files found in {fpath}")
997
+ return
998
+
999
+ logger.info(f"Found {len(eval_files)} .eval files in {fpath}")
1000
+
1001
+ total_runs_added = 0
1002
+ batch_size = 100
1003
+
1004
+ # Process each .eval file
1005
+ for eval_file in tqdm(eval_files, desc="Processing .eval files", unit="files"):
1006
+ # Get total samples for progress tracking
1007
+ total_samples = load_inspect.get_total_samples(eval_file, format="eval")
1008
+
1009
+ if total_samples == 0:
1010
+ logger.info(f"No samples found in {eval_file}")
1011
+ continue
1012
+
1013
+ # Load runs from file
1014
+ with open(eval_file, "rb") as f:
1015
+ _, runs_generator = load_inspect.runs_from_file(f, format="eval")
1016
+
1017
+ # Process runs in batches
1018
+ runs_from_file = 0
1019
+ batches = itertools.batched(runs_generator, batch_size)
1020
+
1021
+ with tqdm(
1022
+ total=total_samples,
1023
+ desc=f"Processing {eval_file.name}",
1024
+ unit="runs",
1025
+ leave=False,
1026
+ ) as file_pbar:
1027
+ for batch in batches:
1028
+ batch_list = list(batch) # Convert generator batch to list
1029
+ if not batch_list:
1030
+ break
1031
+
1032
+ # Add batch to collection
1033
+ url = f"{self._server_url}/{collection_id}/agent_runs"
1034
+ payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch_list]}
1035
+
1036
+ response = self._session.post(url, json=payload)
1037
+ self._handle_response_errors(response)
1038
+
1039
+ runs_from_file += len(batch_list)
1040
+ file_pbar.update(len(batch_list))
1041
+
1042
+ total_runs_added += runs_from_file
1043
+ logger.info(f"Added {runs_from_file} runs from {eval_file}")
1044
+
1045
+ logger.info(
1046
+ f"Successfully ingested {total_runs_added} total agent runs from {len(eval_files)} files"
1047
+ )
1048
+
1049
+ def start_chat(
1050
+ self,
1051
+ context: LLMContext | list[LLMContextItem],
1052
+ model_string: str | None = None,
1053
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
1054
+ ) -> str:
1055
+ """Start a chat session with multiple objects and open it in the browser.
1056
+
1057
+ This method creates a new chat session with the provided objects (agent runs,
1058
+ transcripts, or formatted versions) and opens the chat UI in your default browser.
1059
+
1060
+ Args:
1061
+ objects: List of objects to include in the chat context. Can include:
1062
+ - AgentRun or FormattedAgentRun instances
1063
+ - Transcript or FormattedTranscript instances
1064
+ chat_model: Optional model to use for the chat. If None, uses default.
1065
+
1066
+ Returns:
1067
+ str: The session ID of the created chat session.
1068
+
1069
+ Raises:
1070
+ ValueError: If objects list is empty or contains unsupported types.
1071
+ requests.exceptions.HTTPError: If the API request fails.
1072
+
1073
+ Example:
1074
+ ```python
1075
+ from docent.sdk import Docent
1076
+
1077
+ client = Docent()
1078
+ run1 = client.get_agent_run(collection_id, run_id_1)
1079
+ run2 = client.get_agent_run(collection_id, run_id_2)
1080
+
1081
+ session_id = client.start_chat([run1, run2])
1082
+ # Opens browser to chat UI
1083
+ ```
1084
+ """
1085
+ if isinstance(context, list):
1086
+ context = LLMContext(items=context)
1087
+ else:
1088
+ context = context
1089
+
1090
+ serialized_context = context.to_dict()
1091
+
1092
+ url = f"{self._server_url}/chat/start"
1093
+ payload = {
1094
+ "context_serialized": serialized_context,
1095
+ "model_string": model_string,
1096
+ "reasoning_effort": reasoning_effort,
1097
+ }
1098
+
1099
+ response = self._session.post(url, json=payload)
1100
+ self._handle_response_errors(response)
1101
+
1102
+ response_data = response.json()
1103
+ session_id = response_data.get("session_id")
1104
+ if not session_id:
1105
+ raise ValueError("Failed to create chat session: 'session_id' missing in response")
1106
+
1107
+ chat_url = f"{self._web_url}/chat/{session_id}"
1108
+ logger.info(f"Chat session created. Opening browser to: {chat_url}")
1109
+
1110
+ webbrowser.open(chat_url)
1111
+
1112
+ return session_id
1113
+
1114
+ def open_agent_run(self, collection_id: str, agent_run_id: str) -> str:
1115
+ """Open an agent run in the browser.
1116
+
1117
+ Args:
1118
+ collection_id: ID of the Collection containing the agent run.
1119
+ agent_run_id: ID of the agent run to open.
1120
+
1121
+ Returns:
1122
+ str: The URL that was opened.
1123
+
1124
+ Example:
1125
+ ```python
1126
+ from docent.sdk import Docent
1127
+
1128
+ client = Docent()
1129
+ client.open_agent_run(collection_id, agent_run_id)
1130
+ # Opens browser to agent run page
1131
+ ```
1132
+ """
1133
+ agent_run_url = f"{self._web_url}/dashboard/{collection_id}/agent_run/{agent_run_id}"
1134
+ logger.info(f"Opening agent run in browser: {agent_run_url}")
1135
+
1136
+ webbrowser.open(agent_run_url)
1137
+
1138
+ return agent_run_url
1139
+
1140
+ def open_rubric(
1141
+ self,
1142
+ collection_id: str,
1143
+ rubric_id: str,
1144
+ agent_run_id: str | None = None,
1145
+ judge_result_id: str | None = None,
1146
+ ) -> str:
1147
+ """Open a rubric, agent run, or judge result in the browser.
1148
+
1149
+ Args:
1150
+ collection_id: ID of the Collection.
1151
+ rubric_id: ID of the rubric.
1152
+ agent_run_id: Optional ID of the agent run to view within the rubric.
1153
+ judge_result_id: Optional ID of the judge result to view. Requires agent_run_id.
1154
+
1155
+ Returns:
1156
+ str: The URL that was opened.
1157
+
1158
+ Raises:
1159
+ ValueError: If judge_result_id is provided without agent_run_id.
1160
+
1161
+ Example:
1162
+ ```python
1163
+ from docent.sdk import Docent
1164
+
1165
+ client = Docent()
1166
+ # Open rubric overview
1167
+ client.open_rubric(collection_id, rubric_id)
1168
+ # Open specific agent run within rubric
1169
+ client.open_rubric(collection_id, rubric_id, agent_run_id)
1170
+ # Open specific judge result
1171
+ client.open_rubric(collection_id, rubric_id, agent_run_id, judge_result_id)
1172
+ ```
1173
+ """
1174
+ if judge_result_id is not None and agent_run_id is None:
1175
+ raise ValueError("judge_result_id requires agent_run_id to be specified")
1176
+
1177
+ url = f"{self._web_url}/dashboard/{collection_id}/rubric/{rubric_id}"
1178
+ if agent_run_id is not None:
1179
+ url += f"/agent_run/{agent_run_id}"
1180
+ if judge_result_id is not None:
1181
+ url += f"/result/{judge_result_id}"
1182
+
1183
+ logger.info(f"Opening rubric in browser: {url}")
1184
+ webbrowser.open(url)
1185
+
1186
+ return url