docent-python 0.1.41a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docent-python might be problematic. Click here for more details.
- docent/__init__.py +4 -0
- docent/_llm_util/__init__.py +0 -0
- docent/_llm_util/data_models/__init__.py +0 -0
- docent/_llm_util/data_models/exceptions.py +48 -0
- docent/_llm_util/data_models/llm_output.py +331 -0
- docent/_llm_util/llm_cache.py +193 -0
- docent/_llm_util/llm_svc.py +472 -0
- docent/_llm_util/model_registry.py +134 -0
- docent/_llm_util/providers/__init__.py +0 -0
- docent/_llm_util/providers/anthropic.py +537 -0
- docent/_llm_util/providers/common.py +41 -0
- docent/_llm_util/providers/google.py +530 -0
- docent/_llm_util/providers/openai.py +745 -0
- docent/_llm_util/providers/openrouter.py +375 -0
- docent/_llm_util/providers/preference_types.py +104 -0
- docent/_llm_util/providers/provider_registry.py +164 -0
- docent/_log_util/__init__.py +3 -0
- docent/_log_util/logger.py +141 -0
- docent/data_models/__init__.py +14 -0
- docent/data_models/_tiktoken_util.py +91 -0
- docent/data_models/agent_run.py +473 -0
- docent/data_models/chat/__init__.py +37 -0
- docent/data_models/chat/content.py +56 -0
- docent/data_models/chat/message.py +191 -0
- docent/data_models/chat/tool.py +109 -0
- docent/data_models/citation.py +187 -0
- docent/data_models/formatted_objects.py +84 -0
- docent/data_models/judge.py +17 -0
- docent/data_models/metadata_util.py +16 -0
- docent/data_models/regex.py +56 -0
- docent/data_models/transcript.py +305 -0
- docent/data_models/util.py +170 -0
- docent/judges/__init__.py +23 -0
- docent/judges/analysis.py +77 -0
- docent/judges/impl.py +587 -0
- docent/judges/runner.py +129 -0
- docent/judges/stats.py +205 -0
- docent/judges/types.py +320 -0
- docent/judges/util/forgiving_json.py +108 -0
- docent/judges/util/meta_schema.json +86 -0
- docent/judges/util/meta_schema.py +29 -0
- docent/judges/util/parse_output.py +68 -0
- docent/judges/util/voting.py +139 -0
- docent/loaders/load_inspect.py +215 -0
- docent/py.typed +0 -0
- docent/samples/__init__.py +3 -0
- docent/samples/load.py +9 -0
- docent/samples/log.eval +0 -0
- docent/samples/tb_airline.json +1 -0
- docent/sdk/__init__.py +0 -0
- docent/sdk/agent_run_writer.py +317 -0
- docent/sdk/client.py +1186 -0
- docent/sdk/llm_context.py +432 -0
- docent/trace.py +2741 -0
- docent/trace_temp.py +1086 -0
- docent_python-0.1.41a0.dist-info/METADATA +33 -0
- docent_python-0.1.41a0.dist-info/RECORD +59 -0
- docent_python-0.1.41a0.dist-info/WHEEL +4 -0
- docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _repair_json(text: str) -> str:
|
|
6
|
+
"""Strip leading/trailing text and fix unescaped quotes/newlines."""
|
|
7
|
+
|
|
8
|
+
json_start = None
|
|
9
|
+
for i, char in enumerate(text):
|
|
10
|
+
remaining = text[i:]
|
|
11
|
+
if (
|
|
12
|
+
char in '[{"'
|
|
13
|
+
or char.isdigit()
|
|
14
|
+
or char == "-"
|
|
15
|
+
or remaining.startswith("null")
|
|
16
|
+
or remaining.startswith("true")
|
|
17
|
+
or remaining.startswith("false")
|
|
18
|
+
):
|
|
19
|
+
json_start = i
|
|
20
|
+
break
|
|
21
|
+
if json_start is None:
|
|
22
|
+
raise ValueError("No valid JSON start found")
|
|
23
|
+
|
|
24
|
+
result: list[str] = []
|
|
25
|
+
in_string = False
|
|
26
|
+
escape_next = False
|
|
27
|
+
depth = 0
|
|
28
|
+
started_with_container = text[json_start] in "[{"
|
|
29
|
+
|
|
30
|
+
for i in range(json_start, len(text)):
|
|
31
|
+
char = text[i]
|
|
32
|
+
|
|
33
|
+
if escape_next:
|
|
34
|
+
if in_string:
|
|
35
|
+
# Check if this is a valid escape sequence
|
|
36
|
+
is_valid_escape = char in '\\/bfnrt"' or (
|
|
37
|
+
char == "u"
|
|
38
|
+
and i + 4 < len(text)
|
|
39
|
+
and all(c in "0123456789abcdefABCDEF" for c in text[i + 1 : i + 5])
|
|
40
|
+
)
|
|
41
|
+
if not is_valid_escape:
|
|
42
|
+
# Invalid escape sequence - add another backslash to escape it
|
|
43
|
+
result.append("\\")
|
|
44
|
+
result.append(char)
|
|
45
|
+
escape_next = False
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
if char == "\\":
|
|
49
|
+
result.append(char)
|
|
50
|
+
escape_next = True
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if char == '"':
|
|
54
|
+
if in_string:
|
|
55
|
+
# Check if quote should be escaped by looking at what follows
|
|
56
|
+
remaining = text[i + 1 :].lstrip()
|
|
57
|
+
if remaining and remaining[0] not in ':,}]"':
|
|
58
|
+
result.append('\\"')
|
|
59
|
+
continue
|
|
60
|
+
in_string = False
|
|
61
|
+
result.append(char)
|
|
62
|
+
# If we're at depth 0 and closed a top-level string, we're done
|
|
63
|
+
if depth == 0 and not started_with_container:
|
|
64
|
+
return "".join(result)
|
|
65
|
+
else:
|
|
66
|
+
in_string = True
|
|
67
|
+
result.append(char)
|
|
68
|
+
elif in_string and char == "\n":
|
|
69
|
+
result.append("\\n")
|
|
70
|
+
else:
|
|
71
|
+
result.append(char)
|
|
72
|
+
|
|
73
|
+
if not in_string:
|
|
74
|
+
if char in "[{":
|
|
75
|
+
depth += 1
|
|
76
|
+
elif char in "]}":
|
|
77
|
+
depth -= 1
|
|
78
|
+
if depth == 0:
|
|
79
|
+
return "".join(result)
|
|
80
|
+
# For primitives at top level (depth 0), stop at whitespace if we've consumed content
|
|
81
|
+
elif depth == 0 and not started_with_container and result and char in " \t\n\r":
|
|
82
|
+
# Check if this is trailing whitespace after a complete primitive
|
|
83
|
+
current = "".join(result).strip()
|
|
84
|
+
if current:
|
|
85
|
+
try:
|
|
86
|
+
json.loads(current)
|
|
87
|
+
return current
|
|
88
|
+
except (json.JSONDecodeError, ValueError):
|
|
89
|
+
pass
|
|
90
|
+
|
|
91
|
+
return "".join(result)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def forgiving_json_loads(text: str) -> Any:
|
|
95
|
+
"""
|
|
96
|
+
Parse JSON from text, applying heuristics to fix common LLM mistakes.
|
|
97
|
+
|
|
98
|
+
Repairs applied:
|
|
99
|
+
- Strip leading/trailing non-JSON text
|
|
100
|
+
- Escape unescaped quotes and newlines inside strings
|
|
101
|
+
- Fix invalid escape sequences inside strings
|
|
102
|
+
"""
|
|
103
|
+
if not text or not text.strip():
|
|
104
|
+
raise ValueError("Empty or whitespace-only input")
|
|
105
|
+
|
|
106
|
+
text = _repair_json(text)
|
|
107
|
+
|
|
108
|
+
return json.loads(text)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
{
|
|
2
|
+
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
3
|
+
"$id": "https://example.com/meta/mini-schema",
|
|
4
|
+
"title": "Meta-schema for Docent judge outputs. Makes some restrictions to 2020-12.",
|
|
5
|
+
"type": "object",
|
|
6
|
+
"additionalProperties": false,
|
|
7
|
+
"properties": {
|
|
8
|
+
"type": { "const": "object" },
|
|
9
|
+
"additionalProperties": { "const": false },
|
|
10
|
+
"required": {
|
|
11
|
+
"type": "array",
|
|
12
|
+
"items": { "type": "string" }
|
|
13
|
+
},
|
|
14
|
+
|
|
15
|
+
"properties": {
|
|
16
|
+
"type": "object",
|
|
17
|
+
"propertyNames": { "type": "string" },
|
|
18
|
+
"additionalProperties": {
|
|
19
|
+
"type": "object",
|
|
20
|
+
"additionalProperties": false,
|
|
21
|
+
"required": ["type"],
|
|
22
|
+
|
|
23
|
+
"properties": {
|
|
24
|
+
"type": {
|
|
25
|
+
"type": "string",
|
|
26
|
+
"enum": ["string", "integer", "number", "boolean"]
|
|
27
|
+
},
|
|
28
|
+
"description": {
|
|
29
|
+
"type": "string"
|
|
30
|
+
},
|
|
31
|
+
"citations": {
|
|
32
|
+
"type": "boolean"
|
|
33
|
+
},
|
|
34
|
+
"enum": {
|
|
35
|
+
"type": "array",
|
|
36
|
+
"items": {
|
|
37
|
+
"type": ["string", "integer", "boolean"]
|
|
38
|
+
}
|
|
39
|
+
},
|
|
40
|
+
"format": {
|
|
41
|
+
"type": "string",
|
|
42
|
+
"enum": [
|
|
43
|
+
"date-time",
|
|
44
|
+
"date",
|
|
45
|
+
"time",
|
|
46
|
+
"email",
|
|
47
|
+
"hostname",
|
|
48
|
+
"ipv4",
|
|
49
|
+
"ipv6",
|
|
50
|
+
"uri",
|
|
51
|
+
"uuid"
|
|
52
|
+
]
|
|
53
|
+
},
|
|
54
|
+
"minLength": {
|
|
55
|
+
"type": "integer",
|
|
56
|
+
"minimum": 0
|
|
57
|
+
},
|
|
58
|
+
"maxLength": {
|
|
59
|
+
"type": "integer",
|
|
60
|
+
"minimum": 0
|
|
61
|
+
},
|
|
62
|
+
"pattern": {
|
|
63
|
+
"type": "string"
|
|
64
|
+
},
|
|
65
|
+
"minimum": {
|
|
66
|
+
"type": "number"
|
|
67
|
+
},
|
|
68
|
+
"maximum": {
|
|
69
|
+
"type": "number"
|
|
70
|
+
},
|
|
71
|
+
"exclusiveMinimum": {
|
|
72
|
+
"type": "number"
|
|
73
|
+
},
|
|
74
|
+
"exclusiveMaximum": {
|
|
75
|
+
"type": "number"
|
|
76
|
+
},
|
|
77
|
+
"multipleOf": {
|
|
78
|
+
"type": "number",
|
|
79
|
+
"exclusiveMinimum": 0
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
},
|
|
85
|
+
"required": ["type", "properties"]
|
|
86
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import jsonschema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _load_meta_schema() -> dict[str, Any]:
|
|
9
|
+
"""Load the rubric meta-schema from the adjacent JSON file."""
|
|
10
|
+
meta_schema_path = Path(__file__).with_suffix(".json")
|
|
11
|
+
with meta_schema_path.open("r", encoding="utf-8") as f:
|
|
12
|
+
return json.load(f)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_META_VALIDATOR = jsonschema.Draft202012Validator(_load_meta_schema())
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def validate_judge_result_schema(schema: dict[str, Any]):
|
|
19
|
+
"""Validate a proposed schema against the rubric meta-schema.
|
|
20
|
+
|
|
21
|
+
Raises:
|
|
22
|
+
jsonschema.ValidationError: If the schema is invalid
|
|
23
|
+
jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
|
|
24
|
+
"""
|
|
25
|
+
# First check that this is a valid 2020-12 schema
|
|
26
|
+
jsonschema.Draft202012Validator.check_schema(schema)
|
|
27
|
+
|
|
28
|
+
# Then check that it conforms to our subset of the 2020-12 schema
|
|
29
|
+
_META_VALIDATOR.validate(schema) # type: ignore
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from typing import Any, cast
|
|
2
|
+
|
|
3
|
+
import jsonschema
|
|
4
|
+
|
|
5
|
+
from docent._llm_util.data_models.exceptions import ValidationFailedException
|
|
6
|
+
from docent._log_util import get_logger
|
|
7
|
+
from docent.data_models.agent_run import AgentRun
|
|
8
|
+
from docent.judges.util.forgiving_json import forgiving_json_loads
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _validate_rubric_output(
|
|
14
|
+
output: dict[str, Any], output_schema: dict[str, Any], agent_run: AgentRun
|
|
15
|
+
) -> dict[str, Any]:
|
|
16
|
+
"""Validate that the output conforms to the output schema.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
output: Raw results from LLM judge
|
|
20
|
+
agent_run: Agent run (unused, kept for backwards compatibility)
|
|
21
|
+
output_schema: Schema to validate against
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Validated result dict
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
ValidationFailedException: If validation fails
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
jsonschema.validate(output, output_schema)
|
|
31
|
+
except jsonschema.ValidationError as e:
|
|
32
|
+
raise ValidationFailedException(f"Schema validation failed: {e}", failed_output=str(output))
|
|
33
|
+
|
|
34
|
+
return output
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def parse_and_validate_output_str(
|
|
38
|
+
output_str: str, output_schema: dict[str, Any], agent_run: AgentRun
|
|
39
|
+
) -> dict[str, Any]:
|
|
40
|
+
"""Parse and validate LLM output for rubric evaluation.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
llm_output: The LLM output to parse
|
|
44
|
+
output_schema: The schema to validate against
|
|
45
|
+
agent_run: Agent run for citation validation
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Validated output dict
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValidationFailedException: If parsing or validation fails
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
output = forgiving_json_loads(output_str)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise ValidationFailedException(
|
|
58
|
+
f"Failed to parse JSON: {e}. Raw text: `{output_str}`",
|
|
59
|
+
failed_output=output_str,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if not isinstance(output, dict):
|
|
63
|
+
raise ValidationFailedException(
|
|
64
|
+
f"Expected dict output, got {type(output)}. Raw text: {output_str}",
|
|
65
|
+
failed_output=output_str,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return _validate_rubric_output(cast(dict[str, Any], output), output_schema, agent_run)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from collections import Counter
|
|
2
|
+
from typing import Any, TypedDict, cast
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class EstimateWithCI(TypedDict):
|
|
8
|
+
mean: float
|
|
9
|
+
var: float
|
|
10
|
+
n: int
|
|
11
|
+
ci_95: float
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
18
|
+
"""Get list of top-level keys in schema that we want to measure agreement on.
|
|
19
|
+
|
|
20
|
+
This includes enum and bool fields.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
schema: JSON schema dict
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
List of field names (keys) that should be used for measuring agreement
|
|
27
|
+
"""
|
|
28
|
+
agreement_keys: list[str] = []
|
|
29
|
+
|
|
30
|
+
properties = schema.get("properties", {})
|
|
31
|
+
assert isinstance(properties, dict)
|
|
32
|
+
properties = cast(dict[str, Any], properties)
|
|
33
|
+
|
|
34
|
+
for key, field_schema in properties.items():
|
|
35
|
+
assert isinstance(field_schema, dict)
|
|
36
|
+
field_schema = cast(dict[str, Any], field_schema)
|
|
37
|
+
|
|
38
|
+
field_type = field_schema.get("type")
|
|
39
|
+
assert isinstance(field_type, str)
|
|
40
|
+
|
|
41
|
+
# Include boolean fields
|
|
42
|
+
if field_type == "boolean":
|
|
43
|
+
agreement_keys.append(key)
|
|
44
|
+
# Include enum fields (strings and numbers must be in this category)
|
|
45
|
+
elif "enum" in field_schema:
|
|
46
|
+
agreement_keys.append(key)
|
|
47
|
+
|
|
48
|
+
return agreement_keys
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
|
|
52
|
+
"""Find the result that best matches modal values across agreement keys.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
indep_results: List of independent results to analyze
|
|
56
|
+
agreement_keys: Keys to measure agreement on
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tuple of (max_idx, agt_key_modes_and_counts) where:
|
|
60
|
+
- max_idx is the index of the result that best matches modal values
|
|
61
|
+
- agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If no results are provided
|
|
65
|
+
"""
|
|
66
|
+
if not indep_results:
|
|
67
|
+
raise ValueError("No results to score")
|
|
68
|
+
|
|
69
|
+
# For each agreement key, compute the mode and count (or None, if no values exist for that key)
|
|
70
|
+
agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
|
|
71
|
+
for key in agreement_keys:
|
|
72
|
+
key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
|
|
73
|
+
if most_common_one := key_modes.most_common(1):
|
|
74
|
+
agt_key_modes_and_counts[key] = most_common_one[0]
|
|
75
|
+
else:
|
|
76
|
+
agt_key_modes_and_counts[key] = None
|
|
77
|
+
|
|
78
|
+
# Score each rollout based on how many agreement keys they match
|
|
79
|
+
# If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
|
|
80
|
+
# TODO(mengk): This may bias towards results that have more keys.
|
|
81
|
+
indep_result_scores: list[int] = []
|
|
82
|
+
for r in indep_results:
|
|
83
|
+
score = 0
|
|
84
|
+
for key in agreement_keys:
|
|
85
|
+
mode_and_count = agt_key_modes_and_counts[key]
|
|
86
|
+
if mode_and_count and r.get(key) == mode_and_count[0]:
|
|
87
|
+
score += 1
|
|
88
|
+
indep_result_scores.append(score)
|
|
89
|
+
|
|
90
|
+
# Argmax
|
|
91
|
+
max_idx = indep_result_scores.index(max(indep_result_scores))
|
|
92
|
+
|
|
93
|
+
return max_idx, agt_key_modes_and_counts
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def compute_output_distributions(
|
|
97
|
+
indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
|
|
98
|
+
):
|
|
99
|
+
def _get_possible_values(key: str) -> list[str | bool | int | float]:
|
|
100
|
+
if "enum" in output_schema.get("properties", {}).get(key, {}):
|
|
101
|
+
return output_schema.get("properties", {}).get(key, {}).get("enum", [])
|
|
102
|
+
elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
|
|
103
|
+
return [True, False]
|
|
104
|
+
else:
|
|
105
|
+
return []
|
|
106
|
+
|
|
107
|
+
raw_counts: dict[str, dict[str | bool | int | float, int]] = {
|
|
108
|
+
key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
|
|
109
|
+
}
|
|
110
|
+
# Collect counts for each possible value
|
|
111
|
+
for result in indep_results:
|
|
112
|
+
for key in agreement_keys:
|
|
113
|
+
if (value := result.get(key)) is not None: # Could be none if the key is optional
|
|
114
|
+
assert (
|
|
115
|
+
value in raw_counts[key]
|
|
116
|
+
), "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
|
|
117
|
+
raw_counts[key][value] += 1
|
|
118
|
+
|
|
119
|
+
distributions: dict[str, JudgeOutputDistribution] = {}
|
|
120
|
+
for agt_key in agreement_keys:
|
|
121
|
+
distributions[agt_key] = {}
|
|
122
|
+
|
|
123
|
+
# First normalize the counts to get probabilities
|
|
124
|
+
counts = raw_counts[agt_key]
|
|
125
|
+
total = sum(counts.values())
|
|
126
|
+
probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
|
|
127
|
+
|
|
128
|
+
for output_key, value in probs.items():
|
|
129
|
+
mean, estimate_var = value, (value * (1 - value))
|
|
130
|
+
# TODO(mengk): change to the wilson score interval
|
|
131
|
+
ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
|
|
132
|
+
distributions[agt_key][output_key] = {
|
|
133
|
+
"mean": mean,
|
|
134
|
+
"var": estimate_var,
|
|
135
|
+
"n": total,
|
|
136
|
+
"ci_95": ci_95,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return distributions
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, BinaryIO, Generator, Tuple
|
|
4
|
+
from zipfile import ZipFile
|
|
5
|
+
|
|
6
|
+
from inspect_ai.log import EvalLog
|
|
7
|
+
from inspect_ai.scorer import CORRECT, INCORRECT, NOANSWER, PARTIAL, Score
|
|
8
|
+
|
|
9
|
+
from docent._log_util.logger import get_logger
|
|
10
|
+
from docent.data_models import AgentRun, Transcript
|
|
11
|
+
from docent.data_models.chat import parse_chat_message
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _normalize_inspect_score(score: Score | dict[str, Any]) -> Any:
|
|
17
|
+
"""
|
|
18
|
+
Normalize an inspect score to a float. Logic mirrors inspect_ai.scorer._metric.value_to_float.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
score: The inspect score to normalize.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The normalized score as a float, or None if the score is not a valid value.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def _leaf_normalize(value: Any) -> Any:
|
|
28
|
+
if value is None:
|
|
29
|
+
return None
|
|
30
|
+
if isinstance(value, int | float | bool):
|
|
31
|
+
return float(value)
|
|
32
|
+
if value == CORRECT:
|
|
33
|
+
return 1.0
|
|
34
|
+
if value == PARTIAL:
|
|
35
|
+
return 0.5
|
|
36
|
+
if value in [INCORRECT, NOANSWER]:
|
|
37
|
+
return 0
|
|
38
|
+
value = str(value).lower()
|
|
39
|
+
if value in ["yes", "true"]:
|
|
40
|
+
return 1.0
|
|
41
|
+
if value in ["no", "false"]:
|
|
42
|
+
return 0.0
|
|
43
|
+
if value.replace(".", "").isnumeric():
|
|
44
|
+
return float(value)
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
if isinstance(score, dict):
|
|
48
|
+
value = score["value"]
|
|
49
|
+
else:
|
|
50
|
+
value = score.value
|
|
51
|
+
|
|
52
|
+
if isinstance(value, int | float | bool | str):
|
|
53
|
+
return _leaf_normalize(value)
|
|
54
|
+
if isinstance(value, list):
|
|
55
|
+
return [_leaf_normalize(v) for v in value] # type: ignore
|
|
56
|
+
assert isinstance(value, dict), "Inspect score must be leaf value, list, or dict"
|
|
57
|
+
return {k: _leaf_normalize(v) for k, v in value.items()} # type: ignore
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def load_inspect_log(log: EvalLog) -> list[AgentRun]:
|
|
61
|
+
if log.samples is None:
|
|
62
|
+
return []
|
|
63
|
+
|
|
64
|
+
# TODO(vincent): fix this
|
|
65
|
+
agent_runs: list[AgentRun] = []
|
|
66
|
+
|
|
67
|
+
for s in log.samples:
|
|
68
|
+
sample_id = s.id
|
|
69
|
+
epoch_id = s.epoch
|
|
70
|
+
|
|
71
|
+
if s.scores is None:
|
|
72
|
+
sample_scores = {}
|
|
73
|
+
else:
|
|
74
|
+
sample_scores = {k: _normalize_inspect_score(v) for k, v in s.scores.items()}
|
|
75
|
+
|
|
76
|
+
metadata = {
|
|
77
|
+
"task_id": log.eval.task,
|
|
78
|
+
"sample_id": str(sample_id),
|
|
79
|
+
"epoch_id": epoch_id,
|
|
80
|
+
"model": log.eval.model,
|
|
81
|
+
"additional_metadata": s.metadata,
|
|
82
|
+
"scores": sample_scores,
|
|
83
|
+
# Scores could have answers, explanations, and other metadata besides the values we extract
|
|
84
|
+
"scoring_metadata": s.scores,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
agent_runs.append(
|
|
88
|
+
AgentRun(
|
|
89
|
+
transcripts=[
|
|
90
|
+
Transcript(
|
|
91
|
+
messages=[parse_chat_message(m.model_dump()) for m in s.messages],
|
|
92
|
+
metadata={},
|
|
93
|
+
)
|
|
94
|
+
],
|
|
95
|
+
metadata=metadata,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return agent_runs
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _read_sample_as_run(data: dict[str, Any], header_metadata: dict[str, Any] = {}) -> AgentRun:
|
|
103
|
+
if "scores" in data:
|
|
104
|
+
normalized_scores = {k: _normalize_inspect_score(v) for k, v in data["scores"].items()}
|
|
105
|
+
else:
|
|
106
|
+
normalized_scores = {}
|
|
107
|
+
|
|
108
|
+
if "metadata" in data:
|
|
109
|
+
sample_metadata = data["metadata"]
|
|
110
|
+
else:
|
|
111
|
+
sample_metadata = {}
|
|
112
|
+
|
|
113
|
+
run_metadata: dict[str, Any] = {
|
|
114
|
+
"sample_id": data.get("id"),
|
|
115
|
+
"epoch": data.get("epoch"),
|
|
116
|
+
"target": data.get("target"),
|
|
117
|
+
# Scores could have answers, explanations, and other metadata besides the values we extract
|
|
118
|
+
"scoring_metadata": data.get("scores"),
|
|
119
|
+
"scores": normalized_scores,
|
|
120
|
+
# If a key exists in header and sample, sample takes precedence
|
|
121
|
+
**header_metadata,
|
|
122
|
+
**sample_metadata,
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
run = AgentRun(
|
|
126
|
+
transcripts=[
|
|
127
|
+
Transcript(messages=[parse_chat_message(m) for m in data["messages"]], metadata={})
|
|
128
|
+
],
|
|
129
|
+
metadata=run_metadata,
|
|
130
|
+
)
|
|
131
|
+
return run
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _run_metadata_from_header(header: dict[str, Any]) -> dict[str, Any]:
|
|
135
|
+
"""
|
|
136
|
+
Inspect logs often have a lot of metadata.
|
|
137
|
+
This function tries to get the most important stuff without adding clutter.
|
|
138
|
+
"""
|
|
139
|
+
m: dict[str, Any] = {}
|
|
140
|
+
if e := header.get("eval"):
|
|
141
|
+
m["task"] = e["task"]
|
|
142
|
+
m["model"] = e["model"]
|
|
143
|
+
return m
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def get_total_samples(file_path: Path, format: str = "json") -> int:
|
|
147
|
+
"""Return the total number of samples in the provided file."""
|
|
148
|
+
with open(file_path, "rb") as f:
|
|
149
|
+
if format == "json":
|
|
150
|
+
data = json.load(f)
|
|
151
|
+
return len(data.get("samples", []))
|
|
152
|
+
elif format == "eval":
|
|
153
|
+
z = ZipFile(f, mode="r")
|
|
154
|
+
try:
|
|
155
|
+
return sum(
|
|
156
|
+
1
|
|
157
|
+
for name in z.namelist()
|
|
158
|
+
if name.startswith("samples/") and name.endswith(".json")
|
|
159
|
+
)
|
|
160
|
+
finally:
|
|
161
|
+
z.close()
|
|
162
|
+
else:
|
|
163
|
+
raise ValueError(f"Format must be 'json' or 'eval': {format}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _runs_from_eval_file(
|
|
167
|
+
file: BinaryIO,
|
|
168
|
+
) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
|
|
169
|
+
zip = ZipFile(file, mode="r")
|
|
170
|
+
try:
|
|
171
|
+
header: dict[str, Any] = json.load(zip.open("header.json", "r"))
|
|
172
|
+
header_metadata = _run_metadata_from_header(header)
|
|
173
|
+
except KeyError:
|
|
174
|
+
logger.warning(f"No header found in {file.name} file")
|
|
175
|
+
header_metadata = {}
|
|
176
|
+
|
|
177
|
+
def _iter_runs() -> Generator[AgentRun, None, None]:
|
|
178
|
+
try:
|
|
179
|
+
for sample_file in zip.namelist():
|
|
180
|
+
if not (sample_file.startswith("samples/") and sample_file.endswith(".json")):
|
|
181
|
+
continue
|
|
182
|
+
with zip.open(sample_file, "r") as f:
|
|
183
|
+
data = json.load(f)
|
|
184
|
+
run: AgentRun = _read_sample_as_run(data, header_metadata)
|
|
185
|
+
yield run
|
|
186
|
+
finally:
|
|
187
|
+
zip.close()
|
|
188
|
+
|
|
189
|
+
return header_metadata, _iter_runs()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _runs_from_json_file(
|
|
193
|
+
file: BinaryIO,
|
|
194
|
+
) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
|
|
195
|
+
data = json.load(file)
|
|
196
|
+
header_metadata = _run_metadata_from_header(data)
|
|
197
|
+
|
|
198
|
+
def _iter_runs() -> Generator[AgentRun, None, None]:
|
|
199
|
+
for sample in data["samples"]:
|
|
200
|
+
run: AgentRun = _read_sample_as_run(sample, header_metadata)
|
|
201
|
+
yield run
|
|
202
|
+
|
|
203
|
+
return header_metadata, _iter_runs()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def runs_from_file(
|
|
207
|
+
file: BinaryIO, format: str = "json"
|
|
208
|
+
) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
|
|
209
|
+
if format == "json":
|
|
210
|
+
result = _runs_from_json_file(file)
|
|
211
|
+
elif format == "eval":
|
|
212
|
+
result = _runs_from_eval_file(file)
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError(f"Format must be 'json' or 'eval': {format}")
|
|
215
|
+
return result
|
docent/py.typed
ADDED
|
File without changes
|
docent/samples/load.py
ADDED
docent/samples/log.eval
ADDED
|
Binary file
|