docent-python 0.1.26a0__tar.gz → 0.1.28a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/.gitignore +2 -0
  2. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/PKG-INFO +1 -1
  3. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/llm_svc.py +3 -3
  4. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/model_registry.py +4 -0
  5. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/anthropic.py +1 -1
  6. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/openai.py +1 -6
  7. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/agent_run.py +1 -0
  8. docent_python-0.1.28a0/docent/judges/runner.py +129 -0
  9. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/sdk/client.py +118 -1
  10. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/trace.py +268 -103
  11. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/pyproject.toml +1 -1
  12. docent_python-0.1.26a0/docent/judges/runner.py +0 -66
  13. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/LICENSE.md +0 -0
  14. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/README.md +0 -0
  15. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/__init__.py +0 -0
  16. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/__init__.py +0 -0
  17. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/data_models/__init__.py +0 -0
  18. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/data_models/exceptions.py +0 -0
  19. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/data_models/llm_output.py +0 -0
  20. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/llm_cache.py +0 -0
  21. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/__init__.py +0 -0
  22. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/common.py +0 -0
  23. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/google.py +0 -0
  24. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/openrouter.py +0 -0
  25. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/preference_types.py +0 -0
  26. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_llm_util/providers/provider_registry.py +0 -0
  27. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_log_util/__init__.py +0 -0
  28. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/_log_util/logger.py +0 -0
  29. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/__init__.py +0 -0
  30. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/_tiktoken_util.py +0 -0
  31. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/chat/__init__.py +0 -0
  32. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/chat/content.py +0 -0
  33. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/chat/message.py +0 -0
  34. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/chat/tool.py +0 -0
  35. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/citation.py +0 -0
  36. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/judge.py +0 -0
  37. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/metadata_util.py +0 -0
  38. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/regex.py +0 -0
  39. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/remove_invalid_citation_ranges.py +0 -0
  40. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/shared_types.py +0 -0
  41. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/transcript.py +0 -0
  42. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/data_models/util.py +0 -0
  43. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/__init__.py +0 -0
  44. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/analysis.py +0 -0
  45. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/impl.py +0 -0
  46. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/stats.py +0 -0
  47. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/types.py +0 -0
  48. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/util/forgiving_json.py +0 -0
  49. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/util/meta_schema.json +0 -0
  50. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/util/meta_schema.py +0 -0
  51. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/util/parse_output.py +0 -0
  52. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/judges/util/voting.py +0 -0
  53. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/loaders/load_inspect.py +0 -0
  54. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/py.typed +0 -0
  55. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/samples/__init__.py +0 -0
  56. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/samples/load.py +0 -0
  57. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/samples/log.eval +0 -0
  58. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/samples/tb_airline.json +0 -0
  59. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/sdk/__init__.py +0 -0
  60. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/sdk/agent_run_writer.py +0 -0
  61. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/docent/trace_temp.py +0 -0
  62. {docent_python-0.1.26a0 → docent_python-0.1.28a0}/uv.lock +0 -0
@@ -13,6 +13,8 @@
13
13
  */.terraform/
14
14
  */*.terraform.*
15
15
 
16
+ .idea/
17
+
16
18
  # Byte-compiled / optimized / DLL files
17
19
  __pycache__/
18
20
  *.py[cod]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.26a0
3
+ Version: 0.1.28a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -75,7 +75,7 @@ async def _parallelize_calls(
75
75
  completion_callback: AsyncLLMOutputStreamingCallback | None,
76
76
  # Arguments for the individual completion getter
77
77
  client: Any,
78
- inputs: list[MessagesInput],
78
+ inputs: Sequence[MessagesInput],
79
79
  model_name: str,
80
80
  tools: list[ToolInfo] | None,
81
81
  tool_choice: Literal["auto", "required"] | None,
@@ -306,7 +306,7 @@ async def _parallelize_calls(
306
306
 
307
307
  class BaseLLMService:
308
308
  def __init__(self, max_concurrency: int = DEFAULT_SVC_MAX_CONCURRENCY):
309
- self._semaphore = Semaphore(max_concurrency)
309
+ self.max_concurrency, self._semaphore = max_concurrency, Semaphore(max_concurrency)
310
310
  self._client_cache: dict[tuple[str, str | None], Any] = {} # (provider, api_key) -> client
311
311
  self._client_cache_lock = Lock()
312
312
 
@@ -326,7 +326,7 @@ class BaseLLMService:
326
326
  async def get_completions(
327
327
  self,
328
328
  *,
329
- inputs: list[MessagesInput],
329
+ inputs: Sequence[MessagesInput],
330
330
  model_options: list[ModelOption],
331
331
  tools: list[ToolInfo] | None = None,
332
332
  tool_choice: Literal["auto", "required"] | None = None,
@@ -54,6 +54,10 @@ _REGISTRY: list[tuple[str, ModelInfo]] = [
54
54
  "claude-sonnet-4",
55
55
  ModelInfo(rate={"input": 3.0, "output": 15.0}, context_window=200_000),
56
56
  ),
57
+ (
58
+ "claude-haiku-4-5",
59
+ ModelInfo(rate={"input": 1.0, "output": 5.0}, context_window=200_000),
60
+ ),
57
61
  (
58
62
  "gemini-2.5-flash-lite",
59
63
  ModelInfo(
@@ -178,7 +178,7 @@ def _parse_tool_choice(tool_choice: Literal["auto", "required"] | None) -> ToolC
178
178
 
179
179
  def _convert_anthropic_error(e: Exception):
180
180
  if isinstance(e, BadRequestError):
181
- if "context limit" in e.message.lower():
181
+ if "context limit" in e.message.lower() or "prompt is too long" in e.message.lower():
182
182
  return ContextWindowException()
183
183
  if isinstance(e, RateLimitError):
184
184
  return RateLimitException(e)
@@ -18,13 +18,8 @@ from openai import (
18
18
  PermissionDeniedError,
19
19
  RateLimitError,
20
20
  UnprocessableEntityError,
21
+ omit,
21
22
  )
22
- try:
23
- from openai import omit
24
- except ImportError:
25
- from openai import Omit as _OpenAIOmit
26
-
27
- omit = _OpenAIOmit()
28
23
  from openai.types.chat import (
29
24
  ChatCompletion,
30
25
  ChatCompletionAssistantMessageParam,
@@ -125,6 +125,7 @@ class AgentRun(BaseModel):
125
125
  # )
126
126
 
127
127
  # Append the text field
128
+ result.append({"name": "agent_run_id", "type": "str"})
128
129
  result.append({"name": "text", "type": "str"})
129
130
 
130
131
  return result
@@ -0,0 +1,129 @@
1
+ from typing import Protocol, Sequence, runtime_checkable
2
+
3
+ import anyio
4
+ from tqdm.auto import tqdm
5
+
6
+ from docent._llm_util.llm_svc import BaseLLMService
7
+ from docent._log_util import get_logger
8
+ from docent.data_models.agent_run import AgentRun
9
+ from docent.judges import (
10
+ JudgeResult,
11
+ JudgeResultCompletionCallback,
12
+ Rubric,
13
+ )
14
+ from docent.judges.impl import build_judge
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ @runtime_checkable
20
+ class AgentRunResolver(Protocol):
21
+ async def __call__(self) -> AgentRun | None: ...
22
+
23
+
24
+ AgentRunInput = AgentRun | AgentRunResolver
25
+
26
+
27
+ async def _resolve_agent_run(agent_run_input: AgentRunInput) -> AgentRun | None:
28
+ if isinstance(agent_run_input, AgentRun):
29
+ return agent_run_input
30
+ else:
31
+ return await agent_run_input()
32
+
33
+
34
+ async def run_rubric(
35
+ agent_runs: Sequence[AgentRunInput],
36
+ rubric: Rubric,
37
+ llm_svc: BaseLLMService,
38
+ callback: JudgeResultCompletionCallback | None = None,
39
+ *,
40
+ n_rollouts_per_input: int | list[int] = 1,
41
+ show_progress: bool = True,
42
+ ) -> list[JudgeResult | None]:
43
+ if not agent_runs:
44
+ raise ValueError("agent_runs must be a non-empty sequence")
45
+ if rubric.n_rollouts_per_input <= 0:
46
+ raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
47
+
48
+ # Normalize n_rollouts_per_input to a list
49
+ if isinstance(n_rollouts_per_input, int):
50
+ if n_rollouts_per_input < 0:
51
+ raise ValueError("n_rollouts_per_input must be non-negative")
52
+ rollouts_per_run = [n_rollouts_per_input] * len(agent_runs)
53
+ else:
54
+ rollouts_per_run = n_rollouts_per_input
55
+ if len(rollouts_per_run) != len(agent_runs):
56
+ raise ValueError("n_rollouts_per_input list must match agent_runs length")
57
+ if any(n < 0 for n in rollouts_per_run):
58
+ raise ValueError("All values in n_rollouts_per_input must be non-negative")
59
+
60
+ judge = build_judge(rubric, llm_svc)
61
+
62
+ total_rollouts = sum(rollouts_per_run)
63
+ logger.info(
64
+ "Running rubric %s version %s against %d agent runs with %d total rollouts",
65
+ rubric.id,
66
+ rubric.version,
67
+ len(agent_runs),
68
+ total_rollouts,
69
+ )
70
+
71
+ agent_results: list[list[JudgeResult | None]] = [[] for _ in agent_runs]
72
+ progress_bar = tqdm(
73
+ total=total_rollouts,
74
+ desc=f"Rubric {rubric.id}",
75
+ disable=not show_progress,
76
+ )
77
+
78
+ # NOTE(mengk): using a (2 * llm max concurrency) semaphore is a hack to avoid
79
+ # hammering _resolve_agent_run, which makes expensive DB calls, when they aren't going to be
80
+ # immediately processed by the LLMService anyways.
81
+ # TODO(mengk): We should eventually implement a more idiomatic solution to this.
82
+ # It's related to the idea of a global concurrency limiter.
83
+ run_judge_semaphore = anyio.Semaphore(llm_svc.max_concurrency * 2)
84
+
85
+ async def _run_single_judge(index: int, agent_run_input: AgentRunInput):
86
+ async with run_judge_semaphore:
87
+ rollout_results: list[JudgeResult | None] = []
88
+
89
+ if rollouts_per_run[index] == 0:
90
+ agent_results[index] = []
91
+ if callback is not None:
92
+ await callback(index, None)
93
+ return
94
+
95
+ agent_run = await _resolve_agent_run(agent_run_input)
96
+ if agent_run is None:
97
+ if callback is not None:
98
+ await callback(index, None)
99
+ return
100
+
101
+ for _ in range(rollouts_per_run[index]):
102
+ result = await judge(agent_run)
103
+ rollout_results.append(result)
104
+ progress_bar.update()
105
+
106
+ agent_results[index] = rollout_results
107
+
108
+ if callback is not None:
109
+ # Filter out None results for the callback
110
+ valid_results = [r for r in rollout_results if r is not None]
111
+ await callback(index, valid_results if valid_results else None)
112
+
113
+ try:
114
+ async with anyio.create_task_group() as tg:
115
+ for index, agent_run in enumerate(agent_runs):
116
+ tg.start_soon(_run_single_judge, index, agent_run)
117
+ finally:
118
+ progress_bar.close()
119
+
120
+ flattened_results = [result for rollouts in agent_results for result in rollouts]
121
+ successful = sum(result is not None for result in flattened_results)
122
+ logger.info(
123
+ "Finished rubric %s: produced %d/%d judge results",
124
+ rubric.id,
125
+ successful,
126
+ len(flattened_results),
127
+ )
128
+
129
+ return flattened_results
@@ -200,7 +200,7 @@ class Docent:
200
200
  version: The version of the rubric to get run state for. If None, the latest version is used.
201
201
 
202
202
  Returns:
203
- dict: Dictionary containing rubric run state with results, job_id, and total_agent_runs.
203
+ dict: Dictionary containing rubric run state with results, job_id, and total_results_needed.
204
204
 
205
205
  Raises:
206
206
  requests.exceptions.HTTPError: If the API request fails.
@@ -450,6 +450,123 @@ class Docent:
450
450
  logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
451
451
  return response.json()
452
452
 
453
+ def collection_exists(self, collection_id: str) -> bool:
454
+ """Check if a collection exists without raising if it does not."""
455
+ url = f"{self._server_url}/{collection_id}/exists"
456
+ response = self._session.get(url)
457
+ self._handle_response_errors(response)
458
+ return bool(response.json())
459
+
460
+ def has_collection_permission(self, collection_id: str, permission: str = "write") -> bool:
461
+ """Check whether the authenticated user has a specific permission on a collection.
462
+
463
+ Args:
464
+ collection_id: Collection to check.
465
+ permission: Permission level to verify (`read`, `write`, or `admin`).
466
+
467
+ Returns:
468
+ bool: True if the current API key has the requested permission; otherwise False.
469
+
470
+ Raises:
471
+ ValueError: If an unsupported permission value is provided.
472
+ requests.exceptions.HTTPError: If the API request fails.
473
+ """
474
+ valid_permissions = {"read", "write", "admin"}
475
+ if permission not in valid_permissions:
476
+ raise ValueError(f"permission must be one of {sorted(valid_permissions)}")
477
+
478
+ url = f"{self._server_url}/{collection_id}/has_permission"
479
+ response = self._session.get(url, params={"permission": permission})
480
+ self._handle_response_errors(response)
481
+
482
+ payload = response.json()
483
+ return bool(payload.get("has_permission", False))
484
+
485
+ def get_dql_schema(self, collection_id: str) -> dict[str, Any]:
486
+ """Retrieve the DQL schema for a collection.
487
+
488
+ Args:
489
+ collection_id: ID of the Collection.
490
+
491
+ Returns:
492
+ dict: Dictionary containing available tables, columns, and metadata for DQL queries.
493
+
494
+ Raises:
495
+ requests.exceptions.HTTPError: If the API request fails.
496
+ """
497
+ url = f"{self._server_url}/dql/{collection_id}/schema"
498
+ response = self._session.get(url)
499
+ self._handle_response_errors(response)
500
+ return response.json()
501
+
502
+ def execute_dql(self, collection_id: str, dql: str) -> dict[str, Any]:
503
+ """Execute a DQL query against a collection.
504
+
505
+ Args:
506
+ collection_id: ID of the Collection.
507
+ dql: The DQL query string to execute.
508
+
509
+ Returns:
510
+ dict: Query execution results including rows, columns, execution metadata, and selected columns.
511
+
512
+ Raises:
513
+ ValueError: If `dql` is empty.
514
+ requests.exceptions.HTTPError: If the API request fails or the query is invalid.
515
+ """
516
+ if not dql.strip():
517
+ raise ValueError("dql must be a non-empty string")
518
+
519
+ url = f"{self._server_url}/dql/{collection_id}/execute"
520
+ response = self._session.post(url, json={"dql": dql})
521
+ self._handle_response_errors(response)
522
+ return response.json()
523
+
524
+ def select_agent_run_ids(
525
+ self,
526
+ collection_id: str,
527
+ where_clause: str | None = None,
528
+ limit: int | None = None,
529
+ ) -> list[str]:
530
+ """Convenience helper to fetch agent run IDs via DQL.
531
+
532
+ Args:
533
+ collection_id: ID of the Collection to query.
534
+ where_clause: Optional DQL WHERE clause applied to the agent_runs table.
535
+ limit: Optional LIMIT applied to the underlying DQL query.
536
+
537
+ Returns:
538
+ list[str]: Agent run IDs matching the criteria.
539
+
540
+ Raises:
541
+ ValueError: If the inputs are invalid.
542
+ requests.exceptions.HTTPError: If the API request fails.
543
+ """
544
+ query = "SELECT agent_runs.id AS agent_run_id FROM agent_runs"
545
+
546
+ if where_clause:
547
+ where_clause = where_clause.strip()
548
+ if not where_clause:
549
+ raise ValueError("where_clause must be a non-empty string when provided")
550
+ query += f" WHERE {where_clause}"
551
+
552
+ if limit is not None:
553
+ if limit <= 0:
554
+ raise ValueError("limit must be a positive integer when provided")
555
+ query += f" LIMIT {limit}"
556
+
557
+ result = self.execute_dql(collection_id, query)
558
+ rows = result.get("rows", [])
559
+ agent_run_ids = [str(row[0]) for row in rows if row]
560
+
561
+ if result.get("truncated"):
562
+ logger.warning(
563
+ "DQL query truncated at applied limit %s; returning %s agent run IDs",
564
+ result.get("applied_limit"),
565
+ len(agent_run_ids),
566
+ )
567
+
568
+ return agent_run_ids
569
+
453
570
  def list_agent_run_ids(self, collection_id: str) -> list[str]:
454
571
  """Get all agent run IDs for a collection.
455
572
 
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import atexit
3
2
  import contextvars
4
3
  import itertools
@@ -9,13 +8,24 @@ import sys
9
8
  import threading
10
9
  import uuid
11
10
  from collections import defaultdict
12
- from concurrent.futures import Future, ThreadPoolExecutor
13
11
  from contextlib import asynccontextmanager, contextmanager
14
12
  from contextvars import ContextVar, Token
15
13
  from datetime import datetime, timezone
16
14
  from enum import Enum
17
15
  from importlib.metadata import Distribution, distributions
18
- from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Set, Union
16
+ from typing import (
17
+ Any,
18
+ AsyncIterator,
19
+ Callable,
20
+ Dict,
21
+ Iterator,
22
+ List,
23
+ Mapping,
24
+ Optional,
25
+ Set,
26
+ Union,
27
+ cast,
28
+ )
19
29
 
20
30
  import requests
21
31
  from opentelemetry import trace
@@ -31,12 +41,23 @@ from opentelemetry.sdk.trace.export import (
31
41
  SimpleSpanProcessor,
32
42
  )
33
43
  from opentelemetry.trace import Span
44
+ from requests import Response
34
45
 
35
46
  logger = logging.getLogger(__name__)
36
47
 
37
48
  # Default configuration
38
49
  DEFAULT_ENDPOINT = "https://api.docent.transluce.org/rest/telemetry"
39
50
  DEFAULT_COLLECTION_NAME = "default-collection-name"
51
+ ERROR_DETAIL_MAX_CHARS = 500
52
+
53
+ # Sentinel values for when tracing is disabled
54
+ DISABLED_AGENT_RUN_ID = "disabled"
55
+ DISABLED_TRANSCRIPT_ID = "disabled"
56
+ DISABLED_TRANSCRIPT_GROUP_ID = "disabled"
57
+
58
+
59
+ class DocentTelemetryRequestError(RuntimeError):
60
+ """Raised when the Docent telemetry backend rejects a client request."""
40
61
 
41
62
 
42
63
  class Instruments(Enum):
@@ -135,10 +156,6 @@ class DocentTracer:
135
156
  self._transcript_group_states: dict[str, dict[str, Optional[str]]] = {}
136
157
  self._transcript_group_state_lock = threading.Lock()
137
158
  self._flush_lock = threading.Lock()
138
- self._http_executor: Optional[ThreadPoolExecutor] = None
139
- self._http_executor_lock = threading.Lock()
140
- self._pending_http_futures: Set[Future[Any]] = set()
141
- self._pending_http_lock = threading.Lock()
142
159
 
143
160
  def get_current_agent_run_id(self) -> Optional[str]:
144
161
  """
@@ -448,12 +465,6 @@ class DocentTracer:
448
465
  try:
449
466
  self.flush()
450
467
 
451
- if self._http_executor:
452
- self._http_executor.shutdown(wait=True)
453
- self._http_executor = None
454
- with self._pending_http_lock:
455
- self._pending_http_futures.clear()
456
-
457
468
  if self._tracer_provider:
458
469
  self._tracer_provider.shutdown()
459
470
  self._tracer_provider = None
@@ -484,7 +495,6 @@ class DocentTracer:
484
495
  if hasattr(processor, "force_flush"):
485
496
  logger.debug(f"Flushing span processor {i}")
486
497
  processor.force_flush(timeout_millis=50)
487
- self._wait_for_http_requests()
488
498
  logger.debug("Span flush completed")
489
499
  except Exception as e:
490
500
  logger.error(f"Error during flush: {e}")
@@ -503,6 +513,24 @@ class DocentTracer:
503
513
  """Verify if the manager is properly initialized."""
504
514
  return self._initialized
505
515
 
516
+ def get_disabled_agent_run_id(self, agent_run_id: Optional[str]) -> str:
517
+ """Return sentinel value for agent run ID when tracing is disabled."""
518
+ if agent_run_id is None:
519
+ return DISABLED_AGENT_RUN_ID
520
+ return agent_run_id
521
+
522
+ def get_disabled_transcript_id(self, transcript_id: Optional[str]) -> str:
523
+ """Return sentinel value for transcript ID when tracing is disabled."""
524
+ if transcript_id is None:
525
+ return DISABLED_TRANSCRIPT_ID
526
+ return transcript_id
527
+
528
+ def get_disabled_transcript_group_id(self, transcript_group_id: Optional[str]) -> str:
529
+ """Return sentinel value for transcript group ID when tracing is disabled."""
530
+ if transcript_group_id is None:
531
+ return DISABLED_TRANSCRIPT_GROUP_ID
532
+ return transcript_group_id
533
+
506
534
  @contextmanager
507
535
  def agent_run_context(
508
536
  self,
@@ -524,11 +552,8 @@ class DocentTracer:
524
552
  Tuple of (agent_run_id, transcript_id)
525
553
  """
526
554
  if self._disabled:
527
- # Return dummy IDs when tracing is disabled
528
- if agent_run_id is None:
529
- agent_run_id = str(uuid.uuid4())
530
- if transcript_id is None:
531
- transcript_id = str(uuid.uuid4())
555
+ agent_run_id = self.get_disabled_agent_run_id(agent_run_id)
556
+ transcript_id = self.get_disabled_transcript_id(transcript_id)
532
557
  yield agent_run_id, transcript_id
533
558
  return
534
559
 
@@ -551,7 +576,7 @@ class DocentTracer:
551
576
  try:
552
577
  self.send_agent_run_metadata(agent_run_id, metadata)
553
578
  except Exception as e:
554
- logger.warning(f"Failed sending agent run metadata: {e}")
579
+ logger.error(f"Failed sending agent run metadata: {e}")
555
580
 
556
581
  yield agent_run_id, transcript_id
557
582
  finally:
@@ -581,11 +606,8 @@ class DocentTracer:
581
606
  Tuple of (agent_run_id, transcript_id)
582
607
  """
583
608
  if self._disabled:
584
- # Return dummy IDs when tracing is disabled
585
- if agent_run_id is None:
586
- agent_run_id = str(uuid.uuid4())
587
- if transcript_id is None:
588
- transcript_id = str(uuid.uuid4())
609
+ agent_run_id = self.get_disabled_agent_run_id(agent_run_id)
610
+ transcript_id = self.get_disabled_transcript_id(transcript_id)
589
611
  yield agent_run_id, transcript_id
590
612
  return
591
613
 
@@ -631,48 +653,6 @@ class DocentTracer:
631
653
 
632
654
  return headers
633
655
 
634
- def _get_http_executor(self) -> ThreadPoolExecutor:
635
- with self._http_executor_lock:
636
- if self._http_executor is None:
637
- self._http_executor = ThreadPoolExecutor(
638
- max_workers=4, thread_name_prefix="docent-http"
639
- )
640
- return self._http_executor
641
-
642
- def _should_run_http_in_background(self) -> bool:
643
- try:
644
- loop = asyncio.get_running_loop()
645
- except RuntimeError:
646
- return False
647
- return loop.is_running()
648
-
649
- def _on_http_future_done(self, future: Future[Any]) -> None:
650
- with self._pending_http_lock:
651
- self._pending_http_futures.discard(future)
652
- try:
653
- future.result()
654
- except Exception as exc: # pragma: no cover - defensive logging
655
- logger.error(f"Background HTTP request failed: {exc}")
656
-
657
- def _schedule_background_post(self, task: Callable[[], None]) -> None:
658
- executor = self._get_http_executor()
659
- future = executor.submit(task)
660
- with self._pending_http_lock:
661
- self._pending_http_futures.add(future)
662
- future.add_done_callback(self._on_http_future_done)
663
-
664
- def _wait_for_http_requests(self) -> None:
665
- while True:
666
- with self._pending_http_lock:
667
- pending = list(self._pending_http_futures)
668
- if not pending:
669
- break
670
- for future in pending:
671
- try:
672
- future.result()
673
- except Exception as exc: # pragma: no cover - defensive logging
674
- logger.error(f"Background HTTP request failed: {exc}")
675
-
676
656
  def _ensure_json_serializable_metadata(self, metadata: Dict[str, Any], context: str) -> None:
677
657
  """
678
658
  Validate that metadata can be serialized to JSON before sending it to the backend.
@@ -681,13 +661,14 @@ class DocentTracer:
681
661
  json.dumps(metadata)
682
662
  except (TypeError, ValueError) as exc:
683
663
  raise TypeError(f"{context} metadata must be JSON serializable") from exc
664
+ offending_path = self._find_null_character_path(metadata)
665
+ if offending_path is not None:
666
+ raise ValueError(
667
+ f"{context} metadata cannot contain null characters (found at {offending_path}). "
668
+ "Remove or replace '\\u0000' before calling Docent tracing APIs."
669
+ )
684
670
 
685
- def _post_json(
686
- self, path: str, data: Dict[str, Any], *, allow_background: bool = False
687
- ) -> None:
688
- if allow_background and self._should_run_http_in_background():
689
- self._schedule_background_post(lambda: self._post_json_sync(path, data))
690
- return
671
+ def _post_json(self, path: str, data: Dict[str, Any]) -> None:
691
672
  self._post_json_sync(path, data)
692
673
 
693
674
  def _post_json_sync(self, path: str, data: Dict[str, Any]) -> None:
@@ -697,8 +678,159 @@ class DocentTracer:
697
678
  try:
698
679
  resp = requests.post(url, json=data, headers=self._api_headers(), timeout=(10, 60))
699
680
  resp.raise_for_status()
700
- except requests.exceptions.RequestException as e:
701
- logger.error(f"Failed POST {url}: {e}")
681
+ except requests.exceptions.RequestException as exc:
682
+ message = self._format_request_exception(url, exc)
683
+ raise DocentTelemetryRequestError(message) from exc
684
+
685
+ def _format_request_exception(self, url: str, exc: requests.exceptions.RequestException) -> str:
686
+ response: Optional[Response] = getattr(exc, "response", None)
687
+ message_parts: List[str] = [f"Failed POST {url}"]
688
+ suggestion: Optional[str]
689
+
690
+ if response is not None:
691
+ status_phrase = f"HTTP {response.status_code}"
692
+ if response.reason:
693
+ status_phrase = f"{status_phrase} {response.reason}"
694
+ message_parts.append(f"({status_phrase})")
695
+
696
+ detail = self._extract_response_detail(response)
697
+ if detail:
698
+ message_parts.append(f"- Backend detail: {detail}")
699
+
700
+ request_id = response.headers.get("x-request-id")
701
+ if request_id:
702
+ message_parts.append(f"(request-id: {request_id})")
703
+
704
+ suggestion = self._suggest_fix_for_status(response.status_code)
705
+ else:
706
+ message_parts.append(f"- {exc}")
707
+ suggestion = self._suggest_fix_for_status(None)
708
+
709
+ if suggestion:
710
+ message_parts.append(suggestion)
711
+
712
+ return " ".join(part for part in message_parts if part)
713
+
714
+ def _extract_response_detail(self, response: Response) -> Optional[str]:
715
+ try:
716
+ body = response.json()
717
+ except ValueError:
718
+ text = response.text.strip()
719
+ if not text:
720
+ return None
721
+ normalized = " ".join(text.split())
722
+ return self._truncate_error_message(normalized)
723
+
724
+ if isinstance(body, dict):
725
+ typed_body = cast(Dict[str, Any], body)
726
+ structured_message = self._structured_detail_message(typed_body)
727
+ if structured_message:
728
+ return self._truncate_error_message(structured_message)
729
+ return self._truncate_error_message(self._normalize_error_value(typed_body))
730
+
731
+ return self._truncate_error_message(self._normalize_error_value(body))
732
+
733
+ def _structured_detail_message(self, data: Dict[str, Any]) -> Optional[str]:
734
+ for key in ("detail", "message", "error"):
735
+ if key in data:
736
+ structured_value = self._structured_detail_value(data[key])
737
+ if structured_value:
738
+ return structured_value
739
+ return self._structured_detail_value(data)
740
+
741
+ def _structured_detail_value(self, value: Any) -> Optional[str]:
742
+ if isinstance(value, Mapping):
743
+ mapping_value = cast(Mapping[str, Any], value)
744
+ message = mapping_value.get("message")
745
+ hint = mapping_value.get("hint")
746
+ error_code = mapping_value.get("error_code")
747
+ request_id = mapping_value.get("request_id")
748
+ fallback_detail = mapping_value.get("detail")
749
+
750
+ parts: List[str] = []
751
+ if isinstance(message, str) and message.strip():
752
+ parts.append(message.strip())
753
+ elif isinstance(fallback_detail, str) and fallback_detail.strip():
754
+ parts.append(fallback_detail.strip())
755
+
756
+ if isinstance(hint, str) and hint.strip():
757
+ parts.append(f"(hint: {hint.strip()})")
758
+ if isinstance(error_code, str) and error_code.strip():
759
+ parts.append(f"[code: {error_code.strip()}]")
760
+ if isinstance(request_id, str) and request_id.strip():
761
+ parts.append(f"(request-id: {request_id.strip()})")
762
+
763
+ return " ".join(parts) if parts else None
764
+
765
+ if isinstance(value, str) and value.strip():
766
+ return value.strip()
767
+
768
+ return None
769
+
770
+ def _normalize_error_value(self, value: Any) -> str:
771
+ if isinstance(value, str):
772
+ return " ".join(value.split())
773
+
774
+ try:
775
+ serialized = json.dumps(value)
776
+ except (TypeError, ValueError):
777
+ serialized = str(value)
778
+
779
+ return " ".join(serialized.split())
780
+
781
+ def _truncate_error_message(self, message: str) -> str:
782
+ message = message.strip()
783
+ if len(message) <= ERROR_DETAIL_MAX_CHARS:
784
+ return message
785
+ return f"{message[:ERROR_DETAIL_MAX_CHARS]}..."
786
+
787
+ def _suggest_fix_for_status(self, status_code: Optional[int]) -> Optional[str]:
788
+ if status_code in (401, 403):
789
+ return (
790
+ "Verify that the Authorization header or DOCENT_API_KEY grants write access to the "
791
+ "target collection."
792
+ )
793
+ if status_code == 404:
794
+ return (
795
+ "Ensure the tracing endpoint passed to initialize_tracing matches the Docent server's "
796
+ "/rest/telemetry route."
797
+ )
798
+ if status_code in (400, 422):
799
+ return (
800
+ "Confirm the payload includes collection_id, agent_run_id, metadata, and timestamp in "
801
+ "the expected format."
802
+ )
803
+ if status_code and status_code >= 500:
804
+ return "Inspect the Docent backend logs for the referenced request."
805
+ if status_code is None:
806
+ return "Confirm the Docent telemetry endpoint is reachable from this process."
807
+ return None
808
+
809
+ def _find_null_character_path(self, value: Any, path: str = "") -> Optional[str]:
810
+ """Backend rejects NUL bytes, so detect them before we send metadata to the backend."""
811
+ return None
812
+ if isinstance(value, str):
813
+ if "\x00" in value or "\\u0000" in value or "\\x00" in value:
814
+ return path or "<root>"
815
+ return None
816
+
817
+ if isinstance(value, dict):
818
+ for key, item in value.items():
819
+ next_path = f"{path}.{key}" if path else str(key)
820
+ result = self._find_null_character_path(item, next_path)
821
+ if result:
822
+ return result
823
+ return None
824
+
825
+ if isinstance(value, (list, tuple)):
826
+ for index, item in enumerate(value):
827
+ next_path = f"{path}[{index}]" if path else f"[{index}]"
828
+ result = self._find_null_character_path(item, next_path)
829
+ if result:
830
+ return result
831
+ return None
832
+
833
+ return None
702
834
 
703
835
  def send_agent_run_score(
704
836
  self,
@@ -744,7 +876,7 @@ class DocentTracer:
744
876
  "metadata": metadata,
745
877
  "timestamp": datetime.now(timezone.utc).isoformat(),
746
878
  }
747
- self._post_json("/v1/agent-run-metadata", payload, allow_background=True)
879
+ self._post_json("/v1/agent-run-metadata", payload)
748
880
 
749
881
  def send_transcript_metadata(
750
882
  self,
@@ -834,9 +966,7 @@ class DocentTracer:
834
966
  The transcript ID
835
967
  """
836
968
  if self._disabled:
837
- # Return dummy ID when tracing is disabled
838
- if transcript_id is None:
839
- transcript_id = str(uuid.uuid4())
969
+ transcript_id = self.get_disabled_transcript_id(transcript_id)
840
970
  yield transcript_id
841
971
  return
842
972
 
@@ -866,7 +996,7 @@ class DocentTracer:
866
996
  transcript_id, name, description, transcript_group_id, metadata
867
997
  )
868
998
  except Exception as e:
869
- logger.warning(f"Failed sending transcript data: {e}")
999
+ logger.error(f"Failed sending transcript data: {e}")
870
1000
 
871
1001
  yield transcript_id
872
1002
  finally:
@@ -896,9 +1026,7 @@ class DocentTracer:
896
1026
  The transcript ID
897
1027
  """
898
1028
  if self._disabled:
899
- # Return dummy ID when tracing is disabled
900
- if transcript_id is None:
901
- transcript_id = str(uuid.uuid4())
1029
+ transcript_id = self.get_disabled_transcript_id(transcript_id)
902
1030
  yield transcript_id
903
1031
  return
904
1032
 
@@ -928,7 +1056,7 @@ class DocentTracer:
928
1056
  transcript_id, name, description, transcript_group_id, metadata
929
1057
  )
930
1058
  except Exception as e:
931
- logger.warning(f"Failed sending transcript data: {e}")
1059
+ logger.error(f"Failed sending transcript data: {e}")
932
1060
 
933
1061
  yield transcript_id
934
1062
  finally:
@@ -1029,9 +1157,7 @@ class DocentTracer:
1029
1157
  The transcript group ID
1030
1158
  """
1031
1159
  if self._disabled:
1032
- # Return dummy ID when tracing is disabled
1033
- if transcript_group_id is None:
1034
- transcript_group_id = str(uuid.uuid4())
1160
+ transcript_group_id = self.get_disabled_transcript_group_id(transcript_group_id)
1035
1161
  yield transcript_group_id
1036
1162
  return
1037
1163
 
@@ -1063,7 +1189,7 @@ class DocentTracer:
1063
1189
  transcript_group_id, name, description, parent_transcript_group_id, metadata
1064
1190
  )
1065
1191
  except Exception as e:
1066
- logger.warning(f"Failed sending transcript group data: {e}")
1192
+ logger.error(f"Failed sending transcript group data: {e}")
1067
1193
 
1068
1194
  yield transcript_group_id
1069
1195
  finally:
@@ -1093,9 +1219,7 @@ class DocentTracer:
1093
1219
  The transcript group ID
1094
1220
  """
1095
1221
  if self._disabled:
1096
- # Return dummy ID when tracing is disabled
1097
- if transcript_group_id is None:
1098
- transcript_group_id = str(uuid.uuid4())
1222
+ transcript_group_id = self.get_disabled_transcript_group_id(transcript_group_id)
1099
1223
  yield transcript_group_id
1100
1224
  return
1101
1225
 
@@ -1127,7 +1251,7 @@ class DocentTracer:
1127
1251
  transcript_group_id, name, description, parent_transcript_group_id, metadata
1128
1252
  )
1129
1253
  except Exception as e:
1130
- logger.warning(f"Failed sending transcript group data: {e}")
1254
+ logger.error(f"Failed sending transcript group data: {e}")
1131
1255
 
1132
1256
  yield transcript_group_id
1133
1257
  finally:
@@ -1331,28 +1455,33 @@ def agent_run_metadata(metadata: Dict[str, Any]) -> None:
1331
1455
 
1332
1456
  tracer.send_agent_run_metadata(agent_run_id, metadata)
1333
1457
  except Exception as e:
1334
- logger.error(f"Failed to send metadata: {e}")
1458
+ logger.error(f"Failed to send agent run metadata: {e}")
1335
1459
 
1336
1460
 
1337
1461
  def transcript_metadata(
1462
+ metadata: Dict[str, Any],
1463
+ *,
1338
1464
  name: Optional[str] = None,
1339
1465
  description: Optional[str] = None,
1340
1466
  transcript_group_id: Optional[str] = None,
1341
- metadata: Optional[Dict[str, Any]] = None,
1342
1467
  ) -> None:
1343
1468
  """
1344
1469
  Send transcript metadata directly to the backend for the current transcript.
1345
1470
 
1346
1471
  Args:
1472
+ metadata: Dictionary of metadata to attach to the current transcript (required)
1347
1473
  name: Optional transcript name
1348
1474
  description: Optional transcript description
1349
- parent_transcript_id: Optional parent transcript ID
1350
- metadata: Optional metadata to send
1475
+ transcript_group_id: Optional transcript group ID to associate with
1351
1476
 
1352
1477
  Example:
1353
- transcript_metadata(name="data_processing", description="Process user data")
1354
- transcript_metadata(metadata={"user": "John", "model": "gpt-4"})
1355
- transcript_metadata(name="validation", parent_transcript_id="parent-123")
1478
+ transcript_metadata({"user": "John", "model": "gpt-4"})
1479
+ transcript_metadata({"env": "prod"}, name="data_processing")
1480
+ transcript_metadata(
1481
+ {"team": "search"},
1482
+ name="validation",
1483
+ transcript_group_id="group-123",
1484
+ )
1356
1485
  """
1357
1486
  try:
1358
1487
  tracer = get_tracer()
@@ -1371,23 +1500,29 @@ def transcript_metadata(
1371
1500
 
1372
1501
 
1373
1502
  def transcript_group_metadata(
1503
+ metadata: Dict[str, Any],
1504
+ *,
1374
1505
  name: Optional[str] = None,
1375
1506
  description: Optional[str] = None,
1376
1507
  parent_transcript_group_id: Optional[str] = None,
1377
- metadata: Optional[Dict[str, Any]] = None,
1378
1508
  ) -> None:
1379
1509
  """
1380
1510
  Send transcript group metadata directly to the backend for the current transcript group.
1381
1511
 
1382
1512
  Args:
1513
+ metadata: Dictionary of metadata to attach to the current transcript group (required)
1383
1514
  name: Optional transcript group name
1384
1515
  description: Optional transcript group description
1385
1516
  parent_transcript_group_id: Optional parent transcript group ID
1386
- metadata: Optional metadata to send
1387
1517
 
1388
1518
  Example:
1389
- transcript_group_metadata(name="pipeline", description="Main processing pipeline")
1390
- transcript_group_metadata(metadata={"team": "search", "env": "prod"})
1519
+ transcript_group_metadata({"team": "search", "env": "prod"})
1520
+ transcript_group_metadata({"env": "prod"}, name="pipeline")
1521
+ transcript_group_metadata(
1522
+ {"team": "search"},
1523
+ name="pipeline",
1524
+ parent_transcript_group_id="root-group",
1525
+ )
1391
1526
  """
1392
1527
  try:
1393
1528
  tracer = get_tracer()
@@ -1424,6 +1559,11 @@ class AgentRunContext:
1424
1559
 
1425
1560
  def __enter__(self) -> tuple[str, str]:
1426
1561
  """Sync context manager entry."""
1562
+ if is_disabled():
1563
+ tracer = get_tracer()
1564
+ self.agent_run_id = tracer.get_disabled_agent_run_id(self.agent_run_id)
1565
+ self.transcript_id = tracer.get_disabled_transcript_id(self.transcript_id)
1566
+ return self.agent_run_id, self.transcript_id
1427
1567
  self._sync_context = get_tracer().agent_run_context(
1428
1568
  self.agent_run_id, self.transcript_id, metadata=self.metadata, **self.attributes
1429
1569
  )
@@ -1436,6 +1576,11 @@ class AgentRunContext:
1436
1576
 
1437
1577
  async def __aenter__(self) -> tuple[str, str]:
1438
1578
  """Async context manager entry."""
1579
+ if is_disabled():
1580
+ tracer = get_tracer()
1581
+ self.agent_run_id = tracer.get_disabled_agent_run_id(self.agent_run_id)
1582
+ self.transcript_id = tracer.get_disabled_transcript_id(self.transcript_id)
1583
+ return self.agent_run_id, self.transcript_id
1439
1584
  self._async_context = get_tracer().async_agent_run_context(
1440
1585
  self.agent_run_id, self.transcript_id, metadata=self.metadata, **self.attributes
1441
1586
  )
@@ -1576,6 +1721,10 @@ class TranscriptContext:
1576
1721
 
1577
1722
  def __enter__(self) -> str:
1578
1723
  """Sync context manager entry."""
1724
+ if is_disabled():
1725
+ tracer = get_tracer()
1726
+ self.transcript_id = tracer.get_disabled_transcript_id(self.transcript_id)
1727
+ return self.transcript_id
1579
1728
  self._sync_context = get_tracer().transcript_context(
1580
1729
  name=self.name,
1581
1730
  transcript_id=self.transcript_id,
@@ -1592,6 +1741,10 @@ class TranscriptContext:
1592
1741
 
1593
1742
  async def __aenter__(self) -> str:
1594
1743
  """Async context manager entry."""
1744
+ if is_disabled():
1745
+ tracer = get_tracer()
1746
+ self.transcript_id = tracer.get_disabled_transcript_id(self.transcript_id)
1747
+ return self.transcript_id
1595
1748
  self._async_context = get_tracer().async_transcript_context(
1596
1749
  name=self.name,
1597
1750
  transcript_id=self.transcript_id,
@@ -1753,6 +1906,12 @@ class TranscriptGroupContext:
1753
1906
 
1754
1907
  def __enter__(self) -> str:
1755
1908
  """Sync context manager entry."""
1909
+ if is_disabled():
1910
+ tracer = get_tracer()
1911
+ self.transcript_group_id = tracer.get_disabled_transcript_group_id(
1912
+ self.transcript_group_id
1913
+ )
1914
+ return self.transcript_group_id
1756
1915
  self._sync_context = get_tracer().transcript_group_context(
1757
1916
  name=self.name,
1758
1917
  transcript_group_id=self.transcript_group_id,
@@ -1769,6 +1928,12 @@ class TranscriptGroupContext:
1769
1928
 
1770
1929
  async def __aenter__(self) -> str:
1771
1930
  """Async context manager entry."""
1931
+ if is_disabled():
1932
+ tracer = get_tracer()
1933
+ self.transcript_group_id = tracer.get_disabled_transcript_group_id(
1934
+ self.transcript_group_id
1935
+ )
1936
+ return self.transcript_group_id
1772
1937
  self._async_context = get_tracer().async_transcript_group_context(
1773
1938
  name=self.name,
1774
1939
  transcript_group_id=self.transcript_group_id,
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "docent-python"
3
3
  description = "Docent SDK"
4
- version = "0.1.26-alpha"
4
+ version = "0.1.28-alpha"
5
5
  authors = [
6
6
  { name="Transluce", email="info@transluce.org" },
7
7
  ]
@@ -1,66 +0,0 @@
1
- import anyio
2
- from tqdm.auto import tqdm
3
-
4
- from docent._llm_util.llm_svc import BaseLLMService
5
- from docent._log_util import get_logger
6
- from docent.data_models.agent_run import AgentRun
7
- from docent.judges import (
8
- JudgeResult,
9
- JudgeResultCompletionCallback,
10
- Rubric,
11
- )
12
- from docent.judges.impl import build_judge
13
-
14
- logger = get_logger(__name__)
15
-
16
-
17
- async def run_rubric(
18
- agent_runs: list[AgentRun],
19
- rubric: Rubric,
20
- llm_svc: BaseLLMService,
21
- callback: JudgeResultCompletionCallback | None = None,
22
- *,
23
- show_progress: bool = True,
24
- ) -> list[JudgeResult | None]:
25
- if not agent_runs:
26
- raise ValueError("agent_runs must be a non-empty sequence")
27
- if rubric.n_rollouts_per_input <= 0:
28
- raise ValueError("rubric.n_rollouts_per_input must be greater than 0")
29
-
30
- judge = build_judge(rubric, llm_svc)
31
-
32
- logger.info(
33
- "Running rubric %s version %s against %d agent runs",
34
- rubric.id,
35
- rubric.version,
36
- len(agent_runs),
37
- )
38
-
39
- agent_results: list[JudgeResult | None] = [None for _ in agent_runs]
40
- progress_bar = tqdm(
41
- total=len(agent_runs), desc=f"Rubric {rubric.id}", disable=not show_progress
42
- )
43
-
44
- async def _run_single_judge(index: int, agent_run: AgentRun):
45
- agent_results[index] = result = await judge(agent_run)
46
-
47
- if callback is not None:
48
- await callback(index, [result] if result is not None else None)
49
- progress_bar.update()
50
-
51
- try:
52
- async with anyio.create_task_group() as tg:
53
- for index, agent_run in enumerate(agent_runs):
54
- tg.start_soon(_run_single_judge, index, agent_run)
55
- finally:
56
- progress_bar.close()
57
-
58
- successful = sum(result is not None for result in agent_results)
59
- logger.info(
60
- "Finished rubric %s: produced %d/%d judge results",
61
- rubric.id,
62
- successful,
63
- len(agent_results),
64
- )
65
-
66
- return agent_results