docent-python 0.1.17a0__tar.gz → 0.1.19a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (36) hide show
  1. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/PKG-INFO +1 -1
  2. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/__init__.py +2 -0
  3. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/agent_run.py +5 -5
  4. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/chat/__init__.py +6 -1
  5. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/citation.py +103 -22
  6. docent_python-0.1.19a0/docent/data_models/judge.py +16 -0
  7. docent_python-0.1.19a0/docent/data_models/metadata_util.py +16 -0
  8. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/remove_invalid_citation_ranges.py +23 -10
  9. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/transcript.py +18 -16
  10. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/sdk/agent_run_writer.py +62 -19
  11. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/sdk/client.py +104 -20
  12. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/trace.py +54 -49
  13. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/pyproject.toml +1 -1
  14. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/uv.lock +1 -1
  15. docent_python-0.1.17a0/docent/data_models/metadata.py +0 -229
  16. docent_python-0.1.17a0/docent/data_models/yaml_util.py +0 -12
  17. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/.gitignore +0 -0
  18. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/LICENSE.md +0 -0
  19. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/README.md +0 -0
  20. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/__init__.py +0 -0
  21. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/_log_util/__init__.py +0 -0
  22. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/_log_util/logger.py +0 -0
  23. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/_tiktoken_util.py +0 -0
  24. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/chat/content.py +0 -0
  25. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/chat/message.py +0 -0
  26. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/chat/tool.py +0 -0
  27. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/regex.py +0 -0
  28. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/data_models/shared_types.py +0 -0
  29. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/loaders/load_inspect.py +0 -0
  30. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/py.typed +0 -0
  31. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/samples/__init__.py +0 -0
  32. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/samples/load.py +0 -0
  33. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/samples/log.eval +0 -0
  34. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/samples/tb_airline.json +0 -0
  35. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/sdk/__init__.py +0 -0
  36. {docent_python-0.1.17a0 → docent_python-0.1.19a0}/docent/trace_temp.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.17a0
3
+ Version: 0.1.19a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -1,11 +1,13 @@
1
1
  from docent.data_models.agent_run import AgentRun
2
2
  from docent.data_models.citation import Citation
3
+ from docent.data_models.judge import JudgeRunLabel
3
4
  from docent.data_models.regex import RegexSnippet
4
5
  from docent.data_models.transcript import Transcript, TranscriptGroup
5
6
 
6
7
  __all__ = [
7
8
  "AgentRun",
8
9
  "Citation",
10
+ "JudgeRunLabel",
9
11
  "RegexSnippet",
10
12
  "Transcript",
11
13
  "TranscriptGroup",
@@ -17,8 +17,8 @@ from pydantic_core import to_jsonable_python
17
17
 
18
18
  from docent._log_util import get_logger
19
19
  from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
20
+ from docent.data_models.metadata_util import dump_metadata
20
21
  from docent.data_models.transcript import Transcript, TranscriptGroup
21
- from docent.data_models.yaml_util import yaml_dump_metadata
22
22
 
23
23
  logger = get_logger(__name__)
24
24
 
@@ -446,10 +446,10 @@ class AgentRun(BaseModel):
446
446
  text = _recurse("__global_root")
447
447
 
448
448
  # Append agent run metadata below the full content
449
- yaml_text = yaml_dump_metadata(self.metadata)
450
- if yaml_text is not None:
449
+ metadata_text = dump_metadata(self.metadata)
450
+ if metadata_text is not None:
451
451
  if indent > 0:
452
- yaml_text = textwrap.indent(yaml_text, " " * indent)
453
- text += f"\n<|agent run metadata|>\n{yaml_text}\n</|agent run metadata|>"
452
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
453
+ text += f"\n<|agent run metadata|>\n{metadata_text}\n</|agent run metadata|>"
454
454
 
455
455
  return text
@@ -7,7 +7,12 @@ from docent.data_models.chat.message import (
7
7
  UserMessage,
8
8
  parse_chat_message,
9
9
  )
10
- from docent.data_models.chat.tool import ToolCall, ToolCallContent, ToolInfo, ToolParams
10
+ from docent.data_models.chat.tool import (
11
+ ToolCall,
12
+ ToolCallContent,
13
+ ToolInfo,
14
+ ToolParams,
15
+ )
11
16
 
12
17
  __all__ = [
13
18
  "ChatMessage",
@@ -1,15 +1,27 @@
1
1
  import re
2
+ from dataclasses import dataclass
2
3
 
3
4
  from pydantic import BaseModel
4
5
 
5
6
 
7
+ @dataclass
8
+ class ParsedCitation:
9
+ """Represents a parsed citation before conversion to full Citation object."""
10
+
11
+ transcript_idx: int | None
12
+ block_idx: int | None
13
+ metadata_key: str | None = None
14
+ start_pattern: str | None = None
15
+
16
+
6
17
  class Citation(BaseModel):
7
18
  start_idx: int
8
19
  end_idx: int
9
20
  agent_run_idx: int | None = None
10
21
  transcript_idx: int | None = None
11
- block_idx: int
22
+ block_idx: int | None = None
12
23
  action_unit_idx: int | None = None
24
+ metadata_key: str | None = None
13
25
  start_pattern: str | None = None
14
26
 
15
27
 
@@ -17,6 +29,9 @@ RANGE_BEGIN = "<RANGE>"
17
29
  RANGE_END = "</RANGE>"
18
30
 
19
31
  _SINGLE_RE = re.compile(r"T(\d+)B(\d+)")
32
+ _METADATA_RE = re.compile(r"^M\.([^:]+)$") # [M.key]
33
+ _TRANSCRIPT_METADATA_RE = re.compile(r"^T(\d+)M\.([^:]+)$") # [T0M.key]
34
+ _MESSAGE_METADATA_RE = re.compile(r"^T(\d+)B(\d+)M\.([^:]+)$") # [T0B1M.key]
20
35
  _RANGE_CONTENT_RE = re.compile(r":\s*" + re.escape(RANGE_BEGIN) + r".*?" + re.escape(RANGE_END))
21
36
 
22
37
 
@@ -70,41 +85,93 @@ def scan_brackets(text: str) -> list[tuple[int, int, str]]:
70
85
  return matches
71
86
 
72
87
 
73
- def parse_single_citation(part: str) -> tuple[int, int, str | None] | None:
88
+ def parse_single_citation(part: str) -> ParsedCitation | None:
74
89
  """
75
90
  Parse a single citation token inside a bracket and return its components.
76
91
 
77
- Returns (transcript_idx, block_idx, start_pattern) or None if invalid.
92
+ Returns ParsedCitation or None if invalid.
93
+ For metadata citations, transcript_idx may be None (for agent run metadata).
94
+ Supports optional text range for all valid citation kinds.
78
95
  """
79
96
  token = part.strip()
80
97
  if not token:
81
98
  return None
82
99
 
100
+ # Extract optional range part
101
+ start_pattern: str | None = None
102
+ citation_part = token
83
103
  if ":" in token:
84
- citation_part, range_part = token.split(":", 1)
85
- single_match = _SINGLE_RE.match(citation_part.strip())
86
- if not single_match:
104
+ left, right = token.split(":", 1)
105
+ citation_part = left.strip()
106
+ start_pattern = _extract_range_pattern(right)
107
+
108
+ # Try matches in order of specificity
109
+ # 1) Message metadata [T0B0M.key]
110
+ m = _MESSAGE_METADATA_RE.match(citation_part)
111
+ if m:
112
+ transcript_idx = int(m.group(1))
113
+ block_idx = int(m.group(2))
114
+ metadata_key = m.group(3)
115
+ # Disallow nested keys like status.code per instruction
116
+ if "." in metadata_key:
87
117
  return None
88
- transcript_idx = int(single_match.group(1))
89
- block_idx = int(single_match.group(2))
90
- start_pattern = _extract_range_pattern(range_part)
91
- return transcript_idx, block_idx, start_pattern
92
- else:
93
- single_match = _SINGLE_RE.match(token)
94
- if not single_match:
118
+ return ParsedCitation(
119
+ transcript_idx=transcript_idx,
120
+ block_idx=block_idx,
121
+ metadata_key=metadata_key,
122
+ start_pattern=start_pattern,
123
+ )
124
+
125
+ # 2) Transcript metadata [T0M.key]
126
+ m = _TRANSCRIPT_METADATA_RE.match(citation_part)
127
+ if m:
128
+ transcript_idx = int(m.group(1))
129
+ metadata_key = m.group(2)
130
+ if "." in metadata_key:
95
131
  return None
96
- transcript_idx = int(single_match.group(1))
97
- block_idx = int(single_match.group(2))
98
- return transcript_idx, block_idx, None
132
+ return ParsedCitation(
133
+ transcript_idx=transcript_idx,
134
+ block_idx=None,
135
+ metadata_key=metadata_key,
136
+ start_pattern=start_pattern,
137
+ )
138
+
139
+ # 3) Agent run metadata [M.key]
140
+ m = _METADATA_RE.match(citation_part)
141
+ if m:
142
+ metadata_key = m.group(1)
143
+ if "." in metadata_key:
144
+ return None
145
+ return ParsedCitation(
146
+ transcript_idx=None,
147
+ block_idx=None,
148
+ metadata_key=metadata_key,
149
+ start_pattern=start_pattern,
150
+ )
151
+
152
+ # 4) Regular transcript block [T0B0]
153
+ m = _SINGLE_RE.match(citation_part)
154
+ if m:
155
+ transcript_idx = int(m.group(1))
156
+ block_idx = int(m.group(2))
157
+ return ParsedCitation(
158
+ transcript_idx=transcript_idx, block_idx=block_idx, start_pattern=start_pattern
159
+ )
160
+
161
+ return None
99
162
 
100
163
 
101
164
  def parse_citations(text: str) -> tuple[str, list[Citation]]:
102
165
  """
103
- Parse citations from text in the format described by BLOCK_RANGE_CITE_INSTRUCTION.
166
+ Parse citations from text in the format described by TEXT_RANGE_CITE_INSTRUCTION.
104
167
 
105
168
  Supported formats:
106
169
  - Single block: [T<key>B<idx>]
107
170
  - Text range with start pattern: [T<key>B<idx>:<RANGE>start_pattern</RANGE>]
171
+ - Agent run metadata: [M.key]
172
+ - Transcript metadata: [T<key>M.key]
173
+ - Message metadata: [T<key>B<idx>M.key]
174
+ - Message metadata with text range: [T<key>B<idx>M.key:<RANGE>start_pattern</RANGE>]
108
175
 
109
176
  Args:
110
177
  text: The text to parse citations from
@@ -127,8 +194,21 @@ def parse_citations(text: str) -> tuple[str, list[Citation]]:
127
194
  # Parse a single citation token inside the bracket
128
195
  parsed = parse_single_citation(bracket_content)
129
196
  if parsed:
130
- transcript_idx, block_idx, start_pattern = parsed
131
- replacement = f"T{transcript_idx}B{block_idx}"
197
+ # Create appropriate replacement text based on citation type
198
+ if parsed.metadata_key:
199
+ if parsed.transcript_idx is None:
200
+ # Agent run metadata [M.key]
201
+ replacement = "run metadata"
202
+ elif parsed.block_idx is None:
203
+ # Transcript metadata [T0M.key]
204
+ replacement = f"T{parsed.transcript_idx}"
205
+ else:
206
+ # Message metadata [T0B1M.key]
207
+ replacement = f"T{parsed.transcript_idx}B{parsed.block_idx}"
208
+ else:
209
+ # Regular transcript block [T0B1]
210
+ replacement = f"T{parsed.transcript_idx}B{parsed.block_idx}"
211
+
132
212
  # Current absolute start position for this replacement in the cleaned text
133
213
  start_idx = len(cleaned_text)
134
214
  end_idx = start_idx + len(replacement)
@@ -137,10 +217,11 @@ def parse_citations(text: str) -> tuple[str, list[Citation]]:
137
217
  start_idx=start_idx,
138
218
  end_idx=end_idx,
139
219
  agent_run_idx=None,
140
- transcript_idx=transcript_idx,
141
- block_idx=block_idx,
220
+ transcript_idx=parsed.transcript_idx,
221
+ block_idx=parsed.block_idx,
142
222
  action_unit_idx=None,
143
- start_pattern=start_pattern,
223
+ metadata_key=parsed.metadata_key,
224
+ start_pattern=parsed.start_pattern,
144
225
  )
145
226
  )
146
227
  cleaned_text += replacement
@@ -0,0 +1,16 @@
1
+ """Judge-related data models shared across Docent components."""
2
+
3
+ from typing import Any
4
+ from uuid import uuid4
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class JudgeRunLabel(BaseModel):
10
+ id: str = Field(default_factory=lambda: str(uuid4()))
11
+ agent_run_id: str
12
+ rubric_id: str
13
+ label: dict[str, Any]
14
+
15
+
16
+ __all__ = ["JudgeRunLabel"]
@@ -0,0 +1,16 @@
1
+ import json
2
+ from typing import Any
3
+
4
+ from pydantic_core import to_jsonable_python
5
+
6
+
7
+ def dump_metadata(metadata: dict[str, Any]) -> str | None:
8
+ """
9
+ Dump metadata to a JSON string.
10
+ We used to use YAML to save tokens, but JSON makes it easier to find cited ranges on the frontend because the frontend uses JSON.
11
+ """
12
+ if not metadata:
13
+ return None
14
+ metadata_obj = to_jsonable_python(metadata)
15
+ text = json.dumps(metadata_obj, indent=2)
16
+ return text.strip()
@@ -1,3 +1,4 @@
1
+ import json
1
2
  import re
2
3
 
3
4
  from docent.data_models.agent_run import AgentRun
@@ -52,7 +53,7 @@ def find_citation_matches_in_text(text: str, start_pattern: str) -> list[tuple[i
52
53
 
53
54
  def get_transcript_text_for_citation(agent_run: AgentRun, citation: Citation) -> str | None:
54
55
  """
55
- Get the text content of a specific transcript block from an AgentRun,
56
+ Get the text content of a specific transcript block (or transcript/run metadata) from an AgentRun,
56
57
  using the same formatting as shown to LLMs via format_chat_message.
57
58
 
58
59
  Args:
@@ -62,19 +63,28 @@ def get_transcript_text_for_citation(agent_run: AgentRun, citation: Citation) ->
62
63
  Returns:
63
64
  Text content of the specified block (including tool calls), or None if not found
64
65
  """
65
- if citation.transcript_idx is None:
66
- return None
67
-
68
66
  try:
69
- if citation.transcript_idx >= len(agent_run.get_transcript_ids_ordered()):
67
+ if citation.transcript_idx is None:
68
+ # At the run level, can only cite metadata
69
+ if citation.metadata_key is not None:
70
+ return json.dumps(agent_run.metadata.get(citation.metadata_key))
70
71
  return None
72
+
71
73
  transcript_id = agent_run.get_transcript_ids_ordered()[citation.transcript_idx]
72
74
  transcript = agent_run.transcript_dict[transcript_id]
73
75
 
74
- if citation.block_idx >= len(transcript.messages):
76
+ if citation.block_idx is None:
77
+ # At the transcript level, can only cite metadata
78
+ if citation.metadata_key is not None:
79
+ return json.dumps(transcript.metadata.get(citation.metadata_key))
75
80
  return None
81
+
76
82
  message = transcript.messages[citation.block_idx]
77
83
 
84
+ # At the message level, can cite metadata or content
85
+ if citation.metadata_key is not None:
86
+ return json.dumps(message.metadata.get(citation.metadata_key))
87
+
78
88
  # Use the same formatting function that generates content for LLMs
79
89
  # This ensures consistent formatting between citation validation and LLM serialization
80
90
  return format_chat_message(
@@ -99,6 +109,9 @@ def validate_citation_text_range(agent_run: AgentRun, citation: Citation) -> boo
99
109
  if not citation.start_pattern:
100
110
  # Nothing to validate
101
111
  return True
112
+ if citation.metadata_key is not None:
113
+ # We don't need to remove invalid metadata citation ranges
114
+ return True
102
115
 
103
116
  text = get_transcript_text_for_citation(agent_run, citation)
104
117
  if text is None:
@@ -130,16 +143,16 @@ def remove_invalid_citation_ranges(text: str, agent_run: AgentRun) -> str:
130
143
  # Parse this bracket content to get citation info
131
144
  parsed = parse_single_citation(bracket_content)
132
145
  if parsed:
133
- transcript_idx, block_idx, start_pattern = parsed
134
146
  # The citation spans from start to end in the original text
135
147
  citation = Citation(
136
148
  start_idx=start,
137
149
  end_idx=end,
138
150
  agent_run_idx=None,
139
- transcript_idx=transcript_idx,
140
- block_idx=block_idx,
151
+ transcript_idx=parsed.transcript_idx,
152
+ block_idx=parsed.block_idx,
141
153
  action_unit_idx=None,
142
- start_pattern=start_pattern,
154
+ metadata_key=parsed.metadata_key,
155
+ start_pattern=parsed.start_pattern,
143
156
  )
144
157
  citations.append(citation)
145
158
 
@@ -15,7 +15,7 @@ from docent.data_models._tiktoken_util import (
15
15
  )
16
16
  from docent.data_models.chat import AssistantMessage, ChatMessage, ContentReasoning
17
17
  from docent.data_models.citation import RANGE_BEGIN, RANGE_END
18
- from docent.data_models.yaml_util import yaml_dump_metadata
18
+ from docent.data_models.metadata_util import dump_metadata
19
19
 
20
20
  # Template for formatting individual transcript blocks
21
21
  TRANSCRIPT_BLOCK_TEMPLATE = """
@@ -29,6 +29,12 @@ TEXT_RANGE_CITE_INSTRUCTION = f"""Anytime you quote the transcript, or refer to
29
29
 
30
30
  A citation may include a specific range of text within a block. Use {RANGE_BEGIN} and {RANGE_END} to mark the specific range of text. Add it after the block ID separated by a colon. For example, to cite the part of transcript 0, block 1, where the agent says "I understand the task", write [T0B1:{RANGE_BEGIN}I understand the task{RANGE_END}]. Citations must follow this exact format. The markers {RANGE_BEGIN} and {RANGE_END} must be used ONLY inside the brackets of a citation.
31
31
 
32
+ - You may cite a top-level key in the agent run metadata like this: [M.task_description].
33
+ - You may cite a top-level key in transcript metadata. For example, for transcript 0: [T0M.start_time].
34
+ - You may cite a top-level key in message metadata for a block. For example, for transcript 0, block 1: [T0B1M.status].
35
+ - You may not cite nested keys. For example, [T0B1M.status.code] is invalid.
36
+ - Within a top-level metadata key you may cite a range of text that appears in the value. For example, [T0B1M.status:{RANGE_BEGIN}"running":false{RANGE_END}].
37
+
32
38
  Important notes:
33
39
  - You must include the full content of the text range {RANGE_BEGIN} and {RANGE_END}, EXACTLY as it appears in the transcript, word-for-word, including any markers or punctuation that appear in the middle of the text.
34
40
  - Citations must be as specific as possible. This means you should usually cite a specific text range within a block.
@@ -73,9 +79,9 @@ def format_chat_message(
73
79
  cur_content += f"\n<tool call>\n{tool_call.function}({args})\n</tool call>"
74
80
 
75
81
  if message.metadata:
76
- metadata_yaml = yaml_dump_metadata(message.metadata)
77
- if metadata_yaml is not None:
78
- cur_content += f"\n<|message metadata|>\n{metadata_yaml}\n</|message metadata|>"
82
+ metadata_text = dump_metadata(message.metadata)
83
+ if metadata_text is not None:
84
+ cur_content += f"\n<|message metadata|>\n{metadata_text}\n</|message metadata|>"
79
85
 
80
86
  return TRANSCRIPT_BLOCK_TEMPLATE.format(
81
87
  index_label=index_label, role=message.role, content=cur_content
@@ -127,13 +133,11 @@ class TranscriptGroup(BaseModel):
127
133
  str: XML-like wrapped text including the group's metadata.
128
134
  """
129
135
  # Prepare YAML metadata
130
- yaml_text = yaml_dump_metadata(self.metadata)
131
- if yaml_text is not None:
136
+ metadata_text = dump_metadata(self.metadata)
137
+ if metadata_text is not None:
132
138
  if indent > 0:
133
- yaml_text = textwrap.indent(yaml_text, " " * indent)
134
- inner = (
135
- f"{children_text}\n<|{self.name} metadata|>\n{yaml_text}\n</|{self.name} metadata|>"
136
- )
139
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
140
+ inner = f"{children_text}\n<|{self.name} metadata|>\n{metadata_text}\n</|{self.name} metadata|>"
137
141
  else:
138
142
  inner = children_text
139
143
 
@@ -447,13 +451,11 @@ class Transcript(BaseModel):
447
451
  content_str = f"<|T{transcript_idx} blocks|>\n{blocks_str}\n</|T{transcript_idx} blocks|>"
448
452
 
449
453
  # Gather metadata and add to content
450
- yaml_text = yaml_dump_metadata(self.metadata)
451
- if yaml_text is not None:
454
+ metadata_text = dump_metadata(self.metadata)
455
+ if metadata_text is not None:
452
456
  if indent > 0:
453
- yaml_text = textwrap.indent(yaml_text, " " * indent)
454
- content_str += (
455
- f"\n<|T{transcript_idx} metadata|>\n{yaml_text}\n</|T{transcript_idx} metadata|>"
456
- )
457
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
458
+ content_str += f"\n<|T{transcript_idx} metadata|>\n{metadata_text}\n</|T{transcript_idx} metadata|>"
457
459
 
458
460
  # Format content and return
459
461
  if indent > 0:
@@ -4,11 +4,12 @@ import queue
4
4
  import signal
5
5
  import threading
6
6
  import time
7
- from typing import Any, Callable, Coroutine, Optional
7
+ from typing import Any, AsyncGenerator, Callable, Coroutine, Optional
8
8
 
9
9
  import anyio
10
10
  import backoff
11
11
  import httpx
12
+ import orjson
12
13
  from backoff.types import Details
13
14
 
14
15
  from docent._log_util.logger import get_logger
@@ -19,11 +20,16 @@ logger = get_logger(__name__)
19
20
 
20
21
 
21
22
  def _giveup(exc: BaseException) -> bool:
22
- """Give up on client errors."""
23
+ """Give up on timeouts and client errors (4xx except 429). Retry others."""
24
+
25
+ # Give up immediately on any timeout (connect/read/write/pool)
26
+ if isinstance(exc, httpx.TimeoutException):
27
+ return True
23
28
 
24
29
  if isinstance(exc, httpx.HTTPStatusError):
25
30
  status = exc.response.status_code
26
31
  return status < 500 and status != 429
32
+
27
33
  return False
28
34
 
29
35
 
@@ -33,6 +39,15 @@ def _print_backoff_message(e: Details):
33
39
  )
34
40
 
35
41
 
42
+ async def _generate_payload_chunks(runs: list[AgentRun]) -> AsyncGenerator[bytes, None]:
43
+ yield b'{"agent_runs": ['
44
+ for i, ar in enumerate(runs):
45
+ if i > 0:
46
+ yield b","
47
+ yield orjson.dumps(ar.model_dump(mode="json"))
48
+ yield b"]}"
49
+
50
+
36
51
  class AgentRunWriter:
37
52
  """Background thread for logging agent runs.
38
53
 
@@ -92,7 +107,6 @@ class AgentRunWriter:
92
107
  self._thread = threading.Thread(
93
108
  target=lambda: anyio.run(self._async_main),
94
109
  name="AgentRunWriterThread",
95
- daemon=True,
96
110
  )
97
111
  self._thread.start()
98
112
  logger.info("AgentRunWriter thread started")
@@ -171,7 +185,7 @@ class AgentRunWriter:
171
185
  logger.info("Cancelling pending tasks...")
172
186
  self._cancel_event.set()
173
187
  n_pending = self._queue.qsize()
174
- logger.info(f"Cancelled ~{n_pending} pending tasks")
188
+ logger.info(f"Cancelled ~{n_pending} pending runs")
175
189
 
176
190
  # Give a brief moment to exit
177
191
  logger.info("Waiting for thread to exit...")
@@ -179,7 +193,7 @@ class AgentRunWriter:
179
193
 
180
194
  def get_post_batch_fcn(
181
195
  self, client: httpx.AsyncClient
182
- ) -> Callable[[list[AgentRun], anyio.CapacityLimiter], Coroutine[Any, Any, None]]:
196
+ ) -> Callable[[list[AgentRun]], Coroutine[Any, Any, None]]:
183
197
  """Return a function that will post a batch of agent runs to the API."""
184
198
 
185
199
  @backoff.on_exception(
@@ -189,34 +203,37 @@ class AgentRunWriter:
189
203
  max_tries=self._max_retries,
190
204
  on_backoff=_print_backoff_message,
191
205
  )
192
- async def _post_batch(batch: list[AgentRun], limiter: anyio.CapacityLimiter) -> None:
193
- async with limiter:
194
- payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch]}
195
- resp = await client.post(
196
- self._endpoint, json=payload, timeout=self._request_timeout
197
- )
198
- resp.raise_for_status()
206
+ async def _post_batch(batch: list[AgentRun]) -> None:
207
+ resp = await client.post(
208
+ self._endpoint,
209
+ content=_generate_payload_chunks(batch),
210
+ timeout=self._request_timeout,
211
+ )
212
+ resp.raise_for_status()
199
213
 
200
214
  return _post_batch
201
215
 
202
216
  async def _async_main(self) -> None:
203
217
  """Main async function for the AgentRunWriter thread."""
204
218
 
205
- limiter = anyio.CapacityLimiter(self._num_workers)
206
-
207
219
  async with httpx.AsyncClient(base_url=self._base_url, headers=self._headers) as client:
220
+ _post_batch = self.get_post_batch_fcn(client)
208
221
  async with anyio.create_task_group() as tg:
209
- _post_batch = self.get_post_batch_fcn(client)
210
222
 
211
- async def batch_loop() -> None:
223
+ async def worker():
212
224
  while not self._cancel_event.is_set():
213
225
  batch = await self._gather_next_batch_from_queue()
214
226
  if not batch:
215
227
  continue
228
+ try:
229
+ await _post_batch(batch)
230
+ except Exception as e:
231
+ logger.error(
232
+ f"Failed to post batch of {len(batch)} agent runs: {e.__class__.__name__}: {e}"
233
+ )
216
234
 
217
- tg.start_soon(_post_batch, batch, limiter)
218
-
219
- tg.start_soon(batch_loop)
235
+ for _ in range(self._num_workers):
236
+ tg.start_soon(worker)
220
237
 
221
238
  async def _gather_next_batch_from_queue(self) -> list[AgentRun]:
222
239
  """Gather a batch of agent runs from the queue.
@@ -241,6 +258,14 @@ def init(
241
258
  server_url: str = "https://api.docent.transluce.org",
242
259
  web_url: str = "https://docent.transluce.org",
243
260
  api_key: str | None = None,
261
+ # Writer arguments
262
+ num_workers: int = 4,
263
+ queue_maxsize: int = 20_000,
264
+ request_timeout: float = 30.0,
265
+ flush_interval: float = 1.0,
266
+ batch_size: int = 1_000,
267
+ max_retries: int = 5,
268
+ shutdown_timeout: int = 60,
244
269
  ):
245
270
  """Initialize the AgentRunWriter thread.
246
271
 
@@ -250,6 +275,16 @@ def init(
250
275
  server_url (str): URL of the Docent server.
251
276
  web_url (str): URL of the Docent web UI.
252
277
  api_key (str): API key for the Docent API.
278
+ num_workers (int): Max number of concurrent tasks to run,
279
+ managed by anyio.CapacityLimiter.
280
+ queue_maxsize (int): Maximum size of the queue.
281
+ If maxsize is <= 0, the queue size is infinite.
282
+ request_timeout (float): Timeout for the HTTP request.
283
+ flush_interval (float): Interval to flush the queue.
284
+ batch_size (int): Number of agent runs to batch together.
285
+ max_retries (int): Maximum number of retries for the HTTP request.
286
+ shutdown_timeout (int): Timeout to wait for the background thread to finish
287
+ after the main thread has requested shutdown.
253
288
  """
254
289
  api_key = api_key or os.getenv("DOCENT_API_KEY")
255
290
 
@@ -271,4 +306,12 @@ def init(
271
306
  api_key=api_key,
272
307
  collection_id=collection_id,
273
308
  server_url=server_url,
309
+ # Writer arguments
310
+ num_workers=num_workers,
311
+ queue_maxsize=queue_maxsize,
312
+ request_timeout=request_timeout,
313
+ flush_interval=flush_interval,
314
+ batch_size=batch_size,
315
+ max_retries=max_retries,
316
+ shutdown_timeout=shutdown_timeout,
274
317
  )