docent-python 0.1.14a0__py3-none-any.whl → 0.1.28a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +130 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/data_models/__init__.py +2 -0
- docent/data_models/agent_run.py +17 -29
- docent/data_models/chat/__init__.py +6 -1
- docent/data_models/chat/message.py +3 -1
- docent/data_models/citation.py +103 -22
- docent/data_models/judge.py +19 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/remove_invalid_citation_ranges.py +23 -10
- docent/data_models/transcript.py +25 -80
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +311 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +87 -0
- docent/judges/util/voting.py +139 -0
- docent/sdk/agent_run_writer.py +72 -21
- docent/sdk/client.py +276 -23
- docent/trace.py +413 -90
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/METADATA +13 -5
- docent_python-0.1.28a0.dist-info/RECORD +59 -0
- docent/data_models/metadata.py +0 -229
- docent/data_models/yaml_util.py +0 -12
- docent_python-0.1.14a0.dist-info/RECORD +0 -32
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/WHEEL +0 -0
- {docent_python-0.1.14a0.dist-info → docent_python-0.1.28a0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _repair_json(text: str) -> str:
|
|
6
|
+
"""Strip leading/trailing text and fix unescaped quotes/newlines."""
|
|
7
|
+
|
|
8
|
+
json_start = None
|
|
9
|
+
for i, char in enumerate(text):
|
|
10
|
+
remaining = text[i:]
|
|
11
|
+
if (
|
|
12
|
+
char in '[{"'
|
|
13
|
+
or char.isdigit()
|
|
14
|
+
or char == "-"
|
|
15
|
+
or remaining.startswith("null")
|
|
16
|
+
or remaining.startswith("true")
|
|
17
|
+
or remaining.startswith("false")
|
|
18
|
+
):
|
|
19
|
+
json_start = i
|
|
20
|
+
break
|
|
21
|
+
if json_start is None:
|
|
22
|
+
raise ValueError("No valid JSON start found")
|
|
23
|
+
|
|
24
|
+
result: list[str] = []
|
|
25
|
+
in_string = False
|
|
26
|
+
escape_next = False
|
|
27
|
+
depth = 0
|
|
28
|
+
started_with_container = text[json_start] in "[{"
|
|
29
|
+
|
|
30
|
+
for i in range(json_start, len(text)):
|
|
31
|
+
char = text[i]
|
|
32
|
+
|
|
33
|
+
if escape_next:
|
|
34
|
+
if in_string:
|
|
35
|
+
# Check if this is a valid escape sequence
|
|
36
|
+
is_valid_escape = char in '\\/bfnrt"' or (
|
|
37
|
+
char == "u"
|
|
38
|
+
and i + 4 < len(text)
|
|
39
|
+
and all(c in "0123456789abcdefABCDEF" for c in text[i + 1 : i + 5])
|
|
40
|
+
)
|
|
41
|
+
if not is_valid_escape:
|
|
42
|
+
# Invalid escape sequence - add another backslash to escape it
|
|
43
|
+
result.append("\\")
|
|
44
|
+
result.append(char)
|
|
45
|
+
escape_next = False
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
if char == "\\":
|
|
49
|
+
result.append(char)
|
|
50
|
+
escape_next = True
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if char == '"':
|
|
54
|
+
if in_string:
|
|
55
|
+
# Check if quote should be escaped by looking at what follows
|
|
56
|
+
remaining = text[i + 1 :].lstrip()
|
|
57
|
+
if remaining and remaining[0] not in ':,}]"':
|
|
58
|
+
result.append('\\"')
|
|
59
|
+
continue
|
|
60
|
+
in_string = False
|
|
61
|
+
result.append(char)
|
|
62
|
+
# If we're at depth 0 and closed a top-level string, we're done
|
|
63
|
+
if depth == 0 and not started_with_container:
|
|
64
|
+
return "".join(result)
|
|
65
|
+
else:
|
|
66
|
+
in_string = True
|
|
67
|
+
result.append(char)
|
|
68
|
+
elif in_string and char == "\n":
|
|
69
|
+
result.append("\\n")
|
|
70
|
+
else:
|
|
71
|
+
result.append(char)
|
|
72
|
+
|
|
73
|
+
if not in_string:
|
|
74
|
+
if char in "[{":
|
|
75
|
+
depth += 1
|
|
76
|
+
elif char in "]}":
|
|
77
|
+
depth -= 1
|
|
78
|
+
if depth == 0:
|
|
79
|
+
return "".join(result)
|
|
80
|
+
# For primitives at top level (depth 0), stop at whitespace if we've consumed content
|
|
81
|
+
elif depth == 0 and not started_with_container and result and char in " \t\n\r":
|
|
82
|
+
# Check if this is trailing whitespace after a complete primitive
|
|
83
|
+
current = "".join(result).strip()
|
|
84
|
+
if current:
|
|
85
|
+
try:
|
|
86
|
+
json.loads(current)
|
|
87
|
+
return current
|
|
88
|
+
except (json.JSONDecodeError, ValueError):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
return "".join(result)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def forgiving_json_loads(text: str) -> Any:
|
|
95
|
+
"""
|
|
96
|
+
Parse JSON from text, applying heuristics to fix common LLM mistakes.
|
|
97
|
+
|
|
98
|
+
Repairs applied:
|
|
99
|
+
- Strip leading/trailing non-JSON text
|
|
100
|
+
- Escape unescaped quotes and newlines inside strings
|
|
101
|
+
- Fix invalid escape sequences inside strings
|
|
102
|
+
"""
|
|
103
|
+
if not text or not text.strip():
|
|
104
|
+
raise ValueError("Empty or whitespace-only input")
|
|
105
|
+
|
|
106
|
+
text = _repair_json(text)
|
|
107
|
+
|
|
108
|
+
return json.loads(text)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
{
|
|
2
|
+
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
3
|
+
"$id": "https://example.com/meta/mini-schema",
|
|
4
|
+
"title": "Meta-schema for Docent judge outputs. Makes some restrictions to 2020-12.",
|
|
5
|
+
"type": "object",
|
|
6
|
+
"additionalProperties": false,
|
|
7
|
+
"properties": {
|
|
8
|
+
"type": { "const": "object" },
|
|
9
|
+
"additionalProperties": { "const": false },
|
|
10
|
+
"required": {
|
|
11
|
+
"type": "array",
|
|
12
|
+
"items": { "type": "string" }
|
|
13
|
+
},
|
|
14
|
+
|
|
15
|
+
"properties": {
|
|
16
|
+
"type": "object",
|
|
17
|
+
"propertyNames": { "type": "string" },
|
|
18
|
+
"additionalProperties": {
|
|
19
|
+
"type": "object",
|
|
20
|
+
"additionalProperties": false,
|
|
21
|
+
"required": ["type"],
|
|
22
|
+
|
|
23
|
+
"properties": {
|
|
24
|
+
"type": {
|
|
25
|
+
"type": "string",
|
|
26
|
+
"enum": ["string", "integer", "number", "boolean"]
|
|
27
|
+
},
|
|
28
|
+
"description": {
|
|
29
|
+
"type": "string"
|
|
30
|
+
},
|
|
31
|
+
"citations": {
|
|
32
|
+
"type": "boolean"
|
|
33
|
+
},
|
|
34
|
+
"enum": {
|
|
35
|
+
"type": "array",
|
|
36
|
+
"items": {
|
|
37
|
+
"type": ["string", "integer", "boolean"]
|
|
38
|
+
}
|
|
39
|
+
},
|
|
40
|
+
"format": {
|
|
41
|
+
"type": "string",
|
|
42
|
+
"enum": [
|
|
43
|
+
"date-time",
|
|
44
|
+
"date",
|
|
45
|
+
"time",
|
|
46
|
+
"email",
|
|
47
|
+
"hostname",
|
|
48
|
+
"ipv4",
|
|
49
|
+
"ipv6",
|
|
50
|
+
"uri",
|
|
51
|
+
"uuid"
|
|
52
|
+
]
|
|
53
|
+
},
|
|
54
|
+
"minLength": {
|
|
55
|
+
"type": "integer",
|
|
56
|
+
"minimum": 0
|
|
57
|
+
},
|
|
58
|
+
"maxLength": {
|
|
59
|
+
"type": "integer",
|
|
60
|
+
"minimum": 0
|
|
61
|
+
},
|
|
62
|
+
"pattern": {
|
|
63
|
+
"type": "string"
|
|
64
|
+
},
|
|
65
|
+
"minimum": {
|
|
66
|
+
"type": "number"
|
|
67
|
+
},
|
|
68
|
+
"maximum": {
|
|
69
|
+
"type": "number"
|
|
70
|
+
},
|
|
71
|
+
"exclusiveMinimum": {
|
|
72
|
+
"type": "number"
|
|
73
|
+
},
|
|
74
|
+
"exclusiveMaximum": {
|
|
75
|
+
"type": "number"
|
|
76
|
+
},
|
|
77
|
+
"multipleOf": {
|
|
78
|
+
"type": "number",
|
|
79
|
+
"exclusiveMinimum": 0
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
},
|
|
85
|
+
"required": ["type", "properties"]
|
|
86
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jsonschema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _load_meta_schema() -> dict[str, Any]:
|
|
9
|
+
"""Load the rubric meta-schema from the adjacent JSON file."""
|
|
10
|
+
meta_schema_path = Path(__file__).with_suffix(".json")
|
|
11
|
+
with meta_schema_path.open("r", encoding="utf-8") as f:
|
|
12
|
+
return json.load(f)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_META_VALIDATOR = jsonschema.Draft202012Validator(_load_meta_schema())
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def validate_judge_result_schema(schema: dict[str, Any]):
|
|
19
|
+
"""Validate a proposed schema against the rubric meta-schema.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
jsonschema.ValidationError: If the schema is invalid
|
|
23
|
+
jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
|
|
24
|
+
"""
|
|
25
|
+
# First check that this is a valid 2020-12 schema
|
|
26
|
+
jsonschema.Draft202012Validator.check_schema(schema)
|
|
27
|
+
|
|
28
|
+
# Then check that it conforms to our subset of the 2020-12 schema
|
|
29
|
+
_META_VALIDATOR.validate(schema) # type: ignore
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from typing import Any, cast
|
|
2
|
+
|
|
3
|
+
import jsonschema
|
|
4
|
+
|
|
5
|
+
from docent._llm_util.data_models.exceptions import ValidationFailedException
|
|
6
|
+
from docent._log_util import get_logger
|
|
7
|
+
from docent.data_models.agent_run import AgentRun
|
|
8
|
+
from docent.data_models.remove_invalid_citation_ranges import remove_invalid_citation_ranges
|
|
9
|
+
from docent.judges.types import traverse_schema_and_transform
|
|
10
|
+
from docent.judges.util.forgiving_json import forgiving_json_loads
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _validate_rubric_output(
|
|
16
|
+
output: dict[str, Any], output_schema: dict[str, Any], agent_run: AgentRun
|
|
17
|
+
) -> dict[str, Any]:
|
|
18
|
+
"""Validate and filter citation text ranges in rubric results.
|
|
19
|
+
Also check that the output conforms to the output schema.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
output: Raw results from LLM judge
|
|
23
|
+
agent_run: Agent run containing transcript data for validation
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Validated result dict with invalid citations removed
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValidationFailedException: If validation fails
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def _validate_citation_string(text: str) -> str:
|
|
33
|
+
validated_text = remove_invalid_citation_ranges(text, agent_run)
|
|
34
|
+
if validated_text != text:
|
|
35
|
+
logger.warning(
|
|
36
|
+
f"Citation validation removed invalid text range from citation in judge result. "
|
|
37
|
+
f"Agent run ID: {agent_run.id}, "
|
|
38
|
+
f"Original text: {text}, "
|
|
39
|
+
f"Validated text: {validated_text}, "
|
|
40
|
+
)
|
|
41
|
+
return validated_text
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
jsonschema.validate(output, output_schema)
|
|
45
|
+
except jsonschema.ValidationError as e:
|
|
46
|
+
raise ValidationFailedException(f"Schema validation failed: {e}", failed_output=str(output))
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
return traverse_schema_and_transform(output, output_schema, _validate_citation_string)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise ValidationFailedException(
|
|
52
|
+
f"Citation validation failed: {e}", failed_output=str(output)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def parse_and_validate_output_str(
|
|
57
|
+
output_str: str, output_schema: dict[str, Any], agent_run: AgentRun
|
|
58
|
+
) -> dict[str, Any]:
|
|
59
|
+
"""Parse and validate LLM output for rubric evaluation.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
llm_output: The LLM output to parse
|
|
63
|
+
output_schema: The schema to validate against
|
|
64
|
+
agent_run: Agent run for citation validation
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Validated output dict
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ValidationFailedException: If parsing or validation fails
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
output = forgiving_json_loads(output_str)
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise ValidationFailedException(
|
|
77
|
+
f"Failed to parse JSON: {e}. Raw text: `{output_str}`",
|
|
78
|
+
failed_output=output_str,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if not isinstance(output, dict):
|
|
82
|
+
raise ValidationFailedException(
|
|
83
|
+
f"Expected dict output, got {type(output)}. Raw text: {output_str}",
|
|
84
|
+
failed_output=output_str,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return _validate_rubric_output(cast(dict[str, Any], output), output_schema, agent_run)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from collections import Counter
|
|
2
|
+
from typing import Any, TypedDict, cast
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EstimateWithCI(TypedDict):
|
|
8
|
+
mean: float
|
|
9
|
+
var: float
|
|
10
|
+
n: int
|
|
11
|
+
ci_95: float
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
18
|
+
"""Get list of top-level keys in schema that we want to measure agreement on.
|
|
19
|
+
|
|
20
|
+
This includes enum and bool fields.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
schema: JSON schema dict
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
List of field names (keys) that should be used for measuring agreement
|
|
27
|
+
"""
|
|
28
|
+
agreement_keys: list[str] = []
|
|
29
|
+
|
|
30
|
+
properties = schema.get("properties", {})
|
|
31
|
+
assert isinstance(properties, dict)
|
|
32
|
+
properties = cast(dict[str, Any], properties)
|
|
33
|
+
|
|
34
|
+
for key, field_schema in properties.items():
|
|
35
|
+
assert isinstance(field_schema, dict)
|
|
36
|
+
field_schema = cast(dict[str, Any], field_schema)
|
|
37
|
+
|
|
38
|
+
field_type = field_schema.get("type")
|
|
39
|
+
assert isinstance(field_type, str)
|
|
40
|
+
|
|
41
|
+
# Include boolean fields
|
|
42
|
+
if field_type == "boolean":
|
|
43
|
+
agreement_keys.append(key)
|
|
44
|
+
# Include enum fields (strings and numbers must be in this category)
|
|
45
|
+
elif "enum" in field_schema:
|
|
46
|
+
agreement_keys.append(key)
|
|
47
|
+
|
|
48
|
+
return agreement_keys
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
|
|
52
|
+
"""Find the result that best matches modal values across agreement keys.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
indep_results: List of independent results to analyze
|
|
56
|
+
agreement_keys: Keys to measure agreement on
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tuple of (max_idx, agt_key_modes_and_counts) where:
|
|
60
|
+
- max_idx is the index of the result that best matches modal values
|
|
61
|
+
- agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If no results are provided
|
|
65
|
+
"""
|
|
66
|
+
if not indep_results:
|
|
67
|
+
raise ValueError("No results to score")
|
|
68
|
+
|
|
69
|
+
# For each agreement key, compute the mode and count (or None, if no values exist for that key)
|
|
70
|
+
agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
|
|
71
|
+
for key in agreement_keys:
|
|
72
|
+
key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
|
|
73
|
+
if most_common_one := key_modes.most_common(1):
|
|
74
|
+
agt_key_modes_and_counts[key] = most_common_one[0]
|
|
75
|
+
else:
|
|
76
|
+
agt_key_modes_and_counts[key] = None
|
|
77
|
+
|
|
78
|
+
# Score each rollout based on how many agreement keys they match
|
|
79
|
+
# If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
|
|
80
|
+
# TODO(mengk): This may bias towards results that have more keys.
|
|
81
|
+
indep_result_scores: list[int] = []
|
|
82
|
+
for r in indep_results:
|
|
83
|
+
score = 0
|
|
84
|
+
for key in agreement_keys:
|
|
85
|
+
mode_and_count = agt_key_modes_and_counts[key]
|
|
86
|
+
if mode_and_count and r.get(key) == mode_and_count[0]:
|
|
87
|
+
score += 1
|
|
88
|
+
indep_result_scores.append(score)
|
|
89
|
+
|
|
90
|
+
# Argmax
|
|
91
|
+
max_idx = indep_result_scores.index(max(indep_result_scores))
|
|
92
|
+
|
|
93
|
+
return max_idx, agt_key_modes_and_counts
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_output_distributions(
|
|
97
|
+
indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
|
|
98
|
+
):
|
|
99
|
+
def _get_possible_values(key: str) -> list[str | bool | int | float]:
|
|
100
|
+
if "enum" in output_schema.get("properties", {}).get(key, {}):
|
|
101
|
+
return output_schema.get("properties", {}).get(key, {}).get("enum", [])
|
|
102
|
+
elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
|
|
103
|
+
return [True, False]
|
|
104
|
+
else:
|
|
105
|
+
return []
|
|
106
|
+
|
|
107
|
+
raw_counts: dict[str, dict[str | bool | int | float, int]] = {
|
|
108
|
+
key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
|
|
109
|
+
}
|
|
110
|
+
# Collect counts for each possible value
|
|
111
|
+
for result in indep_results:
|
|
112
|
+
for key in agreement_keys:
|
|
113
|
+
if (value := result.get(key)) is not None: # Could be none if the key is optional
|
|
114
|
+
assert (
|
|
115
|
+
value in raw_counts[key]
|
|
116
|
+
), "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
|
|
117
|
+
raw_counts[key][value] += 1
|
|
118
|
+
|
|
119
|
+
distributions: dict[str, JudgeOutputDistribution] = {}
|
|
120
|
+
for agt_key in agreement_keys:
|
|
121
|
+
distributions[agt_key] = {}
|
|
122
|
+
|
|
123
|
+
# First normalize the counts to get probabilities
|
|
124
|
+
counts = raw_counts[agt_key]
|
|
125
|
+
total = sum(counts.values())
|
|
126
|
+
probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
|
|
127
|
+
|
|
128
|
+
for output_key, value in probs.items():
|
|
129
|
+
mean, estimate_var = value, (value * (1 - value))
|
|
130
|
+
# TODO(mengk): change to the wilson score interval
|
|
131
|
+
ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
|
|
132
|
+
distributions[agt_key][output_key] = {
|
|
133
|
+
"mean": mean,
|
|
134
|
+
"var": estimate_var,
|
|
135
|
+
"n": total,
|
|
136
|
+
"ci_95": ci_95,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return distributions
|
docent/sdk/agent_run_writer.py
CHANGED
|
@@ -4,11 +4,12 @@ import queue
|
|
|
4
4
|
import signal
|
|
5
5
|
import threading
|
|
6
6
|
import time
|
|
7
|
-
from typing import Any, Callable, Coroutine, Optional
|
|
7
|
+
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional
|
|
8
8
|
|
|
9
9
|
import anyio
|
|
10
10
|
import backoff
|
|
11
11
|
import httpx
|
|
12
|
+
import orjson
|
|
12
13
|
from backoff.types import Details
|
|
13
14
|
|
|
14
15
|
from docent._log_util.logger import get_logger
|
|
@@ -19,11 +20,16 @@ logger = get_logger(__name__)
|
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def _giveup(exc: BaseException) -> bool:
|
|
22
|
-
"""Give up on client errors."""
|
|
23
|
+
"""Give up on timeouts and client errors (4xx except 429). Retry others."""
|
|
24
|
+
|
|
25
|
+
# Give up immediately on any timeout (connect/read/write/pool)
|
|
26
|
+
if isinstance(exc, httpx.TimeoutException):
|
|
27
|
+
return True
|
|
23
28
|
|
|
24
29
|
if isinstance(exc, httpx.HTTPStatusError):
|
|
25
30
|
status = exc.response.status_code
|
|
26
31
|
return status < 500 and status != 429
|
|
32
|
+
|
|
27
33
|
return False
|
|
28
34
|
|
|
29
35
|
|
|
@@ -33,6 +39,15 @@ def _print_backoff_message(e: Details):
|
|
|
33
39
|
)
|
|
34
40
|
|
|
35
41
|
|
|
42
|
+
async def _generate_payload_chunks(runs: list[AgentRun]) -> AsyncGenerator[bytes, None]:
|
|
43
|
+
yield b'{"agent_runs": ['
|
|
44
|
+
for i, ar in enumerate(runs):
|
|
45
|
+
if i > 0:
|
|
46
|
+
yield b","
|
|
47
|
+
yield orjson.dumps(ar.model_dump(mode="json"))
|
|
48
|
+
yield b"]}"
|
|
49
|
+
|
|
50
|
+
|
|
36
51
|
class AgentRunWriter:
|
|
37
52
|
"""Background thread for logging agent runs.
|
|
38
53
|
|
|
@@ -92,7 +107,6 @@ class AgentRunWriter:
|
|
|
92
107
|
self._thread = threading.Thread(
|
|
93
108
|
target=lambda: anyio.run(self._async_main),
|
|
94
109
|
name="AgentRunWriterThread",
|
|
95
|
-
daemon=True,
|
|
96
110
|
)
|
|
97
111
|
self._thread.start()
|
|
98
112
|
logger.info("AgentRunWriter thread started")
|
|
@@ -105,9 +119,17 @@ class AgentRunWriter:
|
|
|
105
119
|
# Register shutdown hooks
|
|
106
120
|
atexit.register(self.finish)
|
|
107
121
|
|
|
122
|
+
def _handle_sigint(s: int, f: object) -> None:
|
|
123
|
+
self._shutdown()
|
|
124
|
+
raise KeyboardInterrupt
|
|
125
|
+
|
|
126
|
+
def _handle_sigterm(s: int, f: object) -> None:
|
|
127
|
+
self._shutdown()
|
|
128
|
+
raise SystemExit(0)
|
|
129
|
+
|
|
108
130
|
# Register signal handlers for graceful shutdown
|
|
109
|
-
signal.signal(signal.SIGINT,
|
|
110
|
-
signal.signal(signal.SIGTERM,
|
|
131
|
+
signal.signal(signal.SIGINT, _handle_sigint) # Ctrl+C
|
|
132
|
+
signal.signal(signal.SIGTERM, _handle_sigterm) # Kill signal
|
|
111
133
|
|
|
112
134
|
def log_agent_runs(self, agent_runs: list[AgentRun]) -> None:
|
|
113
135
|
"""Put a list of AgentRun objects into the queue.
|
|
@@ -163,7 +185,7 @@ class AgentRunWriter:
|
|
|
163
185
|
logger.info("Cancelling pending tasks...")
|
|
164
186
|
self._cancel_event.set()
|
|
165
187
|
n_pending = self._queue.qsize()
|
|
166
|
-
logger.info(f"Cancelled ~{n_pending} pending
|
|
188
|
+
logger.info(f"Cancelled ~{n_pending} pending runs")
|
|
167
189
|
|
|
168
190
|
# Give a brief moment to exit
|
|
169
191
|
logger.info("Waiting for thread to exit...")
|
|
@@ -171,7 +193,7 @@ class AgentRunWriter:
|
|
|
171
193
|
|
|
172
194
|
def get_post_batch_fcn(
|
|
173
195
|
self, client: httpx.AsyncClient
|
|
174
|
-
) -> Callable[[list[AgentRun]
|
|
196
|
+
) -> Callable[[list[AgentRun]], Coroutine[Any, Any, None]]:
|
|
175
197
|
"""Return a function that will post a batch of agent runs to the API."""
|
|
176
198
|
|
|
177
199
|
@backoff.on_exception(
|
|
@@ -181,34 +203,37 @@ class AgentRunWriter:
|
|
|
181
203
|
max_tries=self._max_retries,
|
|
182
204
|
on_backoff=_print_backoff_message,
|
|
183
205
|
)
|
|
184
|
-
async def _post_batch(batch: list[AgentRun]
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
206
|
+
async def _post_batch(batch: list[AgentRun]) -> None:
|
|
207
|
+
resp = await client.post(
|
|
208
|
+
self._endpoint,
|
|
209
|
+
content=_generate_payload_chunks(batch),
|
|
210
|
+
timeout=self._request_timeout,
|
|
211
|
+
)
|
|
212
|
+
resp.raise_for_status()
|
|
191
213
|
|
|
192
214
|
return _post_batch
|
|
193
215
|
|
|
194
216
|
async def _async_main(self) -> None:
|
|
195
217
|
"""Main async function for the AgentRunWriter thread."""
|
|
196
218
|
|
|
197
|
-
limiter = anyio.CapacityLimiter(self._num_workers)
|
|
198
|
-
|
|
199
219
|
async with httpx.AsyncClient(base_url=self._base_url, headers=self._headers) as client:
|
|
220
|
+
_post_batch = self.get_post_batch_fcn(client)
|
|
200
221
|
async with anyio.create_task_group() as tg:
|
|
201
|
-
_post_batch = self.get_post_batch_fcn(client)
|
|
202
222
|
|
|
203
|
-
async def
|
|
223
|
+
async def worker():
|
|
204
224
|
while not self._cancel_event.is_set():
|
|
205
225
|
batch = await self._gather_next_batch_from_queue()
|
|
206
226
|
if not batch:
|
|
207
227
|
continue
|
|
228
|
+
try:
|
|
229
|
+
await _post_batch(batch)
|
|
230
|
+
except Exception as e:
|
|
231
|
+
logger.error(
|
|
232
|
+
f"Failed to post batch of {len(batch)} agent runs: {e.__class__.__name__}: {e}"
|
|
233
|
+
)
|
|
208
234
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
tg.start_soon(batch_loop)
|
|
235
|
+
for _ in range(self._num_workers):
|
|
236
|
+
tg.start_soon(worker)
|
|
212
237
|
|
|
213
238
|
async def _gather_next_batch_from_queue(self) -> list[AgentRun]:
|
|
214
239
|
"""Gather a batch of agent runs from the queue.
|
|
@@ -233,6 +258,14 @@ def init(
|
|
|
233
258
|
server_url: str = "https://api.docent.transluce.org",
|
|
234
259
|
web_url: str = "https://docent.transluce.org",
|
|
235
260
|
api_key: str | None = None,
|
|
261
|
+
# Writer arguments
|
|
262
|
+
num_workers: int = 4,
|
|
263
|
+
queue_maxsize: int = 20_000,
|
|
264
|
+
request_timeout: float = 30.0,
|
|
265
|
+
flush_interval: float = 1.0,
|
|
266
|
+
batch_size: int = 1_000,
|
|
267
|
+
max_retries: int = 5,
|
|
268
|
+
shutdown_timeout: int = 60,
|
|
236
269
|
):
|
|
237
270
|
"""Initialize the AgentRunWriter thread.
|
|
238
271
|
|
|
@@ -242,6 +275,16 @@ def init(
|
|
|
242
275
|
server_url (str): URL of the Docent server.
|
|
243
276
|
web_url (str): URL of the Docent web UI.
|
|
244
277
|
api_key (str): API key for the Docent API.
|
|
278
|
+
num_workers (int): Max number of concurrent tasks to run,
|
|
279
|
+
managed by anyio.CapacityLimiter.
|
|
280
|
+
queue_maxsize (int): Maximum size of the queue.
|
|
281
|
+
If maxsize is <= 0, the queue size is infinite.
|
|
282
|
+
request_timeout (float): Timeout for the HTTP request.
|
|
283
|
+
flush_interval (float): Interval to flush the queue.
|
|
284
|
+
batch_size (int): Number of agent runs to batch together.
|
|
285
|
+
max_retries (int): Maximum number of retries for the HTTP request.
|
|
286
|
+
shutdown_timeout (int): Timeout to wait for the background thread to finish
|
|
287
|
+
after the main thread has requested shutdown.
|
|
245
288
|
"""
|
|
246
289
|
api_key = api_key or os.getenv("DOCENT_API_KEY")
|
|
247
290
|
|
|
@@ -263,4 +306,12 @@ def init(
|
|
|
263
306
|
api_key=api_key,
|
|
264
307
|
collection_id=collection_id,
|
|
265
308
|
server_url=server_url,
|
|
309
|
+
# Writer arguments
|
|
310
|
+
num_workers=num_workers,
|
|
311
|
+
queue_maxsize=queue_maxsize,
|
|
312
|
+
request_timeout=request_timeout,
|
|
313
|
+
flush_interval=flush_interval,
|
|
314
|
+
batch_size=batch_size,
|
|
315
|
+
max_retries=max_retries,
|
|
316
|
+
shutdown_timeout=shutdown_timeout,
|
|
266
317
|
)
|