docent-python 0.1.17a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (45) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -0
  17. docent/data_models/agent_run.py +6 -5
  18. docent/data_models/chat/__init__.py +6 -1
  19. docent/data_models/citation.py +103 -22
  20. docent/data_models/judge.py +19 -0
  21. docent/data_models/metadata_util.py +16 -0
  22. docent/data_models/remove_invalid_citation_ranges.py +23 -10
  23. docent/data_models/transcript.py +20 -16
  24. docent/data_models/util.py +170 -0
  25. docent/judges/__init__.py +23 -0
  26. docent/judges/analysis.py +77 -0
  27. docent/judges/impl.py +587 -0
  28. docent/judges/runner.py +129 -0
  29. docent/judges/stats.py +205 -0
  30. docent/judges/types.py +311 -0
  31. docent/judges/util/forgiving_json.py +108 -0
  32. docent/judges/util/meta_schema.json +86 -0
  33. docent/judges/util/meta_schema.py +29 -0
  34. docent/judges/util/parse_output.py +87 -0
  35. docent/judges/util/voting.py +139 -0
  36. docent/sdk/agent_run_writer.py +62 -19
  37. docent/sdk/client.py +244 -23
  38. docent/trace.py +413 -90
  39. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
  40. docent_python-0.1.27a0.dist-info/RECORD +59 -0
  41. docent/data_models/metadata.py +0 -229
  42. docent/data_models/yaml_util.py +0 -12
  43. docent_python-0.1.17a0.dist-info/RECORD +0 -32
  44. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
  45. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,108 @@
1
+ import json
2
+ from typing import Any
3
+
4
+
5
+ def _repair_json(text: str) -> str:
6
+ """Strip leading/trailing text and fix unescaped quotes/newlines."""
7
+
8
+ json_start = None
9
+ for i, char in enumerate(text):
10
+ remaining = text[i:]
11
+ if (
12
+ char in '[{"'
13
+ or char.isdigit()
14
+ or char == "-"
15
+ or remaining.startswith("null")
16
+ or remaining.startswith("true")
17
+ or remaining.startswith("false")
18
+ ):
19
+ json_start = i
20
+ break
21
+ if json_start is None:
22
+ raise ValueError("No valid JSON start found")
23
+
24
+ result: list[str] = []
25
+ in_string = False
26
+ escape_next = False
27
+ depth = 0
28
+ started_with_container = text[json_start] in "[{"
29
+
30
+ for i in range(json_start, len(text)):
31
+ char = text[i]
32
+
33
+ if escape_next:
34
+ if in_string:
35
+ # Check if this is a valid escape sequence
36
+ is_valid_escape = char in '\\/bfnrt"' or (
37
+ char == "u"
38
+ and i + 4 < len(text)
39
+ and all(c in "0123456789abcdefABCDEF" for c in text[i + 1 : i + 5])
40
+ )
41
+ if not is_valid_escape:
42
+ # Invalid escape sequence - add another backslash to escape it
43
+ result.append("\\")
44
+ result.append(char)
45
+ escape_next = False
46
+ continue
47
+
48
+ if char == "\\":
49
+ result.append(char)
50
+ escape_next = True
51
+ continue
52
+
53
+ if char == '"':
54
+ if in_string:
55
+ # Check if quote should be escaped by looking at what follows
56
+ remaining = text[i + 1 :].lstrip()
57
+ if remaining and remaining[0] not in ':,}]"':
58
+ result.append('\\"')
59
+ continue
60
+ in_string = False
61
+ result.append(char)
62
+ # If we're at depth 0 and closed a top-level string, we're done
63
+ if depth == 0 and not started_with_container:
64
+ return "".join(result)
65
+ else:
66
+ in_string = True
67
+ result.append(char)
68
+ elif in_string and char == "\n":
69
+ result.append("\\n")
70
+ else:
71
+ result.append(char)
72
+
73
+ if not in_string:
74
+ if char in "[{":
75
+ depth += 1
76
+ elif char in "]}":
77
+ depth -= 1
78
+ if depth == 0:
79
+ return "".join(result)
80
+ # For primitives at top level (depth 0), stop at whitespace if we've consumed content
81
+ elif depth == 0 and not started_with_container and result and char in " \t\n\r":
82
+ # Check if this is trailing whitespace after a complete primitive
83
+ current = "".join(result).strip()
84
+ if current:
85
+ try:
86
+ json.loads(current)
87
+ return current
88
+ except (json.JSONDecodeError, ValueError):
89
+ pass
90
+
91
+ return "".join(result)
92
+
93
+
94
+ def forgiving_json_loads(text: str) -> Any:
95
+ """
96
+ Parse JSON from text, applying heuristics to fix common LLM mistakes.
97
+
98
+ Repairs applied:
99
+ - Strip leading/trailing non-JSON text
100
+ - Escape unescaped quotes and newlines inside strings
101
+ - Fix invalid escape sequences inside strings
102
+ """
103
+ if not text or not text.strip():
104
+ raise ValueError("Empty or whitespace-only input")
105
+
106
+ text = _repair_json(text)
107
+
108
+ return json.loads(text)
@@ -0,0 +1,86 @@
1
+ {
2
+ "$schema": "https://json-schema.org/draft/2020-12/schema",
3
+ "$id": "https://example.com/meta/mini-schema",
4
+ "title": "Meta-schema for Docent judge outputs. Makes some restrictions to 2020-12.",
5
+ "type": "object",
6
+ "additionalProperties": false,
7
+ "properties": {
8
+ "type": { "const": "object" },
9
+ "additionalProperties": { "const": false },
10
+ "required": {
11
+ "type": "array",
12
+ "items": { "type": "string" }
13
+ },
14
+
15
+ "properties": {
16
+ "type": "object",
17
+ "propertyNames": { "type": "string" },
18
+ "additionalProperties": {
19
+ "type": "object",
20
+ "additionalProperties": false,
21
+ "required": ["type"],
22
+
23
+ "properties": {
24
+ "type": {
25
+ "type": "string",
26
+ "enum": ["string", "integer", "number", "boolean"]
27
+ },
28
+ "description": {
29
+ "type": "string"
30
+ },
31
+ "citations": {
32
+ "type": "boolean"
33
+ },
34
+ "enum": {
35
+ "type": "array",
36
+ "items": {
37
+ "type": ["string", "integer", "boolean"]
38
+ }
39
+ },
40
+ "format": {
41
+ "type": "string",
42
+ "enum": [
43
+ "date-time",
44
+ "date",
45
+ "time",
46
+ "email",
47
+ "hostname",
48
+ "ipv4",
49
+ "ipv6",
50
+ "uri",
51
+ "uuid"
52
+ ]
53
+ },
54
+ "minLength": {
55
+ "type": "integer",
56
+ "minimum": 0
57
+ },
58
+ "maxLength": {
59
+ "type": "integer",
60
+ "minimum": 0
61
+ },
62
+ "pattern": {
63
+ "type": "string"
64
+ },
65
+ "minimum": {
66
+ "type": "number"
67
+ },
68
+ "maximum": {
69
+ "type": "number"
70
+ },
71
+ "exclusiveMinimum": {
72
+ "type": "number"
73
+ },
74
+ "exclusiveMaximum": {
75
+ "type": "number"
76
+ },
77
+ "multipleOf": {
78
+ "type": "number",
79
+ "exclusiveMinimum": 0
80
+ }
81
+ }
82
+ }
83
+ }
84
+ },
85
+ "required": ["type", "properties"]
86
+ }
@@ -0,0 +1,29 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import jsonschema
6
+
7
+
8
+ def _load_meta_schema() -> dict[str, Any]:
9
+ """Load the rubric meta-schema from the adjacent JSON file."""
10
+ meta_schema_path = Path(__file__).with_suffix(".json")
11
+ with meta_schema_path.open("r", encoding="utf-8") as f:
12
+ return json.load(f)
13
+
14
+
15
+ _META_VALIDATOR = jsonschema.Draft202012Validator(_load_meta_schema())
16
+
17
+
18
+ def validate_judge_result_schema(schema: dict[str, Any]):
19
+ """Validate a proposed schema against the rubric meta-schema.
20
+
21
+ Raises:
22
+ jsonschema.ValidationError: If the schema is invalid
23
+ jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
24
+ """
25
+ # First check that this is a valid 2020-12 schema
26
+ jsonschema.Draft202012Validator.check_schema(schema)
27
+
28
+ # Then check that it conforms to our subset of the 2020-12 schema
29
+ _META_VALIDATOR.validate(schema) # type: ignore
@@ -0,0 +1,87 @@
1
+ from typing import Any, cast
2
+
3
+ import jsonschema
4
+
5
+ from docent._llm_util.data_models.exceptions import ValidationFailedException
6
+ from docent._log_util import get_logger
7
+ from docent.data_models.agent_run import AgentRun
8
+ from docent.data_models.remove_invalid_citation_ranges import remove_invalid_citation_ranges
9
+ from docent.judges.types import traverse_schema_and_transform
10
+ from docent.judges.util.forgiving_json import forgiving_json_loads
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ def _validate_rubric_output(
16
+ output: dict[str, Any], output_schema: dict[str, Any], agent_run: AgentRun
17
+ ) -> dict[str, Any]:
18
+ """Validate and filter citation text ranges in rubric results.
19
+ Also check that the output conforms to the output schema.
20
+
21
+ Args:
22
+ output: Raw results from LLM judge
23
+ agent_run: Agent run containing transcript data for validation
24
+
25
+ Returns:
26
+ Validated result dict with invalid citations removed
27
+
28
+ Raises:
29
+ ValidationFailedException: If validation fails
30
+ """
31
+
32
+ def _validate_citation_string(text: str) -> str:
33
+ validated_text = remove_invalid_citation_ranges(text, agent_run)
34
+ if validated_text != text:
35
+ logger.warning(
36
+ f"Citation validation removed invalid text range from citation in judge result. "
37
+ f"Agent run ID: {agent_run.id}, "
38
+ f"Original text: {text}, "
39
+ f"Validated text: {validated_text}, "
40
+ )
41
+ return validated_text
42
+
43
+ try:
44
+ jsonschema.validate(output, output_schema)
45
+ except jsonschema.ValidationError as e:
46
+ raise ValidationFailedException(f"Schema validation failed: {e}", failed_output=str(output))
47
+
48
+ try:
49
+ return traverse_schema_and_transform(output, output_schema, _validate_citation_string)
50
+ except Exception as e:
51
+ raise ValidationFailedException(
52
+ f"Citation validation failed: {e}", failed_output=str(output)
53
+ )
54
+
55
+
56
+ def parse_and_validate_output_str(
57
+ output_str: str, output_schema: dict[str, Any], agent_run: AgentRun
58
+ ) -> dict[str, Any]:
59
+ """Parse and validate LLM output for rubric evaluation.
60
+
61
+ Args:
62
+ llm_output: The LLM output to parse
63
+ output_schema: The schema to validate against
64
+ agent_run: Agent run for citation validation
65
+
66
+ Returns:
67
+ Validated output dict
68
+
69
+ Raises:
70
+ ValidationFailedException: If parsing or validation fails
71
+ """
72
+
73
+ try:
74
+ output = forgiving_json_loads(output_str)
75
+ except Exception as e:
76
+ raise ValidationFailedException(
77
+ f"Failed to parse JSON: {e}. Raw text: `{output_str}`",
78
+ failed_output=output_str,
79
+ )
80
+
81
+ if not isinstance(output, dict):
82
+ raise ValidationFailedException(
83
+ f"Expected dict output, got {type(output)}. Raw text: {output_str}",
84
+ failed_output=output_str,
85
+ )
86
+
87
+ return _validate_rubric_output(cast(dict[str, Any], output), output_schema, agent_run)
@@ -0,0 +1,139 @@
1
+ from collections import Counter
2
+ from typing import Any, TypedDict, cast
3
+
4
+ import numpy as np
5
+
6
+
7
+ class EstimateWithCI(TypedDict):
8
+ mean: float
9
+ var: float
10
+ n: int
11
+ ci_95: float
12
+
13
+
14
+ JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
15
+
16
+
17
+ def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
18
+ """Get list of top-level keys in schema that we want to measure agreement on.
19
+
20
+ This includes enum and bool fields.
21
+
22
+ Args:
23
+ schema: JSON schema dict
24
+
25
+ Returns:
26
+ List of field names (keys) that should be used for measuring agreement
27
+ """
28
+ agreement_keys: list[str] = []
29
+
30
+ properties = schema.get("properties", {})
31
+ assert isinstance(properties, dict)
32
+ properties = cast(dict[str, Any], properties)
33
+
34
+ for key, field_schema in properties.items():
35
+ assert isinstance(field_schema, dict)
36
+ field_schema = cast(dict[str, Any], field_schema)
37
+
38
+ field_type = field_schema.get("type")
39
+ assert isinstance(field_type, str)
40
+
41
+ # Include boolean fields
42
+ if field_type == "boolean":
43
+ agreement_keys.append(key)
44
+ # Include enum fields (strings and numbers must be in this category)
45
+ elif "enum" in field_schema:
46
+ agreement_keys.append(key)
47
+
48
+ return agreement_keys
49
+
50
+
51
+ def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
52
+ """Find the result that best matches modal values across agreement keys.
53
+
54
+ Args:
55
+ indep_results: List of independent results to analyze
56
+ agreement_keys: Keys to measure agreement on
57
+
58
+ Returns:
59
+ Tuple of (max_idx, agt_key_modes_and_counts) where:
60
+ - max_idx is the index of the result that best matches modal values
61
+ - agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
62
+
63
+ Raises:
64
+ ValueError: If no results are provided
65
+ """
66
+ if not indep_results:
67
+ raise ValueError("No results to score")
68
+
69
+ # For each agreement key, compute the mode and count (or None, if no values exist for that key)
70
+ agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
71
+ for key in agreement_keys:
72
+ key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
73
+ if most_common_one := key_modes.most_common(1):
74
+ agt_key_modes_and_counts[key] = most_common_one[0]
75
+ else:
76
+ agt_key_modes_and_counts[key] = None
77
+
78
+ # Score each rollout based on how many agreement keys they match
79
+ # If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
80
+ # TODO(mengk): This may bias towards results that have more keys.
81
+ indep_result_scores: list[int] = []
82
+ for r in indep_results:
83
+ score = 0
84
+ for key in agreement_keys:
85
+ mode_and_count = agt_key_modes_and_counts[key]
86
+ if mode_and_count and r.get(key) == mode_and_count[0]:
87
+ score += 1
88
+ indep_result_scores.append(score)
89
+
90
+ # Argmax
91
+ max_idx = indep_result_scores.index(max(indep_result_scores))
92
+
93
+ return max_idx, agt_key_modes_and_counts
94
+
95
+
96
+ def compute_output_distributions(
97
+ indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
98
+ ):
99
+ def _get_possible_values(key: str) -> list[str | bool | int | float]:
100
+ if "enum" in output_schema.get("properties", {}).get(key, {}):
101
+ return output_schema.get("properties", {}).get(key, {}).get("enum", [])
102
+ elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
103
+ return [True, False]
104
+ else:
105
+ return []
106
+
107
+ raw_counts: dict[str, dict[str | bool | int | float, int]] = {
108
+ key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
109
+ }
110
+ # Collect counts for each possible value
111
+ for result in indep_results:
112
+ for key in agreement_keys:
113
+ if (value := result.get(key)) is not None: # Could be none if the key is optional
114
+ assert (
115
+ value in raw_counts[key]
116
+ ), "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
117
+ raw_counts[key][value] += 1
118
+
119
+ distributions: dict[str, JudgeOutputDistribution] = {}
120
+ for agt_key in agreement_keys:
121
+ distributions[agt_key] = {}
122
+
123
+ # First normalize the counts to get probabilities
124
+ counts = raw_counts[agt_key]
125
+ total = sum(counts.values())
126
+ probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
127
+
128
+ for output_key, value in probs.items():
129
+ mean, estimate_var = value, (value * (1 - value))
130
+ # TODO(mengk): change to the wilson score interval
131
+ ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
132
+ distributions[agt_key][output_key] = {
133
+ "mean": mean,
134
+ "var": estimate_var,
135
+ "n": total,
136
+ "ci_95": ci_95,
137
+ }
138
+
139
+ return distributions
@@ -4,11 +4,12 @@ import queue
4
4
  import signal
5
5
  import threading
6
6
  import time
7
- from typing import Any, Callable, Coroutine, Optional
7
+ from typing import Any, AsyncGenerator, Callable, Coroutine, Optional
8
8
 
9
9
  import anyio
10
10
  import backoff
11
11
  import httpx
12
+ import orjson
12
13
  from backoff.types import Details
13
14
 
14
15
  from docent._log_util.logger import get_logger
@@ -19,11 +20,16 @@ logger = get_logger(__name__)
19
20
 
20
21
 
21
22
  def _giveup(exc: BaseException) -> bool:
22
- """Give up on client errors."""
23
+ """Give up on timeouts and client errors (4xx except 429). Retry others."""
24
+
25
+ # Give up immediately on any timeout (connect/read/write/pool)
26
+ if isinstance(exc, httpx.TimeoutException):
27
+ return True
23
28
 
24
29
  if isinstance(exc, httpx.HTTPStatusError):
25
30
  status = exc.response.status_code
26
31
  return status < 500 and status != 429
32
+
27
33
  return False
28
34
 
29
35
 
@@ -33,6 +39,15 @@ def _print_backoff_message(e: Details):
33
39
  )
34
40
 
35
41
 
42
+ async def _generate_payload_chunks(runs: list[AgentRun]) -> AsyncGenerator[bytes, None]:
43
+ yield b'{"agent_runs": ['
44
+ for i, ar in enumerate(runs):
45
+ if i > 0:
46
+ yield b","
47
+ yield orjson.dumps(ar.model_dump(mode="json"))
48
+ yield b"]}"
49
+
50
+
36
51
  class AgentRunWriter:
37
52
  """Background thread for logging agent runs.
38
53
 
@@ -92,7 +107,6 @@ class AgentRunWriter:
92
107
  self._thread = threading.Thread(
93
108
  target=lambda: anyio.run(self._async_main),
94
109
  name="AgentRunWriterThread",
95
- daemon=True,
96
110
  )
97
111
  self._thread.start()
98
112
  logger.info("AgentRunWriter thread started")
@@ -171,7 +185,7 @@ class AgentRunWriter:
171
185
  logger.info("Cancelling pending tasks...")
172
186
  self._cancel_event.set()
173
187
  n_pending = self._queue.qsize()
174
- logger.info(f"Cancelled ~{n_pending} pending tasks")
188
+ logger.info(f"Cancelled ~{n_pending} pending runs")
175
189
 
176
190
  # Give a brief moment to exit
177
191
  logger.info("Waiting for thread to exit...")
@@ -179,7 +193,7 @@ class AgentRunWriter:
179
193
 
180
194
  def get_post_batch_fcn(
181
195
  self, client: httpx.AsyncClient
182
- ) -> Callable[[list[AgentRun], anyio.CapacityLimiter], Coroutine[Any, Any, None]]:
196
+ ) -> Callable[[list[AgentRun]], Coroutine[Any, Any, None]]:
183
197
  """Return a function that will post a batch of agent runs to the API."""
184
198
 
185
199
  @backoff.on_exception(
@@ -189,34 +203,37 @@ class AgentRunWriter:
189
203
  max_tries=self._max_retries,
190
204
  on_backoff=_print_backoff_message,
191
205
  )
192
- async def _post_batch(batch: list[AgentRun], limiter: anyio.CapacityLimiter) -> None:
193
- async with limiter:
194
- payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch]}
195
- resp = await client.post(
196
- self._endpoint, json=payload, timeout=self._request_timeout
197
- )
198
- resp.raise_for_status()
206
+ async def _post_batch(batch: list[AgentRun]) -> None:
207
+ resp = await client.post(
208
+ self._endpoint,
209
+ content=_generate_payload_chunks(batch),
210
+ timeout=self._request_timeout,
211
+ )
212
+ resp.raise_for_status()
199
213
 
200
214
  return _post_batch
201
215
 
202
216
  async def _async_main(self) -> None:
203
217
  """Main async function for the AgentRunWriter thread."""
204
218
 
205
- limiter = anyio.CapacityLimiter(self._num_workers)
206
-
207
219
  async with httpx.AsyncClient(base_url=self._base_url, headers=self._headers) as client:
220
+ _post_batch = self.get_post_batch_fcn(client)
208
221
  async with anyio.create_task_group() as tg:
209
- _post_batch = self.get_post_batch_fcn(client)
210
222
 
211
- async def batch_loop() -> None:
223
+ async def worker():
212
224
  while not self._cancel_event.is_set():
213
225
  batch = await self._gather_next_batch_from_queue()
214
226
  if not batch:
215
227
  continue
228
+ try:
229
+ await _post_batch(batch)
230
+ except Exception as e:
231
+ logger.error(
232
+ f"Failed to post batch of {len(batch)} agent runs: {e.__class__.__name__}: {e}"
233
+ )
216
234
 
217
- tg.start_soon(_post_batch, batch, limiter)
218
-
219
- tg.start_soon(batch_loop)
235
+ for _ in range(self._num_workers):
236
+ tg.start_soon(worker)
220
237
 
221
238
  async def _gather_next_batch_from_queue(self) -> list[AgentRun]:
222
239
  """Gather a batch of agent runs from the queue.
@@ -241,6 +258,14 @@ def init(
241
258
  server_url: str = "https://api.docent.transluce.org",
242
259
  web_url: str = "https://docent.transluce.org",
243
260
  api_key: str | None = None,
261
+ # Writer arguments
262
+ num_workers: int = 4,
263
+ queue_maxsize: int = 20_000,
264
+ request_timeout: float = 30.0,
265
+ flush_interval: float = 1.0,
266
+ batch_size: int = 1_000,
267
+ max_retries: int = 5,
268
+ shutdown_timeout: int = 60,
244
269
  ):
245
270
  """Initialize the AgentRunWriter thread.
246
271
 
@@ -250,6 +275,16 @@ def init(
250
275
  server_url (str): URL of the Docent server.
251
276
  web_url (str): URL of the Docent web UI.
252
277
  api_key (str): API key for the Docent API.
278
+ num_workers (int): Max number of concurrent tasks to run,
279
+ managed by anyio.CapacityLimiter.
280
+ queue_maxsize (int): Maximum size of the queue.
281
+ If maxsize is <= 0, the queue size is infinite.
282
+ request_timeout (float): Timeout for the HTTP request.
283
+ flush_interval (float): Interval to flush the queue.
284
+ batch_size (int): Number of agent runs to batch together.
285
+ max_retries (int): Maximum number of retries for the HTTP request.
286
+ shutdown_timeout (int): Timeout to wait for the background thread to finish
287
+ after the main thread has requested shutdown.
253
288
  """
254
289
  api_key = api_key or os.getenv("DOCENT_API_KEY")
255
290
 
@@ -271,4 +306,12 @@ def init(
271
306
  api_key=api_key,
272
307
  collection_id=collection_id,
273
308
  server_url=server_url,
309
+ # Writer arguments
310
+ num_workers=num_workers,
311
+ queue_maxsize=queue_maxsize,
312
+ request_timeout=request_timeout,
313
+ flush_interval=flush_interval,
314
+ batch_size=batch_size,
315
+ max_retries=max_retries,
316
+ shutdown_timeout=shutdown_timeout,
274
317
  )