docent-python 0.1.20a0__py3-none-any.whl → 0.1.22a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

docent/judges/impl.py ADDED
@@ -0,0 +1,232 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any
4
+
5
+ from docent._llm_util.data_models.llm_output import LLMOutput
6
+ from docent._llm_util.data_models.simple_svc import BaseLLMService, SimpleLLMService
7
+ from docent._log_util import get_logger
8
+ from docent.data_models.agent_run import AgentRun
9
+ from docent.judges.types import JudgeResult, ResultType, Rubric
10
+ from docent.judges.util.parse_output import parse_and_validate_llm_output
11
+ from docent.judges.util.voting import (
12
+ compute_output_distribution,
13
+ find_modal_result,
14
+ get_agreement_keys,
15
+ )
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class BaseJudge(ABC):
21
+ def __init__(self, cfg: Rubric, llm_svc: BaseLLMService):
22
+ self.cfg = cfg
23
+ self.llm_svc = llm_svc
24
+
25
+ @abstractmethod
26
+ async def __call__(self, agent_run: AgentRun, *args: Any, **kwargs: Any) -> JudgeResult | None:
27
+ """Returns None if all rollouts failed to produce a valid output."""
28
+
29
+
30
+ class MajorityVotingJudge(BaseJudge):
31
+ """Rolls out the judge multiple times, then uses majority voting to determine the final result."""
32
+
33
+ def __init__(
34
+ self,
35
+ cfg: Rubric,
36
+ n_rollouts_per_input: int,
37
+ llm_svc: BaseLLMService = SimpleLLMService(),
38
+ ):
39
+ super().__init__(cfg, llm_svc)
40
+ self.n_rollouts_per_input = n_rollouts_per_input
41
+
42
+ async def __call__(
43
+ self,
44
+ agent_run: AgentRun,
45
+ max_concurrency: int = 10,
46
+ ) -> JudgeResult | None:
47
+ async def _validation_callback(batch_index: int, llm_output: LLMOutput):
48
+ parse_and_validate_llm_output(llm_output, self.cfg.output_schema, agent_run)
49
+
50
+ prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
51
+ outputs = await self.llm_svc.get_completions(
52
+ inputs=[prompt for _ in range(self.n_rollouts_per_input)],
53
+ model_options=[self.cfg.judge_model],
54
+ max_new_tokens=16384,
55
+ timeout=180.0,
56
+ use_cache=False,
57
+ validation_callback=_validation_callback,
58
+ max_concurrency=max_concurrency,
59
+ )
60
+
61
+ # Process each rollout independently
62
+ indep_results: list[dict[str, Any]] = []
63
+ for output in outputs:
64
+ if validated_output := parse_and_validate_llm_output(
65
+ output, self.cfg.output_schema, agent_run
66
+ ):
67
+ indep_results.append(validated_output)
68
+
69
+ if not indep_results:
70
+ return None
71
+
72
+ # Get a list of the keys that we want to measure agreement on
73
+ agreement_keys = get_agreement_keys(self.cfg.output_schema)
74
+
75
+ # Find the result that best matches modal values
76
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
77
+ indep_results, agreement_keys
78
+ )
79
+ final_output = indep_results[final_max_idx]
80
+
81
+ # Compute the distribution of the output across the agreement keys
82
+ final_output_distribution = compute_output_distribution(
83
+ indep_results, self.cfg.output_schema, agreement_keys
84
+ )
85
+
86
+ return JudgeResult(
87
+ agent_run_id=agent_run.id,
88
+ rubric_id=self.cfg.id,
89
+ rubric_version=self.cfg.version,
90
+ output=final_output,
91
+ result_metadata={
92
+ "agt_keys": agreement_keys,
93
+ # Final measurements
94
+ "final_results": indep_results,
95
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
96
+ "final_max_idx": final_max_idx,
97
+ "final_output_distribution": final_output_distribution,
98
+ },
99
+ result_type=ResultType.DIRECT_RESULT,
100
+ )
101
+
102
+
103
+ class MultiReflectionJudge(BaseJudge):
104
+ """Rolls out the judge multiple times, then uses reflection to determine the final result."""
105
+
106
+ def __init__(
107
+ self,
108
+ cfg: Rubric,
109
+ n_rollouts_per_input: int,
110
+ llm_svc: BaseLLMService = SimpleLLMService(),
111
+ ):
112
+ super().__init__(cfg, llm_svc)
113
+ self.n_rollouts_per_input = n_rollouts_per_input
114
+
115
+ async def __call__(
116
+ self,
117
+ agent_run: AgentRun,
118
+ max_concurrency: int = 10,
119
+ ) -> JudgeResult | None:
120
+ rubric = self.cfg
121
+
122
+ async def _validation_callback(batch_index: int, llm_output: LLMOutput):
123
+ parse_and_validate_llm_output(llm_output, rubric.output_schema, agent_run)
124
+
125
+ # Run several independent rollouts
126
+ prompt = [{"role": "user", "content": self.cfg.materialize_system_prompt(agent_run)}]
127
+ outputs = await self.llm_svc.get_completions(
128
+ inputs=[prompt for _ in range(self.n_rollouts_per_input)],
129
+ model_options=[rubric.judge_model],
130
+ max_new_tokens=16384,
131
+ timeout=180.0,
132
+ use_cache=False,
133
+ validation_callback=_validation_callback,
134
+ max_concurrency=max_concurrency,
135
+ )
136
+
137
+ # Process each rollout
138
+ indep_results: list[dict[str, Any]] = []
139
+ for output in outputs:
140
+ if output.first_text is None:
141
+ continue
142
+ if v_output := parse_and_validate_llm_output(output, rubric.output_schema, agent_run):
143
+ indep_results.append(v_output)
144
+
145
+ if not indep_results:
146
+ return None
147
+
148
+ # Compute initial modes
149
+ agreement_keys = get_agreement_keys(rubric.output_schema)
150
+ indep_max_idx, indep_agt_key_modes_and_counts = find_modal_result(
151
+ indep_results, agreement_keys
152
+ )
153
+
154
+ def _get_reflection_prompt(cur_index: int):
155
+ # Current result
156
+ result = indep_results[cur_index]
157
+ # Get other results (excluding the current one)
158
+ other_results = [r for j, r in enumerate(indep_results) if j != cur_index]
159
+
160
+ # Create the reflection message
161
+ other_results_text = "\n\n".join(
162
+ [f"Answer {j+1}:\n{json.dumps(r, indent=2)}" for j, r in enumerate(other_results)]
163
+ )
164
+
165
+ reflection_instruction = (
166
+ f"Here are {len(other_results)} other independent answers to the same rubric evaluation:\n\n"
167
+ f"{other_results_text}\n\n"
168
+ f"Please reflect on these other answers and your own answer. "
169
+ f"Consider if any of them have identified important aspects you missed, or if there are disagreements that should be resolved. "
170
+ f"Then provide your final answer in the same JSON format as before."
171
+ )
172
+
173
+ # Construct the multi-message prompt
174
+ # 1. Original user message
175
+ # 2. Assistant message with the rollout's result
176
+ # 3. New user message asking for reflection
177
+ return [
178
+ *prompt, # Original user message(s)
179
+ {"role": "assistant", "content": json.dumps(result, indent=2)},
180
+ {"role": "user", "content": reflection_instruction},
181
+ ]
182
+
183
+ final_results = indep_results.copy() # Shallow copy
184
+ if len(indep_results) > 1:
185
+ # Ask the judge to reflect on the others' results
186
+ reflection_outputs = await self.llm_svc.get_completions(
187
+ inputs=[_get_reflection_prompt(i) for i in range(len(indep_results))],
188
+ model_options=[rubric.judge_model],
189
+ max_new_tokens=16384,
190
+ timeout=180.0,
191
+ use_cache=False,
192
+ validation_callback=_validation_callback,
193
+ max_concurrency=max_concurrency,
194
+ )
195
+
196
+ # Process reflection outputs in the same way as the initial rollouts
197
+ reflected_results: list[dict[str, Any]] = []
198
+ for output in reflection_outputs:
199
+ if output.first_text is None:
200
+ continue
201
+ if v_output := parse_and_validate_llm_output(
202
+ output, rubric.output_schema, agent_run
203
+ ):
204
+ reflected_results.append(v_output)
205
+
206
+ # Use reflected results if we got any, otherwise fall back to original results
207
+ if reflected_results:
208
+ final_results = reflected_results
209
+ else:
210
+ logger.warning("No reflected results found, falling back to original results")
211
+
212
+ final_max_idx, final_agt_key_modes_and_counts = find_modal_result(
213
+ final_results, agreement_keys
214
+ )
215
+ return JudgeResult(
216
+ agent_run_id=agent_run.id,
217
+ rubric_id=rubric.id,
218
+ rubric_version=rubric.version,
219
+ output=final_results[final_max_idx],
220
+ result_metadata={
221
+ "agt_keys": agreement_keys,
222
+ # Final measurements
223
+ "final_results": final_results,
224
+ "final_agt_key_modes_and_counts": final_agt_key_modes_and_counts,
225
+ "final_max_idx": final_max_idx,
226
+ # Also include initial measurements
227
+ "indep_results": indep_results,
228
+ "indep_max_idx": indep_max_idx,
229
+ "indep_agt_key_modes_and_counts": indep_agt_key_modes_and_counts,
230
+ },
231
+ result_type=ResultType.DIRECT_RESULT,
232
+ )
docent/judges/types.py ADDED
@@ -0,0 +1,240 @@
1
+ import enum
2
+ import json
3
+ from string import Formatter
4
+ from typing import Any, Callable, Protocol
5
+ from uuid import uuid4
6
+
7
+ from pydantic import BaseModel, Field, field_serializer, field_validator
8
+
9
+ from docent._llm_util.providers.preference_types import PUBLIC_PROVIDER_PREFERENCES, ModelOption
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.agent_run import AgentRun
12
+ from docent.data_models.citation import parse_citations
13
+ from docent.data_models.transcript import TEXT_RANGE_CITE_INSTRUCTION
14
+ from docent.judges.util.meta_schema import validate_judge_result_schema
15
+
16
+ logger = get_logger(__name__)
17
+
18
+ DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE = """
19
+ Here is a rubric that we are using to judge transcripts of AI agent runs.
20
+
21
+ Rubric:
22
+ {rubric}
23
+
24
+ Agent run:
25
+ {agent_run}
26
+
27
+ Your response should convey your judgment of the agent run according to the criteria given in the rubric provided above. Your entire response must be a valid JSON string which can be parsed with python `json.loads` without any additional processing. Double quotes (`"`) in the middle of a string in the JSON object must be escaped with a backslash.
28
+
29
+ The JSON object you produce must adhere to the following schema:
30
+ {output_schema}
31
+
32
+ {citation_instructions}
33
+ """.strip()
34
+
35
+ DEFAULT_CITATION_INSTRUCTIONS = f"""
36
+ For strings which require citations (according to the `citations: True` property), you must also follow these instructions:
37
+ {TEXT_RANGE_CITE_INSTRUCTION}
38
+ """.strip()
39
+
40
+ DEFAULT_JUDGE_OUTPUT_SCHEMA = {
41
+ "type": "object",
42
+ "properties": {
43
+ "label": {"type": "string", "enum": ["match", "no match"]},
44
+ "explanation": {"type": "string", "citations": True},
45
+ },
46
+ # Require these properties to be present
47
+ "required": ["label", "explanation"],
48
+ # Allow additional properties though, as their presence is not breaking
49
+ }
50
+
51
+ DEFAULT_JUDGE_MODEL = PUBLIC_PROVIDER_PREFERENCES.default_judge_models[0]
52
+
53
+
54
+ class Rubric(BaseModel):
55
+ """TODO(mengk): this should really be called JudgeConfig,
56
+ but temporarily keeping this for consistency with docent_core."""
57
+
58
+ class Config:
59
+ frozen = True
60
+
61
+ # Primary key
62
+ id: str = Field(default_factory=lambda: str(uuid4()))
63
+ version: int = 1
64
+
65
+ # What the judge actually does
66
+ rubric_text: str
67
+
68
+ # Default instructions for the judge
69
+ system_prompt_template: str = DEFAULT_JUDGE_SYSTEM_PROMPT_TEMPLATE
70
+ citation_instructions: str = DEFAULT_CITATION_INSTRUCTIONS
71
+ output_schema: dict[str, Any] = DEFAULT_JUDGE_OUTPUT_SCHEMA
72
+
73
+ # How to run the judge
74
+ judge_model: ModelOption = DEFAULT_JUDGE_MODEL
75
+
76
+ def materialize_system_prompt(self, agent_run: AgentRun) -> str:
77
+ """Construct the full prompt text for rubric evaluation.
78
+
79
+ This is the canonical implementation of prompt construction - use this function
80
+ anywhere you need to construct a rubric evaluation prompt (including cost estimation).
81
+ """
82
+
83
+ output_schema_text = json.dumps(self.output_schema, indent=2)
84
+
85
+ # We've already validated that the system prompt template has these keys
86
+ prompt = self.system_prompt_template.format(
87
+ rubric=self.rubric_text,
88
+ agent_run=agent_run.to_text_new(),
89
+ output_schema=output_schema_text,
90
+ # Only include citation instructions if the schema requests citations
91
+ citation_instructions=(
92
+ self.citation_instructions if _schema_requests_citations(self.output_schema) else ""
93
+ ),
94
+ ).strip()
95
+
96
+ return prompt
97
+
98
+ @field_validator("system_prompt_template")
99
+ @classmethod
100
+ def validate_system_prompt_template(cls, system_prompt_template: str):
101
+ # Extract all field names from the template
102
+ formatter = Formatter()
103
+ field_names = {
104
+ field_name
105
+ for _, field_name, _, _ in formatter.parse(system_prompt_template)
106
+ if field_name is not None
107
+ }
108
+
109
+ # Check for required fields
110
+ required_fields = {"agent_run", "output_schema", "rubric", "citation_instructions"}
111
+ missing_fields = required_fields - field_names
112
+
113
+ if missing_fields:
114
+ raise ValueError(
115
+ f"system_prompt_template must contain the following placeholders: {missing_fields}"
116
+ )
117
+
118
+ return system_prompt_template
119
+
120
+ @field_validator("output_schema")
121
+ @classmethod
122
+ def validate_output_schema(cls, output_schema: dict[str, Any]):
123
+ """
124
+ Raises:
125
+ jsonschema.ValidationError: If the schema is invalid
126
+ jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
127
+ """
128
+ validate_judge_result_schema(output_schema)
129
+ return output_schema
130
+
131
+
132
+ class ResultType(enum.Enum):
133
+ """Enum for the type of result that a judge result can have."""
134
+
135
+ DIRECT_RESULT = "direct_result"
136
+ NEAR_MISS = "near_miss"
137
+
138
+
139
+ class JudgeResult(BaseModel):
140
+ class Config:
141
+ frozen = True
142
+
143
+ id: str = Field(default_factory=lambda: str(uuid4()))
144
+ agent_run_id: str
145
+ rubric_id: str
146
+ rubric_version: int
147
+
148
+ # Outputs
149
+ output: dict[str, Any]
150
+ result_metadata: dict[str, Any] | None = None
151
+ result_type: ResultType
152
+
153
+ # Deprecated
154
+ value: str | None = None
155
+
156
+ @field_serializer("result_type")
157
+ def serialize_result_type(self, result_type: ResultType) -> str:
158
+ return result_type.value
159
+
160
+
161
+ class JudgeResultWithCitations(JudgeResult):
162
+ @classmethod
163
+ def from_judge_result(
164
+ cls, result: JudgeResult, schema: dict[str, Any]
165
+ ) -> "JudgeResultWithCitations":
166
+ """Judge result must be validated against the schema before calling this function!"""
167
+
168
+ def _parse_citation_string(output: str) -> dict[str, Any]:
169
+ text, citations = parse_citations(output)
170
+ return {"text": text, "citations": citations}
171
+
172
+ data = result.model_dump()
173
+ try:
174
+ data["output"] = traverse_schema_and_transform(
175
+ data["output"], schema, _parse_citation_string
176
+ )
177
+ except Exception as e:
178
+ logger.error(f"Failed to parse citations: {e}")
179
+ logger.error(f"Output: {data['output']}")
180
+ data["output"] = {"raw": data["output"]}
181
+ return cls(**data)
182
+
183
+
184
+ class JudgeResultCompletionCallback(Protocol):
185
+ """Called when some batch of judge results is completed.
186
+ Supports batched calls for cases where many results are pre-computed.
187
+ This avoids invoking the callback separately for each datapoint.
188
+ """
189
+
190
+ async def __call__(
191
+ self,
192
+ batch_index: int,
193
+ judge_results: list[JudgeResult] | None,
194
+ ) -> None: ...
195
+
196
+
197
+ def traverse_schema_and_transform(
198
+ output: Any,
199
+ schema: dict[str, Any],
200
+ citation_string_handler: Callable[[str], Any],
201
+ ) -> Any:
202
+ """Recursively traverse output based on schema, applying citation_string_handler to citation strings."""
203
+ if schema.get("type") == "string" and schema.get("citations"): # type: ignore
204
+ return citation_string_handler(output)
205
+ elif schema.get("type") == "object":
206
+ properties: dict[str, Any] = schema.get("properties", {})
207
+ result: dict[str, Any] = {}
208
+ for key in properties:
209
+ if key in output:
210
+ result[key] = traverse_schema_and_transform(
211
+ output[key], properties[key], citation_string_handler
212
+ )
213
+ return result
214
+ elif schema.get("type") == "array":
215
+ item_schema: dict[str, Any] = schema.get("items", {})
216
+ return [
217
+ traverse_schema_and_transform(item, item_schema, citation_string_handler)
218
+ for item in output
219
+ ]
220
+ else:
221
+ return output
222
+
223
+
224
+ def _schema_requests_citations(schema: dict[str, Any]) -> bool:
225
+ """Check if any field in the schema requests citations by having 'citations': 'true'."""
226
+
227
+ def _check_field(field_schema: Any) -> bool:
228
+ if isinstance(field_schema, dict):
229
+ if field_schema.get("citations"): # type: ignore
230
+ return True
231
+ for value in field_schema.values(): # type: ignore
232
+ if isinstance(value, dict) and _check_field(value):
233
+ return True
234
+ elif isinstance(value, list):
235
+ for item in value: # type: ignore
236
+ if isinstance(item, dict) and _check_field(item):
237
+ return True
238
+ return False
239
+
240
+ return _check_field(schema)
@@ -0,0 +1,108 @@
1
+ import json
2
+ from typing import Any
3
+
4
+
5
+ def _repair_json(text: str) -> str:
6
+ """Strip leading/trailing text and fix unescaped quotes/newlines."""
7
+
8
+ json_start = None
9
+ for i, char in enumerate(text):
10
+ remaining = text[i:]
11
+ if (
12
+ char in '[{"'
13
+ or char.isdigit()
14
+ or char == "-"
15
+ or remaining.startswith("null")
16
+ or remaining.startswith("true")
17
+ or remaining.startswith("false")
18
+ ):
19
+ json_start = i
20
+ break
21
+ if json_start is None:
22
+ raise ValueError("No valid JSON start found")
23
+
24
+ result: list[str] = []
25
+ in_string = False
26
+ escape_next = False
27
+ depth = 0
28
+ started_with_container = text[json_start] in "[{"
29
+
30
+ for i in range(json_start, len(text)):
31
+ char = text[i]
32
+
33
+ if escape_next:
34
+ if in_string:
35
+ # Check if this is a valid escape sequence
36
+ is_valid_escape = char in '\\/bfnrt"' or (
37
+ char == "u"
38
+ and i + 4 < len(text)
39
+ and all(c in "0123456789abcdefABCDEF" for c in text[i + 1 : i + 5])
40
+ )
41
+ if not is_valid_escape:
42
+ # Invalid escape sequence - add another backslash to escape it
43
+ result.append("\\")
44
+ result.append(char)
45
+ escape_next = False
46
+ continue
47
+
48
+ if char == "\\":
49
+ result.append(char)
50
+ escape_next = True
51
+ continue
52
+
53
+ if char == '"':
54
+ if in_string:
55
+ # Check if quote should be escaped by looking at what follows
56
+ remaining = text[i + 1 :].lstrip()
57
+ if remaining and remaining[0] not in ':,}]"':
58
+ result.append('\\"')
59
+ continue
60
+ in_string = False
61
+ result.append(char)
62
+ # If we're at depth 0 and closed a top-level string, we're done
63
+ if depth == 0 and not started_with_container:
64
+ return "".join(result)
65
+ else:
66
+ in_string = True
67
+ result.append(char)
68
+ elif in_string and char == "\n":
69
+ result.append("\\n")
70
+ else:
71
+ result.append(char)
72
+
73
+ if not in_string:
74
+ if char in "[{":
75
+ depth += 1
76
+ elif char in "]}":
77
+ depth -= 1
78
+ if depth == 0:
79
+ return "".join(result)
80
+ # For primitives at top level (depth 0), stop at whitespace if we've consumed content
81
+ elif depth == 0 and not started_with_container and result and char in " \t\n\r":
82
+ # Check if this is trailing whitespace after a complete primitive
83
+ current = "".join(result).strip()
84
+ if current:
85
+ try:
86
+ json.loads(current)
87
+ return current
88
+ except (json.JSONDecodeError, ValueError):
89
+ pass
90
+
91
+ return "".join(result)
92
+
93
+
94
+ def forgiving_json_loads(text: str) -> Any:
95
+ """
96
+ Parse JSON from text, applying heuristics to fix common LLM mistakes.
97
+
98
+ Repairs applied:
99
+ - Strip leading/trailing non-JSON text
100
+ - Escape unescaped quotes and newlines inside strings
101
+ - Fix invalid escape sequences inside strings
102
+ """
103
+ if not text or not text.strip():
104
+ raise ValueError("Empty or whitespace-only input")
105
+
106
+ text = _repair_json(text)
107
+
108
+ return json.loads(text)
@@ -0,0 +1,84 @@
1
+ {
2
+ "$schema": "https://json-schema.org/draft/2020-12/schema",
3
+ "$id": "https://example.com/meta/mini-schema",
4
+ "title": "Meta-schema for Docent judge outputs. Makes some restrictions to 2020-12.",
5
+ "type": "object",
6
+ "additionalProperties": false,
7
+ "properties": {
8
+ "type": { "const": "object" },
9
+ "additionalProperties": { "const": false },
10
+ "required": {
11
+ "type": "array",
12
+ "items": { "type": "string" }
13
+ },
14
+
15
+ "properties": {
16
+ "type": "object",
17
+ "propertyNames": { "type": "string" },
18
+ "additionalProperties": {
19
+ "type": "object",
20
+ "additionalProperties": false,
21
+ "required": ["type"],
22
+
23
+ "properties": {
24
+ "type": {
25
+ "type": "string",
26
+ "enum": ["string", "integer", "number", "boolean"]
27
+ },
28
+ "description": {
29
+ "type": "string"
30
+ },
31
+ "citations": {
32
+ "type": "boolean"
33
+ },
34
+ "enum": {
35
+ "type": "array",
36
+ "items": { "type": "string" }
37
+ },
38
+ "format": {
39
+ "type": "string",
40
+ "enum": [
41
+ "date-time",
42
+ "date",
43
+ "time",
44
+ "email",
45
+ "hostname",
46
+ "ipv4",
47
+ "ipv6",
48
+ "uri",
49
+ "uuid"
50
+ ]
51
+ },
52
+ "minLength": {
53
+ "type": "integer",
54
+ "minimum": 0
55
+ },
56
+ "maxLength": {
57
+ "type": "integer",
58
+ "minimum": 0
59
+ },
60
+ "pattern": {
61
+ "type": "string"
62
+ },
63
+ "minimum": {
64
+ "type": "number"
65
+ },
66
+ "maximum": {
67
+ "type": "number"
68
+ },
69
+ "exclusiveMinimum": {
70
+ "type": "number"
71
+ },
72
+ "exclusiveMaximum": {
73
+ "type": "number"
74
+ },
75
+ "multipleOf": {
76
+ "type": "number",
77
+ "exclusiveMinimum": 0
78
+ }
79
+ }
80
+ }
81
+ }
82
+ },
83
+ "required": ["type", "properties"]
84
+ }
@@ -0,0 +1,29 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import jsonschema
6
+
7
+
8
+ def _load_meta_schema() -> dict[str, Any]:
9
+ """Load the rubric meta-schema from the adjacent JSON file."""
10
+ meta_schema_path = Path(__file__).with_suffix(".json")
11
+ with meta_schema_path.open("r", encoding="utf-8") as f:
12
+ return json.load(f)
13
+
14
+
15
+ _META_VALIDATOR = jsonschema.Draft202012Validator(_load_meta_schema())
16
+
17
+
18
+ def validate_judge_result_schema(schema: dict[str, Any]):
19
+ """Validate a proposed schema against the rubric meta-schema.
20
+
21
+ Raises:
22
+ jsonschema.ValidationError: If the schema is invalid
23
+ jsonschema.SchemaError: If the schema is not a valid 2020-12 schema
24
+ """
25
+ # First check that this is a valid 2020-12 schema
26
+ jsonschema.Draft202012Validator.check_schema(schema)
27
+
28
+ # Then check that it conforms to our subset of the 2020-12 schema
29
+ _META_VALIDATOR.validate(schema) # type: ignore