docent-python 0.1.50a0__tar.gz → 0.1.52a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/PKG-INFO +1 -1
- docent_python-0.1.52a0/docent/data_models/feedback.py +393 -0
- docent_python-0.1.52a0/docent/judges/util/voting.py +351 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/client.py +112 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/pyproject.toml +1 -1
- docent_python-0.1.50a0/docent/judges/util/voting.py +0 -140
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/.gitignore +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/LICENSE.md +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/README.md +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/exceptions.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/llm_output.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/llm_cache.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/llm_svc.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/model_registry.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/anthropic.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/common.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/google.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/openai.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/openrouter.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/preference_types.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/provider_registry.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_log_util/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_log_util/logger.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/_tiktoken_util.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/agent_run.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/content.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/message.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/response_format.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/tool.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/citation.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/formatted_objects.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/judge.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/metadata_util.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/regex.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/transcript.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/util.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/analysis.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/impl.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/runner.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/stats.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/types.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/forgiving_json.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/meta_schema.json +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/meta_schema.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/parse_output.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/template_formatter.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/loaders/load_inspect.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/__main__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/server.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/py.typed +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/load.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/log.eval +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/tb_airline.json +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/__init__.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/agent_run_writer.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/llm_context.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/llm_request.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/trace.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/trace_temp.py +0 -0
- {docent_python-0.1.50a0 → docent_python-0.1.52a0}/uv.lock +0 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Data structures for run-centric feedback elicitation and user context inference."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, model_validator
|
|
9
|
+
|
|
10
|
+
from docent.data_models.citation import InlineCitation
|
|
11
|
+
from docent.judges.util.voting import OutputDistribution
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _stable_json(value: Any) -> str:
|
|
15
|
+
return json.dumps(value, sort_keys=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _indent_lines(lines: list[str], indent: int) -> list[str]:
|
|
19
|
+
prefix = " " * max(0, indent)
|
|
20
|
+
return [f"{prefix}{line}" if line else "" for line in lines]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _tag_block(tag: str, body_lines: list[str], indent: int) -> list[str]:
|
|
24
|
+
lines = [f"<{tag}>"]
|
|
25
|
+
if body_lines:
|
|
26
|
+
lines.extend(_indent_lines(body_lines, indent))
|
|
27
|
+
else:
|
|
28
|
+
lines.extend(_indent_lines(["N/A"], indent))
|
|
29
|
+
lines.append(f"</{tag}>")
|
|
30
|
+
return lines
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _text_or_na(text: str | None) -> str:
|
|
34
|
+
if text is None:
|
|
35
|
+
return "N/A"
|
|
36
|
+
stripped = text.strip()
|
|
37
|
+
return stripped if stripped else "N/A"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _text_lines_or_na(text: str | None) -> list[str]:
|
|
41
|
+
return _text_or_na(text).splitlines()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _render_citations_block(citations: list[InlineCitation], indent: int) -> list[str]:
|
|
45
|
+
citation_payload = [citation.model_dump(mode="json") for citation in citations]
|
|
46
|
+
citation_text = _stable_json(citation_payload) if citation_payload else "N/A"
|
|
47
|
+
return _tag_block("Citations", [citation_text], indent)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _render_user_distribution_block(
|
|
51
|
+
user_distribution: OutputDistribution | None,
|
|
52
|
+
indent: int,
|
|
53
|
+
) -> list[str]:
|
|
54
|
+
distribution_text = (
|
|
55
|
+
_stable_json(user_distribution.model_dump(mode="json"))
|
|
56
|
+
if user_distribution is not None
|
|
57
|
+
else "N/A"
|
|
58
|
+
)
|
|
59
|
+
return _tag_block("Estimated user distribution p_u", [distribution_text], indent)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _render_user_distribution_reasoning_block(
|
|
63
|
+
reasoning: str | None,
|
|
64
|
+
reasoning_citations: list[InlineCitation] | None,
|
|
65
|
+
indent: int,
|
|
66
|
+
) -> list[str]:
|
|
67
|
+
body_lines = _text_lines_or_na(reasoning)
|
|
68
|
+
body_lines.extend(_render_citations_block(reasoning_citations or [], indent))
|
|
69
|
+
return _tag_block("p_u reasoning", body_lines, indent)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LabelingRequestFocusItem(BaseModel):
|
|
73
|
+
"""Specific rubric-related question the human labeler should inspect."""
|
|
74
|
+
|
|
75
|
+
question: str
|
|
76
|
+
citations: list[InlineCitation] = Field(default_factory=list[InlineCitation])
|
|
77
|
+
sample_answers: list[str] = Field(default_factory=list[str])
|
|
78
|
+
|
|
79
|
+
def to_str(self, indent: int = 0) -> str:
|
|
80
|
+
"""Render focus item in a deterministic LLM-facing format."""
|
|
81
|
+
lines: list[str] = []
|
|
82
|
+
|
|
83
|
+
# Render the question and its citations as one nested block.
|
|
84
|
+
question_lines = _text_lines_or_na(self.question)
|
|
85
|
+
question_lines.extend(_render_citations_block(self.citations, indent))
|
|
86
|
+
lines.extend(_tag_block("Question", question_lines, indent))
|
|
87
|
+
|
|
88
|
+
sample_answers_lines = (
|
|
89
|
+
[
|
|
90
|
+
f"Answer {sample_idx}: {sample_answer}"
|
|
91
|
+
for sample_idx, sample_answer in enumerate(self.sample_answers, start=1)
|
|
92
|
+
]
|
|
93
|
+
if self.sample_answers
|
|
94
|
+
else ["N/A"]
|
|
95
|
+
)
|
|
96
|
+
lines.extend(_tag_block("Sample Answers", sample_answers_lines, indent))
|
|
97
|
+
return "\n".join(lines)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class QAPair(BaseModel):
|
|
101
|
+
"""A single review-focus answer captured for one run."""
|
|
102
|
+
|
|
103
|
+
# What the user was shown
|
|
104
|
+
focus_index: int
|
|
105
|
+
|
|
106
|
+
# What the user responded
|
|
107
|
+
answer: str
|
|
108
|
+
explanation: str | None = None
|
|
109
|
+
|
|
110
|
+
# The user could have skipped this question and provided nothing
|
|
111
|
+
status: Literal["answered", "skipped"]
|
|
112
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
113
|
+
|
|
114
|
+
def to_str(self, labeling_request: "LabelingRequest", indent: int = 0) -> str:
|
|
115
|
+
"""Render QA pair in a deterministic LLM-facing format."""
|
|
116
|
+
if self.focus_index < 0 or self.focus_index >= len(labeling_request.review_focus):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"focus_index={self.focus_index} is out of bounds for review_focus length "
|
|
119
|
+
f"{len(labeling_request.review_focus)}"
|
|
120
|
+
)
|
|
121
|
+
focus_item = labeling_request.review_focus[self.focus_index]
|
|
122
|
+
lines = focus_item.to_str(indent=indent).splitlines()
|
|
123
|
+
lines.append(f"User answer: {_text_or_na(self.answer)}")
|
|
124
|
+
lines.append(f"User explanation: {_text_or_na(self.explanation)}")
|
|
125
|
+
return "\n".join(lines)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class LabelingRequest(BaseModel):
|
|
129
|
+
"""Structured labeling request shown to the user."""
|
|
130
|
+
|
|
131
|
+
title: str
|
|
132
|
+
review_context: str
|
|
133
|
+
review_context_citations: list[InlineCitation] = Field(default_factory=list[InlineCitation])
|
|
134
|
+
review_focus: list[LabelingRequestFocusItem] = Field(
|
|
135
|
+
default_factory=list[LabelingRequestFocusItem]
|
|
136
|
+
)
|
|
137
|
+
user_distribution: OutputDistribution | None = None
|
|
138
|
+
user_distribution_reasoning: str | None = None
|
|
139
|
+
|
|
140
|
+
def to_str(self, indent: int = 0) -> str:
|
|
141
|
+
"""Render labeling request in a deterministic LLM-facing format."""
|
|
142
|
+
body_lines: list[str] = [f"Title: {_text_or_na(self.title)}"]
|
|
143
|
+
|
|
144
|
+
review_context_lines = _text_lines_or_na(self.review_context)
|
|
145
|
+
review_context_lines.extend(_render_citations_block(self.review_context_citations, indent))
|
|
146
|
+
body_lines.extend(_tag_block("Review Context", review_context_lines, indent))
|
|
147
|
+
|
|
148
|
+
review_focus_lines: list[str] = []
|
|
149
|
+
if self.review_focus:
|
|
150
|
+
for focus_idx, focus_item in enumerate(self.review_focus, start=1):
|
|
151
|
+
focus_lines = focus_item.to_str(indent=indent).splitlines()
|
|
152
|
+
review_focus_lines.extend(_tag_block(f"Focus {focus_idx}", focus_lines, indent))
|
|
153
|
+
else:
|
|
154
|
+
review_focus_lines.append("N/A")
|
|
155
|
+
body_lines.extend(_tag_block("Review Focus", review_focus_lines, indent))
|
|
156
|
+
|
|
157
|
+
body_lines.extend(_render_user_distribution_block(self.user_distribution, indent))
|
|
158
|
+
body_lines.extend(
|
|
159
|
+
_render_user_distribution_reasoning_block(
|
|
160
|
+
self.user_distribution_reasoning,
|
|
161
|
+
reasoning_citations=None,
|
|
162
|
+
indent=indent,
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
lines = _tag_block("Labeling Request", body_lines, indent)
|
|
167
|
+
return "\n".join(lines)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class LabeledRun(BaseModel):
|
|
171
|
+
"""A human label for one agent run."""
|
|
172
|
+
|
|
173
|
+
agent_run_id: str
|
|
174
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
175
|
+
|
|
176
|
+
# What the user responded
|
|
177
|
+
label_value: dict[str, Any]
|
|
178
|
+
explanation: str | None = None
|
|
179
|
+
|
|
180
|
+
def to_str(
|
|
181
|
+
self,
|
|
182
|
+
labeling_request: LabelingRequest | None = None,
|
|
183
|
+
indent: int = 0,
|
|
184
|
+
) -> str:
|
|
185
|
+
"""Render user label in a deterministic LLM-facing format."""
|
|
186
|
+
body_lines = [
|
|
187
|
+
f"User label: {_stable_json(self.label_value)}",
|
|
188
|
+
f"User explanation: {_text_or_na(self.explanation)}",
|
|
189
|
+
]
|
|
190
|
+
if labeling_request is None:
|
|
191
|
+
return "\n".join(_tag_block("Label", body_lines, indent))
|
|
192
|
+
|
|
193
|
+
body_lines.extend(
|
|
194
|
+
_render_user_distribution_block(labeling_request.user_distribution, indent)
|
|
195
|
+
)
|
|
196
|
+
body_lines.extend(
|
|
197
|
+
_render_user_distribution_reasoning_block(
|
|
198
|
+
labeling_request.user_distribution_reasoning,
|
|
199
|
+
reasoning_citations=None,
|
|
200
|
+
indent=indent,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
return "\n".join(_tag_block("Label", body_lines, indent))
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class AgentRunFeedbackContext(BaseModel):
|
|
207
|
+
"""All feedback collected for a single agent run."""
|
|
208
|
+
|
|
209
|
+
feedback_context_id: str | None = None
|
|
210
|
+
agent_run_id: str
|
|
211
|
+
round: int
|
|
212
|
+
created_at: datetime = Field(default_factory=datetime.now)
|
|
213
|
+
last_updated: datetime = Field(default_factory=datetime.now)
|
|
214
|
+
|
|
215
|
+
# What the user was shown
|
|
216
|
+
labeling_request: LabelingRequest
|
|
217
|
+
|
|
218
|
+
# What the user responded
|
|
219
|
+
qa_pairs: list[QAPair] = Field(default_factory=list[QAPair])
|
|
220
|
+
label: LabeledRun | None = None
|
|
221
|
+
|
|
222
|
+
@model_validator(mode="after")
|
|
223
|
+
def validate_nested_agent_run_ids(self) -> "AgentRunFeedbackContext":
|
|
224
|
+
"""Ensure nested run IDs are consistent with the top-level run ID."""
|
|
225
|
+
if self.label is not None and self.label.agent_run_id != self.agent_run_id:
|
|
226
|
+
raise ValueError("label.agent_run_id must match agent_run_id")
|
|
227
|
+
return self
|
|
228
|
+
|
|
229
|
+
def to_str(self, indent: int = 0) -> str:
|
|
230
|
+
"""Render full feedback entry in a deterministic LLM-facing format."""
|
|
231
|
+
lines = self.labeling_request.to_str(indent=indent).splitlines()
|
|
232
|
+
|
|
233
|
+
qa_lines: list[str] = []
|
|
234
|
+
if not self.qa_pairs:
|
|
235
|
+
qa_lines.append("N/A")
|
|
236
|
+
else:
|
|
237
|
+
for qa_idx, qa_pair in enumerate(self.qa_pairs, start=1):
|
|
238
|
+
qa_entry_lines = qa_pair.to_str(
|
|
239
|
+
labeling_request=self.labeling_request,
|
|
240
|
+
indent=indent,
|
|
241
|
+
).splitlines()
|
|
242
|
+
qa_lines.extend(_tag_block(f"QA {qa_idx}", qa_entry_lines, indent))
|
|
243
|
+
lines.extend(_tag_block("Question Answer Pairs", qa_lines, indent))
|
|
244
|
+
|
|
245
|
+
if self.label is None:
|
|
246
|
+
label_body_lines = [
|
|
247
|
+
"User label: N/A",
|
|
248
|
+
"User explanation: N/A",
|
|
249
|
+
]
|
|
250
|
+
label_body_lines.extend(
|
|
251
|
+
_render_user_distribution_block(self.labeling_request.user_distribution, indent)
|
|
252
|
+
)
|
|
253
|
+
label_body_lines.extend(
|
|
254
|
+
_render_user_distribution_reasoning_block(
|
|
255
|
+
self.labeling_request.user_distribution_reasoning,
|
|
256
|
+
reasoning_citations=None,
|
|
257
|
+
indent=indent,
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
lines.extend(_tag_block("Label", label_body_lines, indent))
|
|
261
|
+
else:
|
|
262
|
+
lines.extend(
|
|
263
|
+
self.label.to_str(
|
|
264
|
+
labeling_request=self.labeling_request,
|
|
265
|
+
indent=indent,
|
|
266
|
+
).splitlines()
|
|
267
|
+
)
|
|
268
|
+
return "\n".join(lines)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class FeedbackContext(BaseModel):
|
|
272
|
+
"""Feedback context returned by the feedback REST API."""
|
|
273
|
+
|
|
274
|
+
feedback_context_id: str
|
|
275
|
+
feedback_session_id: str
|
|
276
|
+
agent_run_id: str
|
|
277
|
+
labeling_request: LabelingRequest
|
|
278
|
+
created_at: datetime
|
|
279
|
+
updated_at: datetime
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class FeedbackContextsResponse(BaseModel):
|
|
283
|
+
"""Round-scoped feedback contexts returned by the feedback REST API."""
|
|
284
|
+
|
|
285
|
+
current_round: int
|
|
286
|
+
contexts: list[FeedbackContext] = Field(default_factory=list[FeedbackContext])
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
FeedbackJobStatus = Literal["pending", "running", "cancelling", "canceled", "completed"]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class StartFeedbackContextsJobResponse(BaseModel):
|
|
293
|
+
"""Response for enqueueing or reusing a feedback contexts job."""
|
|
294
|
+
|
|
295
|
+
job_id: str
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class FeedbackContextsJobStateResponse(BaseModel):
|
|
299
|
+
"""Current feedback contexts job status and round-scoped contexts."""
|
|
300
|
+
|
|
301
|
+
job_id: str | None
|
|
302
|
+
job_status: FeedbackJobStatus | None
|
|
303
|
+
current_round: int
|
|
304
|
+
contexts: list[FeedbackContext] = Field(default_factory=list[FeedbackContext])
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class UserData(BaseModel):
|
|
308
|
+
"""User Data (U) for user-context inference and downstream evaluation."""
|
|
309
|
+
|
|
310
|
+
initial_rubric: str
|
|
311
|
+
agent_run_feedbacks: list[AgentRunFeedbackContext] = Field(
|
|
312
|
+
default_factory=lambda: list[AgentRunFeedbackContext]()
|
|
313
|
+
)
|
|
314
|
+
created_at: datetime = Field(default_factory=datetime.now)
|
|
315
|
+
last_updated: datetime = Field(default_factory=datetime.now)
|
|
316
|
+
|
|
317
|
+
def upsert_run_feedback(self, agent_run_feedback: AgentRunFeedbackContext) -> None:
|
|
318
|
+
"""Insert or replace feedback for an agent run ID, updating timestamps."""
|
|
319
|
+
now = datetime.now()
|
|
320
|
+
upserted_feedback = agent_run_feedback.model_copy(deep=True)
|
|
321
|
+
upserted_feedback.last_updated = now
|
|
322
|
+
|
|
323
|
+
for idx, existing in enumerate(self.agent_run_feedbacks):
|
|
324
|
+
if existing.agent_run_id != upserted_feedback.agent_run_id:
|
|
325
|
+
continue
|
|
326
|
+
upserted_feedback.created_at = existing.created_at
|
|
327
|
+
self.agent_run_feedbacks[idx] = upserted_feedback
|
|
328
|
+
self.last_updated = now
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
self.agent_run_feedbacks.append(upserted_feedback)
|
|
332
|
+
self.last_updated = now
|
|
333
|
+
|
|
334
|
+
def validate_against_agreement_keys(self, agreement_keys: set[str]) -> None:
|
|
335
|
+
"""Validate stored labels and p_u outcomes against rubric agreement keys."""
|
|
336
|
+
for feedback in self.agent_run_feedbacks:
|
|
337
|
+
run_id = feedback.agent_run_id
|
|
338
|
+
|
|
339
|
+
label = feedback.label
|
|
340
|
+
if label is not None:
|
|
341
|
+
invalid_label_keys = sorted(set(label.label_value.keys()) - agreement_keys)
|
|
342
|
+
if invalid_label_keys:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
"Run "
|
|
345
|
+
f"{run_id} has label_value keys outside rubric agreement keys: "
|
|
346
|
+
+ ", ".join(invalid_label_keys)
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
user_distribution = feedback.labeling_request.user_distribution
|
|
350
|
+
if user_distribution is None:
|
|
351
|
+
continue
|
|
352
|
+
|
|
353
|
+
for outcome_idx, outcome in enumerate(user_distribution.outcomes, start=1):
|
|
354
|
+
invalid_output_keys = sorted(set(outcome.output.keys()) - agreement_keys)
|
|
355
|
+
if invalid_output_keys:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
"Run "
|
|
358
|
+
f"{run_id} has user_distribution outcome #{outcome_idx} keys outside "
|
|
359
|
+
"rubric agreement keys: " + ", ".join(invalid_output_keys)
|
|
360
|
+
)
|
|
361
|
+
for key, value in outcome.output.items():
|
|
362
|
+
if isinstance(value, (str, bool, int, float)):
|
|
363
|
+
continue
|
|
364
|
+
raise ValueError(
|
|
365
|
+
"Run "
|
|
366
|
+
f"{run_id} has user_distribution outcome #{outcome_idx} non-scalar "
|
|
367
|
+
f"value for key '{key}': {type(value).__name__}"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def iter_answered_qa_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, QAPair]]:
|
|
371
|
+
"""Iterate answered QA pairs with their parent run feedback."""
|
|
372
|
+
for feedback in self.agent_run_feedbacks:
|
|
373
|
+
for qa_pair in feedback.qa_pairs:
|
|
374
|
+
if qa_pair.status == "answered":
|
|
375
|
+
yield feedback, qa_pair
|
|
376
|
+
|
|
377
|
+
def iter_skipped_qa_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, QAPair]]:
|
|
378
|
+
"""Iterate skipped QA pairs with their parent run feedback."""
|
|
379
|
+
for feedback in self.agent_run_feedbacks:
|
|
380
|
+
for qa_pair in feedback.qa_pairs:
|
|
381
|
+
if qa_pair.status == "skipped":
|
|
382
|
+
yield feedback, qa_pair
|
|
383
|
+
|
|
384
|
+
def iter_labeled_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, LabeledRun]]:
|
|
385
|
+
"""Iterate labeled run entries with their parent run feedback."""
|
|
386
|
+
for feedback in self.agent_run_feedbacks:
|
|
387
|
+
if feedback.label is None:
|
|
388
|
+
continue
|
|
389
|
+
yield feedback, feedback.label
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
# Backward-compatible alias used by older callers/scripts.
|
|
393
|
+
AgentRunFeedback = AgentRunFeedbackContext
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from itertools import product
|
|
5
|
+
from typing import Any, TypedDict, cast
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
AgreementValue = str | bool | int | float
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EstimateWithCI(TypedDict):
|
|
14
|
+
mean: float
|
|
15
|
+
var: float
|
|
16
|
+
n: int
|
|
17
|
+
ci_95: float
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
JudgeOutputDistribution = dict[AgreementValue, EstimateWithCI]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class DistributionOutcome(BaseModel):
|
|
24
|
+
"""Single outcome and probability mass for a predictive distribution."""
|
|
25
|
+
|
|
26
|
+
output: dict[str, Any]
|
|
27
|
+
probability: float
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class OutputDistribution(BaseModel):
|
|
31
|
+
"""Probability distribution over rubric-compliant outputs."""
|
|
32
|
+
|
|
33
|
+
outcomes: list[DistributionOutcome] = Field(default_factory=list[DistributionOutcome])
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _stable_json_dict(value: dict[str, Any]) -> str:
|
|
37
|
+
return json.dumps(value, sort_keys=True, separators=(",", ":"))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def normalize_output_distribution(distribution: OutputDistribution) -> OutputDistribution:
|
|
41
|
+
"""Normalize probabilities and merge duplicate outcomes by canonical JSON output."""
|
|
42
|
+
if not distribution.outcomes:
|
|
43
|
+
return OutputDistribution()
|
|
44
|
+
|
|
45
|
+
merged: dict[str, DistributionOutcome] = {}
|
|
46
|
+
for outcome in distribution.outcomes:
|
|
47
|
+
key = _stable_json_dict(outcome.output)
|
|
48
|
+
existing = merged.get(key)
|
|
49
|
+
if existing is None:
|
|
50
|
+
merged[key] = DistributionOutcome(
|
|
51
|
+
output=outcome.output,
|
|
52
|
+
probability=max(0.0, outcome.probability),
|
|
53
|
+
)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
existing.probability += max(0.0, outcome.probability)
|
|
57
|
+
|
|
58
|
+
merged_outcomes = list(merged.values())
|
|
59
|
+
if not merged_outcomes:
|
|
60
|
+
return OutputDistribution()
|
|
61
|
+
|
|
62
|
+
total_probability = sum(item.probability for item in merged_outcomes)
|
|
63
|
+
if total_probability <= 0:
|
|
64
|
+
uniform_prob = 1.0 / len(merged_outcomes)
|
|
65
|
+
for item in merged_outcomes:
|
|
66
|
+
item.probability = uniform_prob
|
|
67
|
+
else:
|
|
68
|
+
for item in merged_outcomes:
|
|
69
|
+
item.probability = item.probability / total_probability
|
|
70
|
+
|
|
71
|
+
merged_outcomes.sort(key=lambda item: item.probability, reverse=True)
|
|
72
|
+
|
|
73
|
+
return OutputDistribution(outcomes=merged_outcomes)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def assert_agreement_only_output_schema(schema: dict[str, Any]) -> list[str]:
|
|
77
|
+
"""Validate agreement-only schema contract for elicitation and entropy workflows.
|
|
78
|
+
|
|
79
|
+
Contract:
|
|
80
|
+
- Every top-level property must be an agreement key (enum or boolean),
|
|
81
|
+
- ``additionalProperties`` must be explicitly ``False``, and
|
|
82
|
+
- At least one agreement key must exist.
|
|
83
|
+
"""
|
|
84
|
+
# if schema.get("additionalProperties") is not False:
|
|
85
|
+
# raise ValueError(
|
|
86
|
+
# "Rubric output_schema must set additionalProperties to false for "
|
|
87
|
+
# "agreement-only entropy workflows."
|
|
88
|
+
# )
|
|
89
|
+
|
|
90
|
+
properties_obj = schema.get("properties")
|
|
91
|
+
if not isinstance(properties_obj, dict):
|
|
92
|
+
raise ValueError("Rubric output_schema must define an object-valued properties field.")
|
|
93
|
+
properties = cast(dict[str, Any], properties_obj)
|
|
94
|
+
|
|
95
|
+
agreement_keys = get_agreement_keys(schema)
|
|
96
|
+
if not agreement_keys:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
"Rubric output_schema must include at least one top-level agreement key "
|
|
99
|
+
"(enum or boolean)."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
non_agreement_keys = sorted(set(properties.keys()) - set(agreement_keys))
|
|
103
|
+
if non_agreement_keys:
|
|
104
|
+
details: list[str] = []
|
|
105
|
+
for key in non_agreement_keys:
|
|
106
|
+
field_obj = properties.get(key, {})
|
|
107
|
+
if isinstance(field_obj, dict):
|
|
108
|
+
field_schema = cast(dict[str, Any], field_obj)
|
|
109
|
+
field_type = field_schema.get("type")
|
|
110
|
+
field_type_text = field_type if isinstance(field_type, str) else "unknown"
|
|
111
|
+
has_enum = "enum" in field_schema
|
|
112
|
+
details.append(f"{key} (type={field_type_text}, enum={has_enum})")
|
|
113
|
+
else:
|
|
114
|
+
details.append(f"{key} (type=invalid-schema)")
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"Rubric output_schema includes non-agreement top-level keys: " + ", ".join(details)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return agreement_keys
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
123
|
+
"""Get top-level schema keys that support agreement computations.
|
|
124
|
+
|
|
125
|
+
This includes top-level enum and boolean fields.
|
|
126
|
+
"""
|
|
127
|
+
agreement_keys: list[str] = []
|
|
128
|
+
|
|
129
|
+
properties = schema.get("properties", {})
|
|
130
|
+
assert isinstance(properties, dict)
|
|
131
|
+
properties = cast(dict[str, Any], properties)
|
|
132
|
+
|
|
133
|
+
for key, field_schema in properties.items():
|
|
134
|
+
assert isinstance(field_schema, dict)
|
|
135
|
+
field_schema = cast(dict[str, Any], field_schema)
|
|
136
|
+
|
|
137
|
+
field_type = field_schema.get("type")
|
|
138
|
+
assert isinstance(field_type, str)
|
|
139
|
+
|
|
140
|
+
if field_type == "boolean":
|
|
141
|
+
agreement_keys.append(key)
|
|
142
|
+
elif "enum" in field_schema:
|
|
143
|
+
agreement_keys.append(key)
|
|
144
|
+
|
|
145
|
+
return agreement_keys
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_agreement_key_options(
|
|
149
|
+
schema: dict[str, Any],
|
|
150
|
+
agreement_keys: list[str] | None = None,
|
|
151
|
+
) -> dict[str, list[AgreementValue]]:
|
|
152
|
+
"""Return possible output options for each agreement key from schema."""
|
|
153
|
+
if agreement_keys is None:
|
|
154
|
+
agreement_keys = get_agreement_keys(schema)
|
|
155
|
+
|
|
156
|
+
properties = schema.get("properties", {})
|
|
157
|
+
assert isinstance(properties, dict)
|
|
158
|
+
properties = cast(dict[str, Any], properties)
|
|
159
|
+
|
|
160
|
+
key_options: dict[str, list[AgreementValue]] = {}
|
|
161
|
+
for key in agreement_keys:
|
|
162
|
+
field_schema_obj = properties.get(key, {})
|
|
163
|
+
assert isinstance(field_schema_obj, dict)
|
|
164
|
+
field_schema = cast(dict[str, Any], field_schema_obj)
|
|
165
|
+
|
|
166
|
+
field_type = field_schema.get("type")
|
|
167
|
+
assert isinstance(field_type, str)
|
|
168
|
+
|
|
169
|
+
if field_type == "boolean":
|
|
170
|
+
key_options[key] = [True, False]
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
if "enum" in field_schema:
|
|
174
|
+
enum_values = field_schema.get("enum")
|
|
175
|
+
assert isinstance(enum_values, list)
|
|
176
|
+
options: list[AgreementValue] = []
|
|
177
|
+
for enum_value in cast(list[object], enum_values):
|
|
178
|
+
assert isinstance(enum_value, (str, bool, int, float))
|
|
179
|
+
options.append(enum_value)
|
|
180
|
+
key_options[key] = options
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
key_options[key] = []
|
|
184
|
+
|
|
185
|
+
return key_options
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def enumerate_agreement_output_space(
|
|
189
|
+
schema: dict[str, Any],
|
|
190
|
+
agreement_keys: list[str] | None = None,
|
|
191
|
+
max_outcomes: int | None = None,
|
|
192
|
+
) -> list[dict[str, AgreementValue]]:
|
|
193
|
+
"""Enumerate all possible agreement-only output objects from schema key options."""
|
|
194
|
+
if agreement_keys is None:
|
|
195
|
+
agreement_keys = get_agreement_keys(schema)
|
|
196
|
+
if not agreement_keys:
|
|
197
|
+
raise ValueError("Cannot enumerate output space: no agreement keys were found.")
|
|
198
|
+
|
|
199
|
+
if max_outcomes is not None and max_outcomes <= 0:
|
|
200
|
+
raise ValueError("max_outcomes must be > 0 when provided.")
|
|
201
|
+
|
|
202
|
+
key_options = get_agreement_key_options(schema, agreement_keys)
|
|
203
|
+
ordered_options: list[list[AgreementValue]] = []
|
|
204
|
+
total_outcomes = 1
|
|
205
|
+
|
|
206
|
+
for key in agreement_keys:
|
|
207
|
+
options = key_options.get(key, [])
|
|
208
|
+
if not options:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
f"Cannot enumerate output space: agreement key '{key}' has no valid options."
|
|
211
|
+
)
|
|
212
|
+
ordered_options.append(options)
|
|
213
|
+
total_outcomes *= len(options)
|
|
214
|
+
if max_outcomes is not None and total_outcomes > max_outcomes:
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"Cannot enumerate output space: "
|
|
217
|
+
f"{total_outcomes} outcomes exceeds max_outcomes={max_outcomes}."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return [
|
|
221
|
+
{key: value for key, value in zip(agreement_keys, combo)}
|
|
222
|
+
for combo in product(*ordered_options)
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
|
|
227
|
+
"""Find the result that best matches modal values across agreement keys.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
indep_results: List of independent results to analyze
|
|
231
|
+
agreement_keys: Keys to measure agreement on
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Tuple of (max_idx, agt_key_modes_and_counts) where:
|
|
235
|
+
- max_idx is the index of the result that best matches modal values
|
|
236
|
+
- agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
ValueError: If no results are provided
|
|
240
|
+
"""
|
|
241
|
+
if not indep_results:
|
|
242
|
+
raise ValueError("No results to score")
|
|
243
|
+
|
|
244
|
+
# For each agreement key, compute the mode and count (or None, if no values exist for that key)
|
|
245
|
+
agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
|
|
246
|
+
for key in agreement_keys:
|
|
247
|
+
key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
|
|
248
|
+
if most_common_one := key_modes.most_common(1):
|
|
249
|
+
agt_key_modes_and_counts[key] = most_common_one[0]
|
|
250
|
+
else:
|
|
251
|
+
agt_key_modes_and_counts[key] = None
|
|
252
|
+
|
|
253
|
+
# Score each rollout based on how many agreement keys they match
|
|
254
|
+
# If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
|
|
255
|
+
# TODO(mengk): This may bias towards results that have more keys.
|
|
256
|
+
indep_result_scores: list[int] = []
|
|
257
|
+
for r in indep_results:
|
|
258
|
+
score = 0
|
|
259
|
+
for key in agreement_keys:
|
|
260
|
+
mode_and_count = agt_key_modes_and_counts[key]
|
|
261
|
+
if mode_and_count and r.get(key) == mode_and_count[0]:
|
|
262
|
+
score += 1
|
|
263
|
+
indep_result_scores.append(score)
|
|
264
|
+
|
|
265
|
+
# Argmax
|
|
266
|
+
max_idx = indep_result_scores.index(max(indep_result_scores))
|
|
267
|
+
|
|
268
|
+
return max_idx, agt_key_modes_and_counts
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def compute_output_distributions(
|
|
272
|
+
eq_weighted_outputs: list[dict[str, Any]],
|
|
273
|
+
output_schema: dict[str, Any],
|
|
274
|
+
agreement_keys: list[str],
|
|
275
|
+
) -> dict[str, JudgeOutputDistribution]:
|
|
276
|
+
"""Estimate per-key output value distributions from equally weighted outputs.
|
|
277
|
+
|
|
278
|
+
For each key in ``agreement_keys``, this function:
|
|
279
|
+
1. Enumerates allowed values from the schema (``enum`` or ``boolean``),
|
|
280
|
+
2. Counts observed non-null values in ``eq_weighted_outputs``,
|
|
281
|
+
3. Normalizes counts into empirical probabilities, and
|
|
282
|
+
4. Computes per-value summary statistics:
|
|
283
|
+
- ``mean``: empirical probability ``p``
|
|
284
|
+
- ``var``: Bernoulli variance ``p * (1 - p)``
|
|
285
|
+
- ``n``: number of observed (non-null) values for that key
|
|
286
|
+
- ``ci_95``: normal-approximation 95% CI half-width ``1.96 * sqrt(var / n)``
|
|
287
|
+
|
|
288
|
+
Optional fields that are missing in some results are skipped (not counted toward ``n``).
|
|
289
|
+
If no values are observed for a key, all value probabilities are set to ``0.0`` and ``n=0``.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
eq_weighted_outputs: Output objects to aggregate, each counted with equal weight.
|
|
293
|
+
output_schema: JSON schema used to validate/define allowed output values.
|
|
294
|
+
agreement_keys: Keys to include in the distribution computation.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Mapping from agreement key to per-value distribution estimates.
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
AssertionError: If an observed value is not one of the schema-derived possible values.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
key_options = get_agreement_key_options(output_schema, agreement_keys)
|
|
304
|
+
raw_counts: dict[str, dict[AgreementValue, int]] = {
|
|
305
|
+
key: {value: 0 for value in key_options.get(key, [])} for key in agreement_keys
|
|
306
|
+
}
|
|
307
|
+
# Collect counts for each possible value
|
|
308
|
+
for result in eq_weighted_outputs:
|
|
309
|
+
for key in agreement_keys:
|
|
310
|
+
if (value := result.get(key)) is not None: # Could be none if the key is optional
|
|
311
|
+
assert value in raw_counts[key], (
|
|
312
|
+
"this should never happen; the value must be in possible values, since judge results have been validated against the schema"
|
|
313
|
+
)
|
|
314
|
+
raw_counts[key][value] += 1
|
|
315
|
+
|
|
316
|
+
distributions: dict[str, JudgeOutputDistribution] = {}
|
|
317
|
+
for agt_key in agreement_keys:
|
|
318
|
+
distributions[agt_key] = {}
|
|
319
|
+
|
|
320
|
+
# First normalize the counts to get probabilities
|
|
321
|
+
counts = raw_counts[agt_key]
|
|
322
|
+
total = sum(counts.values())
|
|
323
|
+
probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
|
|
324
|
+
|
|
325
|
+
for output_key, value in probs.items():
|
|
326
|
+
mean, estimate_var = value, (value * (1 - value))
|
|
327
|
+
# TODO(mengk): change to the wilson score interval
|
|
328
|
+
ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
|
|
329
|
+
estimate: EstimateWithCI = {
|
|
330
|
+
"mean": mean,
|
|
331
|
+
"var": estimate_var,
|
|
332
|
+
"n": total,
|
|
333
|
+
"ci_95": ci_95,
|
|
334
|
+
}
|
|
335
|
+
distributions[agt_key][output_key] = estimate
|
|
336
|
+
|
|
337
|
+
return distributions
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def compute_entropy(distribution: OutputDistribution) -> float:
|
|
341
|
+
"""Compute Shannon entropy over normalized outcomes in nats."""
|
|
342
|
+
normalized = normalize_output_distribution(distribution)
|
|
343
|
+
if not normalized.outcomes:
|
|
344
|
+
return 0.0
|
|
345
|
+
|
|
346
|
+
entropy = 0.0
|
|
347
|
+
for outcome in normalized.outcomes:
|
|
348
|
+
probability = outcome.probability
|
|
349
|
+
if probability > 0:
|
|
350
|
+
entropy -= probability * math.log(probability)
|
|
351
|
+
return entropy
|
|
@@ -21,6 +21,11 @@ from tqdm import tqdm
|
|
|
21
21
|
from docent._llm_util.providers.preference_types import ModelOption
|
|
22
22
|
from docent._log_util.logger import LoggerAdapter, get_logger
|
|
23
23
|
from docent.data_models.agent_run import AgentRun
|
|
24
|
+
from docent.data_models.feedback import (
|
|
25
|
+
AgentRunFeedbackContext,
|
|
26
|
+
FeedbackContextsJobStateResponse,
|
|
27
|
+
StartFeedbackContextsJobResponse,
|
|
28
|
+
)
|
|
24
29
|
from docent.data_models.judge import Label
|
|
25
30
|
from docent.judges.util.meta_schema import validate_judge_result_schema
|
|
26
31
|
from docent.loaders import load_inspect
|
|
@@ -962,6 +967,113 @@ class Docent:
|
|
|
962
967
|
clustering_state = self.get_clustering_state(collection_id, rubric_id)
|
|
963
968
|
return clustering_state.get("assignments", {})
|
|
964
969
|
|
|
970
|
+
def create_feedback_session(
|
|
971
|
+
self,
|
|
972
|
+
collection_id: str,
|
|
973
|
+
rubric_id: str,
|
|
974
|
+
rubric_version: int,
|
|
975
|
+
) -> str:
|
|
976
|
+
"""Create a feedback session for a specific rubric version."""
|
|
977
|
+
url = f"{self._api_url}/feedback/{collection_id}/session"
|
|
978
|
+
payload = {
|
|
979
|
+
"rubric_id": rubric_id,
|
|
980
|
+
"rubric_version": rubric_version,
|
|
981
|
+
}
|
|
982
|
+
response = self._session.post(url, json=payload)
|
|
983
|
+
self._handle_response_errors(response)
|
|
984
|
+
return response.json()["feedback_session_id"]
|
|
985
|
+
|
|
986
|
+
def start_feedback_contexts_job(
|
|
987
|
+
self,
|
|
988
|
+
collection_id: str,
|
|
989
|
+
feedback_session_id: str,
|
|
990
|
+
num_samples: int = 50,
|
|
991
|
+
top_n: int = 20,
|
|
992
|
+
seed: int = 0,
|
|
993
|
+
candidate_pool_limit: int = 1_000,
|
|
994
|
+
where_clause: str | None = None,
|
|
995
|
+
increment_round: bool = False,
|
|
996
|
+
) -> StartFeedbackContextsJobResponse:
|
|
997
|
+
"""Start or reuse a background job to compute feedback contexts for a session."""
|
|
998
|
+
payload = {
|
|
999
|
+
"feedback_session_id": feedback_session_id,
|
|
1000
|
+
"num_samples": num_samples,
|
|
1001
|
+
"top_n": top_n,
|
|
1002
|
+
"seed": seed,
|
|
1003
|
+
"candidate_pool_limit": candidate_pool_limit,
|
|
1004
|
+
"where_clause": where_clause,
|
|
1005
|
+
"increment_round": increment_round,
|
|
1006
|
+
}
|
|
1007
|
+
url = f"{self._api_url}/feedback/{collection_id}/contexts/start"
|
|
1008
|
+
response = self._session.post(url, json=payload)
|
|
1009
|
+
self._handle_response_errors(response)
|
|
1010
|
+
return StartFeedbackContextsJobResponse.model_validate(response.json())
|
|
1011
|
+
|
|
1012
|
+
def get_feedback_contexts(
|
|
1013
|
+
self,
|
|
1014
|
+
collection_id: str,
|
|
1015
|
+
feedback_session_id: str,
|
|
1016
|
+
) -> FeedbackContextsJobStateResponse:
|
|
1017
|
+
"""Get feedback contexts state for a session, including job status and current round data."""
|
|
1018
|
+
payload = {
|
|
1019
|
+
"feedback_session_id": feedback_session_id,
|
|
1020
|
+
}
|
|
1021
|
+
url = f"{self._api_url}/feedback/{collection_id}/contexts/state"
|
|
1022
|
+
response = self._session.post(url, json=payload)
|
|
1023
|
+
self._handle_response_errors(response)
|
|
1024
|
+
return FeedbackContextsJobStateResponse.model_validate(response.json())
|
|
1025
|
+
|
|
1026
|
+
def get_agent_run_feedback_contexts_by_session(
|
|
1027
|
+
self,
|
|
1028
|
+
collection_id: str,
|
|
1029
|
+
feedback_session_id: str,
|
|
1030
|
+
) -> list[AgentRunFeedbackContext]:
|
|
1031
|
+
"""Get all persisted AgentRun feedback contexts for a feedback session.
|
|
1032
|
+
|
|
1033
|
+
Returns fully hydrated AgentRun feedback context objects from the database,
|
|
1034
|
+
including QA pairs and labels (if present), across all rounds in the session.
|
|
1035
|
+
"""
|
|
1036
|
+
url = f"{self._api_url}/feedback/{collection_id}/session/{feedback_session_id}/contexts"
|
|
1037
|
+
response = self._session.get(url)
|
|
1038
|
+
self._handle_response_errors(response)
|
|
1039
|
+
return [AgentRunFeedbackContext.model_validate(item) for item in response.json()]
|
|
1040
|
+
|
|
1041
|
+
def add_feedback_qa(
|
|
1042
|
+
self,
|
|
1043
|
+
collection_id: str,
|
|
1044
|
+
feedback_context_id: str,
|
|
1045
|
+
focus_index: int,
|
|
1046
|
+
answer: str,
|
|
1047
|
+
explanation: str | None = None,
|
|
1048
|
+
) -> None:
|
|
1049
|
+
"""Submit a QA answer for a feedback context."""
|
|
1050
|
+
payload = {
|
|
1051
|
+
"feedback_context_id": feedback_context_id,
|
|
1052
|
+
"focus_index": focus_index,
|
|
1053
|
+
"answer": answer,
|
|
1054
|
+
"explanation": explanation,
|
|
1055
|
+
}
|
|
1056
|
+
url = f"{self._api_url}/feedback/{collection_id}/qa"
|
|
1057
|
+
response = self._session.post(url, json=payload)
|
|
1058
|
+
self._handle_response_errors(response)
|
|
1059
|
+
|
|
1060
|
+
def upsert_feedback_label(
|
|
1061
|
+
self,
|
|
1062
|
+
collection_id: str,
|
|
1063
|
+
feedback_context_id: str,
|
|
1064
|
+
label_value: dict[str, Any],
|
|
1065
|
+
explanation: str | None = None,
|
|
1066
|
+
) -> None:
|
|
1067
|
+
"""Create or update a label for a feedback context."""
|
|
1068
|
+
payload = {
|
|
1069
|
+
"feedback_context_id": feedback_context_id,
|
|
1070
|
+
"label_value": label_value,
|
|
1071
|
+
"explanation": explanation,
|
|
1072
|
+
}
|
|
1073
|
+
url = f"{self._api_url}/feedback/{collection_id}/label"
|
|
1074
|
+
response = self._session.put(url, json=payload)
|
|
1075
|
+
self._handle_response_errors(response)
|
|
1076
|
+
|
|
965
1077
|
def create_label_set(
|
|
966
1078
|
self,
|
|
967
1079
|
collection_id: str,
|
|
@@ -1,140 +0,0 @@
|
|
|
1
|
-
from collections import Counter
|
|
2
|
-
from typing import Any, TypedDict, cast
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class EstimateWithCI(TypedDict):
|
|
8
|
-
mean: float
|
|
9
|
-
var: float
|
|
10
|
-
n: int
|
|
11
|
-
ci_95: float
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
|
|
18
|
-
"""Get list of top-level keys in schema that we want to measure agreement on.
|
|
19
|
-
|
|
20
|
-
This includes enum and bool fields.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
schema: JSON schema dict
|
|
24
|
-
|
|
25
|
-
Returns:
|
|
26
|
-
List of field names (keys) that should be used for measuring agreement
|
|
27
|
-
"""
|
|
28
|
-
agreement_keys: list[str] = []
|
|
29
|
-
|
|
30
|
-
properties = schema.get("properties", {})
|
|
31
|
-
assert isinstance(properties, dict)
|
|
32
|
-
properties = cast(dict[str, Any], properties)
|
|
33
|
-
|
|
34
|
-
for key, field_schema in properties.items():
|
|
35
|
-
assert isinstance(field_schema, dict)
|
|
36
|
-
field_schema = cast(dict[str, Any], field_schema)
|
|
37
|
-
|
|
38
|
-
field_type = field_schema.get("type")
|
|
39
|
-
assert isinstance(field_type, str)
|
|
40
|
-
|
|
41
|
-
# Include boolean fields
|
|
42
|
-
if field_type == "boolean":
|
|
43
|
-
agreement_keys.append(key)
|
|
44
|
-
# Include enum fields (strings and numbers must be in this category)
|
|
45
|
-
elif "enum" in field_schema:
|
|
46
|
-
agreement_keys.append(key)
|
|
47
|
-
|
|
48
|
-
return agreement_keys
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
|
|
52
|
-
"""Find the result that best matches modal values across agreement keys.
|
|
53
|
-
|
|
54
|
-
Args:
|
|
55
|
-
indep_results: List of independent results to analyze
|
|
56
|
-
agreement_keys: Keys to measure agreement on
|
|
57
|
-
|
|
58
|
-
Returns:
|
|
59
|
-
Tuple of (max_idx, agt_key_modes_and_counts) where:
|
|
60
|
-
- max_idx is the index of the result that best matches modal values
|
|
61
|
-
- agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
|
|
62
|
-
|
|
63
|
-
Raises:
|
|
64
|
-
ValueError: If no results are provided
|
|
65
|
-
"""
|
|
66
|
-
if not indep_results:
|
|
67
|
-
raise ValueError("No results to score")
|
|
68
|
-
|
|
69
|
-
# For each agreement key, compute the mode and count (or None, if no values exist for that key)
|
|
70
|
-
agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
|
|
71
|
-
for key in agreement_keys:
|
|
72
|
-
key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
|
|
73
|
-
if most_common_one := key_modes.most_common(1):
|
|
74
|
-
agt_key_modes_and_counts[key] = most_common_one[0]
|
|
75
|
-
else:
|
|
76
|
-
agt_key_modes_and_counts[key] = None
|
|
77
|
-
|
|
78
|
-
# Score each rollout based on how many agreement keys they match
|
|
79
|
-
# If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
|
|
80
|
-
# TODO(mengk): This may bias towards results that have more keys.
|
|
81
|
-
indep_result_scores: list[int] = []
|
|
82
|
-
for r in indep_results:
|
|
83
|
-
score = 0
|
|
84
|
-
for key in agreement_keys:
|
|
85
|
-
mode_and_count = agt_key_modes_and_counts[key]
|
|
86
|
-
if mode_and_count and r.get(key) == mode_and_count[0]:
|
|
87
|
-
score += 1
|
|
88
|
-
indep_result_scores.append(score)
|
|
89
|
-
|
|
90
|
-
# Argmax
|
|
91
|
-
max_idx = indep_result_scores.index(max(indep_result_scores))
|
|
92
|
-
|
|
93
|
-
return max_idx, agt_key_modes_and_counts
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def compute_output_distributions(
|
|
97
|
-
indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
|
|
98
|
-
):
|
|
99
|
-
def _get_possible_values(key: str) -> list[str | bool | int | float]:
|
|
100
|
-
if "enum" in output_schema.get("properties", {}).get(key, {}):
|
|
101
|
-
return output_schema.get("properties", {}).get(key, {}).get("enum", [])
|
|
102
|
-
elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
|
|
103
|
-
return [True, False]
|
|
104
|
-
else:
|
|
105
|
-
return []
|
|
106
|
-
|
|
107
|
-
raw_counts: dict[str, dict[str | bool | int | float, int]] = {
|
|
108
|
-
key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
|
|
109
|
-
}
|
|
110
|
-
# Collect counts for each possible value
|
|
111
|
-
for result in indep_results:
|
|
112
|
-
for key in agreement_keys:
|
|
113
|
-
if (value := result.get(key)) is not None: # Could be none if the key is optional
|
|
114
|
-
assert value in raw_counts[key], (
|
|
115
|
-
"this should never happen; the value must be in possible values, since judge results have been validated against the schema"
|
|
116
|
-
)
|
|
117
|
-
raw_counts[key][value] += 1
|
|
118
|
-
|
|
119
|
-
distributions: dict[str, JudgeOutputDistribution] = {}
|
|
120
|
-
for agt_key in agreement_keys:
|
|
121
|
-
distributions[agt_key] = {}
|
|
122
|
-
|
|
123
|
-
# First normalize the counts to get probabilities
|
|
124
|
-
counts = raw_counts[agt_key]
|
|
125
|
-
total = sum(counts.values())
|
|
126
|
-
probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
|
|
127
|
-
|
|
128
|
-
for output_key, value in probs.items():
|
|
129
|
-
mean, estimate_var = value, (value * (1 - value))
|
|
130
|
-
# TODO(mengk): change to the wilson score interval
|
|
131
|
-
ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
|
|
132
|
-
estimate: EstimateWithCI = {
|
|
133
|
-
"mean": mean,
|
|
134
|
-
"var": estimate_var,
|
|
135
|
-
"n": total,
|
|
136
|
-
"ci_95": ci_95,
|
|
137
|
-
}
|
|
138
|
-
distributions[agt_key][output_key] = estimate
|
|
139
|
-
|
|
140
|
-
return distributions
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/exceptions.py
RENAMED
|
File without changes
|
{docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/llm_output.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/preference_types.py
RENAMED
|
File without changes
|
{docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/provider_registry.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/response_format.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|