docent-python 0.1.50a0__tar.gz → 0.1.52a0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/PKG-INFO +1 -1
  2. docent_python-0.1.52a0/docent/data_models/feedback.py +393 -0
  3. docent_python-0.1.52a0/docent/judges/util/voting.py +351 -0
  4. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/client.py +112 -0
  5. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/pyproject.toml +1 -1
  6. docent_python-0.1.50a0/docent/judges/util/voting.py +0 -140
  7. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/.gitignore +0 -0
  8. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/LICENSE.md +0 -0
  9. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/README.md +0 -0
  10. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/__init__.py +0 -0
  11. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/__init__.py +0 -0
  12. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/__init__.py +0 -0
  13. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/exceptions.py +0 -0
  14. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/data_models/llm_output.py +0 -0
  15. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/llm_cache.py +0 -0
  16. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/llm_svc.py +0 -0
  17. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/model_registry.py +0 -0
  18. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/__init__.py +0 -0
  19. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/anthropic.py +0 -0
  20. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/common.py +0 -0
  21. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/google.py +0 -0
  22. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/openai.py +0 -0
  23. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/openrouter.py +0 -0
  24. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/preference_types.py +0 -0
  25. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_llm_util/providers/provider_registry.py +0 -0
  26. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_log_util/__init__.py +0 -0
  27. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/_log_util/logger.py +0 -0
  28. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/__init__.py +0 -0
  29. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/_tiktoken_util.py +0 -0
  30. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/agent_run.py +0 -0
  31. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/__init__.py +0 -0
  32. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/content.py +0 -0
  33. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/message.py +0 -0
  34. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/response_format.py +0 -0
  35. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/chat/tool.py +0 -0
  36. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/citation.py +0 -0
  37. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/formatted_objects.py +0 -0
  38. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/judge.py +0 -0
  39. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/metadata_util.py +0 -0
  40. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/regex.py +0 -0
  41. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/transcript.py +0 -0
  42. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/data_models/util.py +0 -0
  43. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/__init__.py +0 -0
  44. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/analysis.py +0 -0
  45. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/impl.py +0 -0
  46. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/runner.py +0 -0
  47. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/stats.py +0 -0
  48. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/types.py +0 -0
  49. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/forgiving_json.py +0 -0
  50. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/meta_schema.json +0 -0
  51. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/meta_schema.py +0 -0
  52. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/parse_output.py +0 -0
  53. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/judges/util/template_formatter.py +0 -0
  54. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/loaders/load_inspect.py +0 -0
  55. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/__init__.py +0 -0
  56. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/__main__.py +0 -0
  57. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/mcp/server.py +0 -0
  58. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/py.typed +0 -0
  59. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/__init__.py +0 -0
  60. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/load.py +0 -0
  61. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/log.eval +0 -0
  62. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/samples/tb_airline.json +0 -0
  63. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/__init__.py +0 -0
  64. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/agent_run_writer.py +0 -0
  65. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/llm_context.py +0 -0
  66. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/sdk/llm_request.py +0 -0
  67. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/trace.py +0 -0
  68. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/docent/trace_temp.py +0 -0
  69. {docent_python-0.1.50a0 → docent_python-0.1.52a0}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.50a0
3
+ Version: 0.1.52a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -0,0 +1,393 @@
1
+ """Data structures for run-centric feedback elicitation and user context inference."""
2
+
3
+ import json
4
+ from collections.abc import Iterator
5
+ from datetime import datetime
6
+ from typing import Any, Literal
7
+
8
+ from pydantic import BaseModel, Field, model_validator
9
+
10
+ from docent.data_models.citation import InlineCitation
11
+ from docent.judges.util.voting import OutputDistribution
12
+
13
+
14
+ def _stable_json(value: Any) -> str:
15
+ return json.dumps(value, sort_keys=True)
16
+
17
+
18
+ def _indent_lines(lines: list[str], indent: int) -> list[str]:
19
+ prefix = " " * max(0, indent)
20
+ return [f"{prefix}{line}" if line else "" for line in lines]
21
+
22
+
23
+ def _tag_block(tag: str, body_lines: list[str], indent: int) -> list[str]:
24
+ lines = [f"<{tag}>"]
25
+ if body_lines:
26
+ lines.extend(_indent_lines(body_lines, indent))
27
+ else:
28
+ lines.extend(_indent_lines(["N/A"], indent))
29
+ lines.append(f"</{tag}>")
30
+ return lines
31
+
32
+
33
+ def _text_or_na(text: str | None) -> str:
34
+ if text is None:
35
+ return "N/A"
36
+ stripped = text.strip()
37
+ return stripped if stripped else "N/A"
38
+
39
+
40
+ def _text_lines_or_na(text: str | None) -> list[str]:
41
+ return _text_or_na(text).splitlines()
42
+
43
+
44
+ def _render_citations_block(citations: list[InlineCitation], indent: int) -> list[str]:
45
+ citation_payload = [citation.model_dump(mode="json") for citation in citations]
46
+ citation_text = _stable_json(citation_payload) if citation_payload else "N/A"
47
+ return _tag_block("Citations", [citation_text], indent)
48
+
49
+
50
+ def _render_user_distribution_block(
51
+ user_distribution: OutputDistribution | None,
52
+ indent: int,
53
+ ) -> list[str]:
54
+ distribution_text = (
55
+ _stable_json(user_distribution.model_dump(mode="json"))
56
+ if user_distribution is not None
57
+ else "N/A"
58
+ )
59
+ return _tag_block("Estimated user distribution p_u", [distribution_text], indent)
60
+
61
+
62
+ def _render_user_distribution_reasoning_block(
63
+ reasoning: str | None,
64
+ reasoning_citations: list[InlineCitation] | None,
65
+ indent: int,
66
+ ) -> list[str]:
67
+ body_lines = _text_lines_or_na(reasoning)
68
+ body_lines.extend(_render_citations_block(reasoning_citations or [], indent))
69
+ return _tag_block("p_u reasoning", body_lines, indent)
70
+
71
+
72
+ class LabelingRequestFocusItem(BaseModel):
73
+ """Specific rubric-related question the human labeler should inspect."""
74
+
75
+ question: str
76
+ citations: list[InlineCitation] = Field(default_factory=list[InlineCitation])
77
+ sample_answers: list[str] = Field(default_factory=list[str])
78
+
79
+ def to_str(self, indent: int = 0) -> str:
80
+ """Render focus item in a deterministic LLM-facing format."""
81
+ lines: list[str] = []
82
+
83
+ # Render the question and its citations as one nested block.
84
+ question_lines = _text_lines_or_na(self.question)
85
+ question_lines.extend(_render_citations_block(self.citations, indent))
86
+ lines.extend(_tag_block("Question", question_lines, indent))
87
+
88
+ sample_answers_lines = (
89
+ [
90
+ f"Answer {sample_idx}: {sample_answer}"
91
+ for sample_idx, sample_answer in enumerate(self.sample_answers, start=1)
92
+ ]
93
+ if self.sample_answers
94
+ else ["N/A"]
95
+ )
96
+ lines.extend(_tag_block("Sample Answers", sample_answers_lines, indent))
97
+ return "\n".join(lines)
98
+
99
+
100
+ class QAPair(BaseModel):
101
+ """A single review-focus answer captured for one run."""
102
+
103
+ # What the user was shown
104
+ focus_index: int
105
+
106
+ # What the user responded
107
+ answer: str
108
+ explanation: str | None = None
109
+
110
+ # The user could have skipped this question and provided nothing
111
+ status: Literal["answered", "skipped"]
112
+ timestamp: datetime = Field(default_factory=datetime.now)
113
+
114
+ def to_str(self, labeling_request: "LabelingRequest", indent: int = 0) -> str:
115
+ """Render QA pair in a deterministic LLM-facing format."""
116
+ if self.focus_index < 0 or self.focus_index >= len(labeling_request.review_focus):
117
+ raise ValueError(
118
+ f"focus_index={self.focus_index} is out of bounds for review_focus length "
119
+ f"{len(labeling_request.review_focus)}"
120
+ )
121
+ focus_item = labeling_request.review_focus[self.focus_index]
122
+ lines = focus_item.to_str(indent=indent).splitlines()
123
+ lines.append(f"User answer: {_text_or_na(self.answer)}")
124
+ lines.append(f"User explanation: {_text_or_na(self.explanation)}")
125
+ return "\n".join(lines)
126
+
127
+
128
+ class LabelingRequest(BaseModel):
129
+ """Structured labeling request shown to the user."""
130
+
131
+ title: str
132
+ review_context: str
133
+ review_context_citations: list[InlineCitation] = Field(default_factory=list[InlineCitation])
134
+ review_focus: list[LabelingRequestFocusItem] = Field(
135
+ default_factory=list[LabelingRequestFocusItem]
136
+ )
137
+ user_distribution: OutputDistribution | None = None
138
+ user_distribution_reasoning: str | None = None
139
+
140
+ def to_str(self, indent: int = 0) -> str:
141
+ """Render labeling request in a deterministic LLM-facing format."""
142
+ body_lines: list[str] = [f"Title: {_text_or_na(self.title)}"]
143
+
144
+ review_context_lines = _text_lines_or_na(self.review_context)
145
+ review_context_lines.extend(_render_citations_block(self.review_context_citations, indent))
146
+ body_lines.extend(_tag_block("Review Context", review_context_lines, indent))
147
+
148
+ review_focus_lines: list[str] = []
149
+ if self.review_focus:
150
+ for focus_idx, focus_item in enumerate(self.review_focus, start=1):
151
+ focus_lines = focus_item.to_str(indent=indent).splitlines()
152
+ review_focus_lines.extend(_tag_block(f"Focus {focus_idx}", focus_lines, indent))
153
+ else:
154
+ review_focus_lines.append("N/A")
155
+ body_lines.extend(_tag_block("Review Focus", review_focus_lines, indent))
156
+
157
+ body_lines.extend(_render_user_distribution_block(self.user_distribution, indent))
158
+ body_lines.extend(
159
+ _render_user_distribution_reasoning_block(
160
+ self.user_distribution_reasoning,
161
+ reasoning_citations=None,
162
+ indent=indent,
163
+ )
164
+ )
165
+
166
+ lines = _tag_block("Labeling Request", body_lines, indent)
167
+ return "\n".join(lines)
168
+
169
+
170
+ class LabeledRun(BaseModel):
171
+ """A human label for one agent run."""
172
+
173
+ agent_run_id: str
174
+ timestamp: datetime = Field(default_factory=datetime.now)
175
+
176
+ # What the user responded
177
+ label_value: dict[str, Any]
178
+ explanation: str | None = None
179
+
180
+ def to_str(
181
+ self,
182
+ labeling_request: LabelingRequest | None = None,
183
+ indent: int = 0,
184
+ ) -> str:
185
+ """Render user label in a deterministic LLM-facing format."""
186
+ body_lines = [
187
+ f"User label: {_stable_json(self.label_value)}",
188
+ f"User explanation: {_text_or_na(self.explanation)}",
189
+ ]
190
+ if labeling_request is None:
191
+ return "\n".join(_tag_block("Label", body_lines, indent))
192
+
193
+ body_lines.extend(
194
+ _render_user_distribution_block(labeling_request.user_distribution, indent)
195
+ )
196
+ body_lines.extend(
197
+ _render_user_distribution_reasoning_block(
198
+ labeling_request.user_distribution_reasoning,
199
+ reasoning_citations=None,
200
+ indent=indent,
201
+ )
202
+ )
203
+ return "\n".join(_tag_block("Label", body_lines, indent))
204
+
205
+
206
+ class AgentRunFeedbackContext(BaseModel):
207
+ """All feedback collected for a single agent run."""
208
+
209
+ feedback_context_id: str | None = None
210
+ agent_run_id: str
211
+ round: int
212
+ created_at: datetime = Field(default_factory=datetime.now)
213
+ last_updated: datetime = Field(default_factory=datetime.now)
214
+
215
+ # What the user was shown
216
+ labeling_request: LabelingRequest
217
+
218
+ # What the user responded
219
+ qa_pairs: list[QAPair] = Field(default_factory=list[QAPair])
220
+ label: LabeledRun | None = None
221
+
222
+ @model_validator(mode="after")
223
+ def validate_nested_agent_run_ids(self) -> "AgentRunFeedbackContext":
224
+ """Ensure nested run IDs are consistent with the top-level run ID."""
225
+ if self.label is not None and self.label.agent_run_id != self.agent_run_id:
226
+ raise ValueError("label.agent_run_id must match agent_run_id")
227
+ return self
228
+
229
+ def to_str(self, indent: int = 0) -> str:
230
+ """Render full feedback entry in a deterministic LLM-facing format."""
231
+ lines = self.labeling_request.to_str(indent=indent).splitlines()
232
+
233
+ qa_lines: list[str] = []
234
+ if not self.qa_pairs:
235
+ qa_lines.append("N/A")
236
+ else:
237
+ for qa_idx, qa_pair in enumerate(self.qa_pairs, start=1):
238
+ qa_entry_lines = qa_pair.to_str(
239
+ labeling_request=self.labeling_request,
240
+ indent=indent,
241
+ ).splitlines()
242
+ qa_lines.extend(_tag_block(f"QA {qa_idx}", qa_entry_lines, indent))
243
+ lines.extend(_tag_block("Question Answer Pairs", qa_lines, indent))
244
+
245
+ if self.label is None:
246
+ label_body_lines = [
247
+ "User label: N/A",
248
+ "User explanation: N/A",
249
+ ]
250
+ label_body_lines.extend(
251
+ _render_user_distribution_block(self.labeling_request.user_distribution, indent)
252
+ )
253
+ label_body_lines.extend(
254
+ _render_user_distribution_reasoning_block(
255
+ self.labeling_request.user_distribution_reasoning,
256
+ reasoning_citations=None,
257
+ indent=indent,
258
+ )
259
+ )
260
+ lines.extend(_tag_block("Label", label_body_lines, indent))
261
+ else:
262
+ lines.extend(
263
+ self.label.to_str(
264
+ labeling_request=self.labeling_request,
265
+ indent=indent,
266
+ ).splitlines()
267
+ )
268
+ return "\n".join(lines)
269
+
270
+
271
+ class FeedbackContext(BaseModel):
272
+ """Feedback context returned by the feedback REST API."""
273
+
274
+ feedback_context_id: str
275
+ feedback_session_id: str
276
+ agent_run_id: str
277
+ labeling_request: LabelingRequest
278
+ created_at: datetime
279
+ updated_at: datetime
280
+
281
+
282
+ class FeedbackContextsResponse(BaseModel):
283
+ """Round-scoped feedback contexts returned by the feedback REST API."""
284
+
285
+ current_round: int
286
+ contexts: list[FeedbackContext] = Field(default_factory=list[FeedbackContext])
287
+
288
+
289
+ FeedbackJobStatus = Literal["pending", "running", "cancelling", "canceled", "completed"]
290
+
291
+
292
+ class StartFeedbackContextsJobResponse(BaseModel):
293
+ """Response for enqueueing or reusing a feedback contexts job."""
294
+
295
+ job_id: str
296
+
297
+
298
+ class FeedbackContextsJobStateResponse(BaseModel):
299
+ """Current feedback contexts job status and round-scoped contexts."""
300
+
301
+ job_id: str | None
302
+ job_status: FeedbackJobStatus | None
303
+ current_round: int
304
+ contexts: list[FeedbackContext] = Field(default_factory=list[FeedbackContext])
305
+
306
+
307
+ class UserData(BaseModel):
308
+ """User Data (U) for user-context inference and downstream evaluation."""
309
+
310
+ initial_rubric: str
311
+ agent_run_feedbacks: list[AgentRunFeedbackContext] = Field(
312
+ default_factory=lambda: list[AgentRunFeedbackContext]()
313
+ )
314
+ created_at: datetime = Field(default_factory=datetime.now)
315
+ last_updated: datetime = Field(default_factory=datetime.now)
316
+
317
+ def upsert_run_feedback(self, agent_run_feedback: AgentRunFeedbackContext) -> None:
318
+ """Insert or replace feedback for an agent run ID, updating timestamps."""
319
+ now = datetime.now()
320
+ upserted_feedback = agent_run_feedback.model_copy(deep=True)
321
+ upserted_feedback.last_updated = now
322
+
323
+ for idx, existing in enumerate(self.agent_run_feedbacks):
324
+ if existing.agent_run_id != upserted_feedback.agent_run_id:
325
+ continue
326
+ upserted_feedback.created_at = existing.created_at
327
+ self.agent_run_feedbacks[idx] = upserted_feedback
328
+ self.last_updated = now
329
+ return
330
+
331
+ self.agent_run_feedbacks.append(upserted_feedback)
332
+ self.last_updated = now
333
+
334
+ def validate_against_agreement_keys(self, agreement_keys: set[str]) -> None:
335
+ """Validate stored labels and p_u outcomes against rubric agreement keys."""
336
+ for feedback in self.agent_run_feedbacks:
337
+ run_id = feedback.agent_run_id
338
+
339
+ label = feedback.label
340
+ if label is not None:
341
+ invalid_label_keys = sorted(set(label.label_value.keys()) - agreement_keys)
342
+ if invalid_label_keys:
343
+ raise ValueError(
344
+ "Run "
345
+ f"{run_id} has label_value keys outside rubric agreement keys: "
346
+ + ", ".join(invalid_label_keys)
347
+ )
348
+
349
+ user_distribution = feedback.labeling_request.user_distribution
350
+ if user_distribution is None:
351
+ continue
352
+
353
+ for outcome_idx, outcome in enumerate(user_distribution.outcomes, start=1):
354
+ invalid_output_keys = sorted(set(outcome.output.keys()) - agreement_keys)
355
+ if invalid_output_keys:
356
+ raise ValueError(
357
+ "Run "
358
+ f"{run_id} has user_distribution outcome #{outcome_idx} keys outside "
359
+ "rubric agreement keys: " + ", ".join(invalid_output_keys)
360
+ )
361
+ for key, value in outcome.output.items():
362
+ if isinstance(value, (str, bool, int, float)):
363
+ continue
364
+ raise ValueError(
365
+ "Run "
366
+ f"{run_id} has user_distribution outcome #{outcome_idx} non-scalar "
367
+ f"value for key '{key}': {type(value).__name__}"
368
+ )
369
+
370
+ def iter_answered_qa_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, QAPair]]:
371
+ """Iterate answered QA pairs with their parent run feedback."""
372
+ for feedback in self.agent_run_feedbacks:
373
+ for qa_pair in feedback.qa_pairs:
374
+ if qa_pair.status == "answered":
375
+ yield feedback, qa_pair
376
+
377
+ def iter_skipped_qa_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, QAPair]]:
378
+ """Iterate skipped QA pairs with their parent run feedback."""
379
+ for feedback in self.agent_run_feedbacks:
380
+ for qa_pair in feedback.qa_pairs:
381
+ if qa_pair.status == "skipped":
382
+ yield feedback, qa_pair
383
+
384
+ def iter_labeled_entries(self) -> Iterator[tuple[AgentRunFeedbackContext, LabeledRun]]:
385
+ """Iterate labeled run entries with their parent run feedback."""
386
+ for feedback in self.agent_run_feedbacks:
387
+ if feedback.label is None:
388
+ continue
389
+ yield feedback, feedback.label
390
+
391
+
392
+ # Backward-compatible alias used by older callers/scripts.
393
+ AgentRunFeedback = AgentRunFeedbackContext
@@ -0,0 +1,351 @@
1
+ import json
2
+ import math
3
+ from collections import Counter
4
+ from itertools import product
5
+ from typing import Any, TypedDict, cast
6
+
7
+ import numpy as np
8
+ from pydantic import BaseModel, Field
9
+
10
+ AgreementValue = str | bool | int | float
11
+
12
+
13
+ class EstimateWithCI(TypedDict):
14
+ mean: float
15
+ var: float
16
+ n: int
17
+ ci_95: float
18
+
19
+
20
+ JudgeOutputDistribution = dict[AgreementValue, EstimateWithCI]
21
+
22
+
23
+ class DistributionOutcome(BaseModel):
24
+ """Single outcome and probability mass for a predictive distribution."""
25
+
26
+ output: dict[str, Any]
27
+ probability: float
28
+
29
+
30
+ class OutputDistribution(BaseModel):
31
+ """Probability distribution over rubric-compliant outputs."""
32
+
33
+ outcomes: list[DistributionOutcome] = Field(default_factory=list[DistributionOutcome])
34
+
35
+
36
+ def _stable_json_dict(value: dict[str, Any]) -> str:
37
+ return json.dumps(value, sort_keys=True, separators=(",", ":"))
38
+
39
+
40
+ def normalize_output_distribution(distribution: OutputDistribution) -> OutputDistribution:
41
+ """Normalize probabilities and merge duplicate outcomes by canonical JSON output."""
42
+ if not distribution.outcomes:
43
+ return OutputDistribution()
44
+
45
+ merged: dict[str, DistributionOutcome] = {}
46
+ for outcome in distribution.outcomes:
47
+ key = _stable_json_dict(outcome.output)
48
+ existing = merged.get(key)
49
+ if existing is None:
50
+ merged[key] = DistributionOutcome(
51
+ output=outcome.output,
52
+ probability=max(0.0, outcome.probability),
53
+ )
54
+ continue
55
+
56
+ existing.probability += max(0.0, outcome.probability)
57
+
58
+ merged_outcomes = list(merged.values())
59
+ if not merged_outcomes:
60
+ return OutputDistribution()
61
+
62
+ total_probability = sum(item.probability for item in merged_outcomes)
63
+ if total_probability <= 0:
64
+ uniform_prob = 1.0 / len(merged_outcomes)
65
+ for item in merged_outcomes:
66
+ item.probability = uniform_prob
67
+ else:
68
+ for item in merged_outcomes:
69
+ item.probability = item.probability / total_probability
70
+
71
+ merged_outcomes.sort(key=lambda item: item.probability, reverse=True)
72
+
73
+ return OutputDistribution(outcomes=merged_outcomes)
74
+
75
+
76
+ def assert_agreement_only_output_schema(schema: dict[str, Any]) -> list[str]:
77
+ """Validate agreement-only schema contract for elicitation and entropy workflows.
78
+
79
+ Contract:
80
+ - Every top-level property must be an agreement key (enum or boolean),
81
+ - ``additionalProperties`` must be explicitly ``False``, and
82
+ - At least one agreement key must exist.
83
+ """
84
+ # if schema.get("additionalProperties") is not False:
85
+ # raise ValueError(
86
+ # "Rubric output_schema must set additionalProperties to false for "
87
+ # "agreement-only entropy workflows."
88
+ # )
89
+
90
+ properties_obj = schema.get("properties")
91
+ if not isinstance(properties_obj, dict):
92
+ raise ValueError("Rubric output_schema must define an object-valued properties field.")
93
+ properties = cast(dict[str, Any], properties_obj)
94
+
95
+ agreement_keys = get_agreement_keys(schema)
96
+ if not agreement_keys:
97
+ raise ValueError(
98
+ "Rubric output_schema must include at least one top-level agreement key "
99
+ "(enum or boolean)."
100
+ )
101
+
102
+ non_agreement_keys = sorted(set(properties.keys()) - set(agreement_keys))
103
+ if non_agreement_keys:
104
+ details: list[str] = []
105
+ for key in non_agreement_keys:
106
+ field_obj = properties.get(key, {})
107
+ if isinstance(field_obj, dict):
108
+ field_schema = cast(dict[str, Any], field_obj)
109
+ field_type = field_schema.get("type")
110
+ field_type_text = field_type if isinstance(field_type, str) else "unknown"
111
+ has_enum = "enum" in field_schema
112
+ details.append(f"{key} (type={field_type_text}, enum={has_enum})")
113
+ else:
114
+ details.append(f"{key} (type=invalid-schema)")
115
+ raise ValueError(
116
+ "Rubric output_schema includes non-agreement top-level keys: " + ", ".join(details)
117
+ )
118
+
119
+ return agreement_keys
120
+
121
+
122
+ def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
123
+ """Get top-level schema keys that support agreement computations.
124
+
125
+ This includes top-level enum and boolean fields.
126
+ """
127
+ agreement_keys: list[str] = []
128
+
129
+ properties = schema.get("properties", {})
130
+ assert isinstance(properties, dict)
131
+ properties = cast(dict[str, Any], properties)
132
+
133
+ for key, field_schema in properties.items():
134
+ assert isinstance(field_schema, dict)
135
+ field_schema = cast(dict[str, Any], field_schema)
136
+
137
+ field_type = field_schema.get("type")
138
+ assert isinstance(field_type, str)
139
+
140
+ if field_type == "boolean":
141
+ agreement_keys.append(key)
142
+ elif "enum" in field_schema:
143
+ agreement_keys.append(key)
144
+
145
+ return agreement_keys
146
+
147
+
148
+ def get_agreement_key_options(
149
+ schema: dict[str, Any],
150
+ agreement_keys: list[str] | None = None,
151
+ ) -> dict[str, list[AgreementValue]]:
152
+ """Return possible output options for each agreement key from schema."""
153
+ if agreement_keys is None:
154
+ agreement_keys = get_agreement_keys(schema)
155
+
156
+ properties = schema.get("properties", {})
157
+ assert isinstance(properties, dict)
158
+ properties = cast(dict[str, Any], properties)
159
+
160
+ key_options: dict[str, list[AgreementValue]] = {}
161
+ for key in agreement_keys:
162
+ field_schema_obj = properties.get(key, {})
163
+ assert isinstance(field_schema_obj, dict)
164
+ field_schema = cast(dict[str, Any], field_schema_obj)
165
+
166
+ field_type = field_schema.get("type")
167
+ assert isinstance(field_type, str)
168
+
169
+ if field_type == "boolean":
170
+ key_options[key] = [True, False]
171
+ continue
172
+
173
+ if "enum" in field_schema:
174
+ enum_values = field_schema.get("enum")
175
+ assert isinstance(enum_values, list)
176
+ options: list[AgreementValue] = []
177
+ for enum_value in cast(list[object], enum_values):
178
+ assert isinstance(enum_value, (str, bool, int, float))
179
+ options.append(enum_value)
180
+ key_options[key] = options
181
+ continue
182
+
183
+ key_options[key] = []
184
+
185
+ return key_options
186
+
187
+
188
+ def enumerate_agreement_output_space(
189
+ schema: dict[str, Any],
190
+ agreement_keys: list[str] | None = None,
191
+ max_outcomes: int | None = None,
192
+ ) -> list[dict[str, AgreementValue]]:
193
+ """Enumerate all possible agreement-only output objects from schema key options."""
194
+ if agreement_keys is None:
195
+ agreement_keys = get_agreement_keys(schema)
196
+ if not agreement_keys:
197
+ raise ValueError("Cannot enumerate output space: no agreement keys were found.")
198
+
199
+ if max_outcomes is not None and max_outcomes <= 0:
200
+ raise ValueError("max_outcomes must be > 0 when provided.")
201
+
202
+ key_options = get_agreement_key_options(schema, agreement_keys)
203
+ ordered_options: list[list[AgreementValue]] = []
204
+ total_outcomes = 1
205
+
206
+ for key in agreement_keys:
207
+ options = key_options.get(key, [])
208
+ if not options:
209
+ raise ValueError(
210
+ f"Cannot enumerate output space: agreement key '{key}' has no valid options."
211
+ )
212
+ ordered_options.append(options)
213
+ total_outcomes *= len(options)
214
+ if max_outcomes is not None and total_outcomes > max_outcomes:
215
+ raise ValueError(
216
+ "Cannot enumerate output space: "
217
+ f"{total_outcomes} outcomes exceeds max_outcomes={max_outcomes}."
218
+ )
219
+
220
+ return [
221
+ {key: value for key, value in zip(agreement_keys, combo)}
222
+ for combo in product(*ordered_options)
223
+ ]
224
+
225
+
226
+ def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
227
+ """Find the result that best matches modal values across agreement keys.
228
+
229
+ Args:
230
+ indep_results: List of independent results to analyze
231
+ agreement_keys: Keys to measure agreement on
232
+
233
+ Returns:
234
+ Tuple of (max_idx, agt_key_modes_and_counts) where:
235
+ - max_idx is the index of the result that best matches modal values
236
+ - agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
237
+
238
+ Raises:
239
+ ValueError: If no results are provided
240
+ """
241
+ if not indep_results:
242
+ raise ValueError("No results to score")
243
+
244
+ # For each agreement key, compute the mode and count (or None, if no values exist for that key)
245
+ agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
246
+ for key in agreement_keys:
247
+ key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
248
+ if most_common_one := key_modes.most_common(1):
249
+ agt_key_modes_and_counts[key] = most_common_one[0]
250
+ else:
251
+ agt_key_modes_and_counts[key] = None
252
+
253
+ # Score each rollout based on how many agreement keys they match
254
+ # If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
255
+ # TODO(mengk): This may bias towards results that have more keys.
256
+ indep_result_scores: list[int] = []
257
+ for r in indep_results:
258
+ score = 0
259
+ for key in agreement_keys:
260
+ mode_and_count = agt_key_modes_and_counts[key]
261
+ if mode_and_count and r.get(key) == mode_and_count[0]:
262
+ score += 1
263
+ indep_result_scores.append(score)
264
+
265
+ # Argmax
266
+ max_idx = indep_result_scores.index(max(indep_result_scores))
267
+
268
+ return max_idx, agt_key_modes_and_counts
269
+
270
+
271
+ def compute_output_distributions(
272
+ eq_weighted_outputs: list[dict[str, Any]],
273
+ output_schema: dict[str, Any],
274
+ agreement_keys: list[str],
275
+ ) -> dict[str, JudgeOutputDistribution]:
276
+ """Estimate per-key output value distributions from equally weighted outputs.
277
+
278
+ For each key in ``agreement_keys``, this function:
279
+ 1. Enumerates allowed values from the schema (``enum`` or ``boolean``),
280
+ 2. Counts observed non-null values in ``eq_weighted_outputs``,
281
+ 3. Normalizes counts into empirical probabilities, and
282
+ 4. Computes per-value summary statistics:
283
+ - ``mean``: empirical probability ``p``
284
+ - ``var``: Bernoulli variance ``p * (1 - p)``
285
+ - ``n``: number of observed (non-null) values for that key
286
+ - ``ci_95``: normal-approximation 95% CI half-width ``1.96 * sqrt(var / n)``
287
+
288
+ Optional fields that are missing in some results are skipped (not counted toward ``n``).
289
+ If no values are observed for a key, all value probabilities are set to ``0.0`` and ``n=0``.
290
+
291
+ Args:
292
+ eq_weighted_outputs: Output objects to aggregate, each counted with equal weight.
293
+ output_schema: JSON schema used to validate/define allowed output values.
294
+ agreement_keys: Keys to include in the distribution computation.
295
+
296
+ Returns:
297
+ Mapping from agreement key to per-value distribution estimates.
298
+
299
+ Raises:
300
+ AssertionError: If an observed value is not one of the schema-derived possible values.
301
+ """
302
+
303
+ key_options = get_agreement_key_options(output_schema, agreement_keys)
304
+ raw_counts: dict[str, dict[AgreementValue, int]] = {
305
+ key: {value: 0 for value in key_options.get(key, [])} for key in agreement_keys
306
+ }
307
+ # Collect counts for each possible value
308
+ for result in eq_weighted_outputs:
309
+ for key in agreement_keys:
310
+ if (value := result.get(key)) is not None: # Could be none if the key is optional
311
+ assert value in raw_counts[key], (
312
+ "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
313
+ )
314
+ raw_counts[key][value] += 1
315
+
316
+ distributions: dict[str, JudgeOutputDistribution] = {}
317
+ for agt_key in agreement_keys:
318
+ distributions[agt_key] = {}
319
+
320
+ # First normalize the counts to get probabilities
321
+ counts = raw_counts[agt_key]
322
+ total = sum(counts.values())
323
+ probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
324
+
325
+ for output_key, value in probs.items():
326
+ mean, estimate_var = value, (value * (1 - value))
327
+ # TODO(mengk): change to the wilson score interval
328
+ ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
329
+ estimate: EstimateWithCI = {
330
+ "mean": mean,
331
+ "var": estimate_var,
332
+ "n": total,
333
+ "ci_95": ci_95,
334
+ }
335
+ distributions[agt_key][output_key] = estimate
336
+
337
+ return distributions
338
+
339
+
340
+ def compute_entropy(distribution: OutputDistribution) -> float:
341
+ """Compute Shannon entropy over normalized outcomes in nats."""
342
+ normalized = normalize_output_distribution(distribution)
343
+ if not normalized.outcomes:
344
+ return 0.0
345
+
346
+ entropy = 0.0
347
+ for outcome in normalized.outcomes:
348
+ probability = outcome.probability
349
+ if probability > 0:
350
+ entropy -= probability * math.log(probability)
351
+ return entropy
@@ -21,6 +21,11 @@ from tqdm import tqdm
21
21
  from docent._llm_util.providers.preference_types import ModelOption
22
22
  from docent._log_util.logger import LoggerAdapter, get_logger
23
23
  from docent.data_models.agent_run import AgentRun
24
+ from docent.data_models.feedback import (
25
+ AgentRunFeedbackContext,
26
+ FeedbackContextsJobStateResponse,
27
+ StartFeedbackContextsJobResponse,
28
+ )
24
29
  from docent.data_models.judge import Label
25
30
  from docent.judges.util.meta_schema import validate_judge_result_schema
26
31
  from docent.loaders import load_inspect
@@ -962,6 +967,113 @@ class Docent:
962
967
  clustering_state = self.get_clustering_state(collection_id, rubric_id)
963
968
  return clustering_state.get("assignments", {})
964
969
 
970
+ def create_feedback_session(
971
+ self,
972
+ collection_id: str,
973
+ rubric_id: str,
974
+ rubric_version: int,
975
+ ) -> str:
976
+ """Create a feedback session for a specific rubric version."""
977
+ url = f"{self._api_url}/feedback/{collection_id}/session"
978
+ payload = {
979
+ "rubric_id": rubric_id,
980
+ "rubric_version": rubric_version,
981
+ }
982
+ response = self._session.post(url, json=payload)
983
+ self._handle_response_errors(response)
984
+ return response.json()["feedback_session_id"]
985
+
986
+ def start_feedback_contexts_job(
987
+ self,
988
+ collection_id: str,
989
+ feedback_session_id: str,
990
+ num_samples: int = 50,
991
+ top_n: int = 20,
992
+ seed: int = 0,
993
+ candidate_pool_limit: int = 1_000,
994
+ where_clause: str | None = None,
995
+ increment_round: bool = False,
996
+ ) -> StartFeedbackContextsJobResponse:
997
+ """Start or reuse a background job to compute feedback contexts for a session."""
998
+ payload = {
999
+ "feedback_session_id": feedback_session_id,
1000
+ "num_samples": num_samples,
1001
+ "top_n": top_n,
1002
+ "seed": seed,
1003
+ "candidate_pool_limit": candidate_pool_limit,
1004
+ "where_clause": where_clause,
1005
+ "increment_round": increment_round,
1006
+ }
1007
+ url = f"{self._api_url}/feedback/{collection_id}/contexts/start"
1008
+ response = self._session.post(url, json=payload)
1009
+ self._handle_response_errors(response)
1010
+ return StartFeedbackContextsJobResponse.model_validate(response.json())
1011
+
1012
+ def get_feedback_contexts(
1013
+ self,
1014
+ collection_id: str,
1015
+ feedback_session_id: str,
1016
+ ) -> FeedbackContextsJobStateResponse:
1017
+ """Get feedback contexts state for a session, including job status and current round data."""
1018
+ payload = {
1019
+ "feedback_session_id": feedback_session_id,
1020
+ }
1021
+ url = f"{self._api_url}/feedback/{collection_id}/contexts/state"
1022
+ response = self._session.post(url, json=payload)
1023
+ self._handle_response_errors(response)
1024
+ return FeedbackContextsJobStateResponse.model_validate(response.json())
1025
+
1026
+ def get_agent_run_feedback_contexts_by_session(
1027
+ self,
1028
+ collection_id: str,
1029
+ feedback_session_id: str,
1030
+ ) -> list[AgentRunFeedbackContext]:
1031
+ """Get all persisted AgentRun feedback contexts for a feedback session.
1032
+
1033
+ Returns fully hydrated AgentRun feedback context objects from the database,
1034
+ including QA pairs and labels (if present), across all rounds in the session.
1035
+ """
1036
+ url = f"{self._api_url}/feedback/{collection_id}/session/{feedback_session_id}/contexts"
1037
+ response = self._session.get(url)
1038
+ self._handle_response_errors(response)
1039
+ return [AgentRunFeedbackContext.model_validate(item) for item in response.json()]
1040
+
1041
+ def add_feedback_qa(
1042
+ self,
1043
+ collection_id: str,
1044
+ feedback_context_id: str,
1045
+ focus_index: int,
1046
+ answer: str,
1047
+ explanation: str | None = None,
1048
+ ) -> None:
1049
+ """Submit a QA answer for a feedback context."""
1050
+ payload = {
1051
+ "feedback_context_id": feedback_context_id,
1052
+ "focus_index": focus_index,
1053
+ "answer": answer,
1054
+ "explanation": explanation,
1055
+ }
1056
+ url = f"{self._api_url}/feedback/{collection_id}/qa"
1057
+ response = self._session.post(url, json=payload)
1058
+ self._handle_response_errors(response)
1059
+
1060
+ def upsert_feedback_label(
1061
+ self,
1062
+ collection_id: str,
1063
+ feedback_context_id: str,
1064
+ label_value: dict[str, Any],
1065
+ explanation: str | None = None,
1066
+ ) -> None:
1067
+ """Create or update a label for a feedback context."""
1068
+ payload = {
1069
+ "feedback_context_id": feedback_context_id,
1070
+ "label_value": label_value,
1071
+ "explanation": explanation,
1072
+ }
1073
+ url = f"{self._api_url}/feedback/{collection_id}/label"
1074
+ response = self._session.put(url, json=payload)
1075
+ self._handle_response_errors(response)
1076
+
965
1077
  def create_label_set(
966
1078
  self,
967
1079
  collection_id: str,
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "docent-python"
3
3
  description = "Docent SDK"
4
- version = "0.1.50-alpha"
4
+ version = "0.1.52-alpha"
5
5
  authors = [
6
6
  { name="Transluce", email="info@transluce.org" },
7
7
  ]
@@ -1,140 +0,0 @@
1
- from collections import Counter
2
- from typing import Any, TypedDict, cast
3
-
4
- import numpy as np
5
-
6
-
7
- class EstimateWithCI(TypedDict):
8
- mean: float
9
- var: float
10
- n: int
11
- ci_95: float
12
-
13
-
14
- JudgeOutputDistribution = dict[str | bool | int | float, EstimateWithCI]
15
-
16
-
17
- def get_agreement_keys(schema: dict[str, Any]) -> list[str]:
18
- """Get list of top-level keys in schema that we want to measure agreement on.
19
-
20
- This includes enum and bool fields.
21
-
22
- Args:
23
- schema: JSON schema dict
24
-
25
- Returns:
26
- List of field names (keys) that should be used for measuring agreement
27
- """
28
- agreement_keys: list[str] = []
29
-
30
- properties = schema.get("properties", {})
31
- assert isinstance(properties, dict)
32
- properties = cast(dict[str, Any], properties)
33
-
34
- for key, field_schema in properties.items():
35
- assert isinstance(field_schema, dict)
36
- field_schema = cast(dict[str, Any], field_schema)
37
-
38
- field_type = field_schema.get("type")
39
- assert isinstance(field_type, str)
40
-
41
- # Include boolean fields
42
- if field_type == "boolean":
43
- agreement_keys.append(key)
44
- # Include enum fields (strings and numbers must be in this category)
45
- elif "enum" in field_schema:
46
- agreement_keys.append(key)
47
-
48
- return agreement_keys
49
-
50
-
51
- def find_modal_result(indep_results: list[dict[str, Any]], agreement_keys: list[str]):
52
- """Find the result that best matches modal values across agreement keys.
53
-
54
- Args:
55
- indep_results: List of independent results to analyze
56
- agreement_keys: Keys to measure agreement on
57
-
58
- Returns:
59
- Tuple of (max_idx, agt_key_modes_and_counts) where:
60
- - max_idx is the index of the result that best matches modal values
61
- - agt_key_modes_and_counts maps each key to (modal_value, count) or None if no values exist for that key
62
-
63
- Raises:
64
- ValueError: If no results are provided
65
- """
66
- if not indep_results:
67
- raise ValueError("No results to score")
68
-
69
- # For each agreement key, compute the mode and count (or None, if no values exist for that key)
70
- agt_key_modes_and_counts: dict[str, tuple[str | bool | int, int] | None] = {}
71
- for key in agreement_keys:
72
- key_modes = Counter(v for r in indep_results if (v := r.get(key)) is not None)
73
- if most_common_one := key_modes.most_common(1):
74
- agt_key_modes_and_counts[key] = most_common_one[0]
75
- else:
76
- agt_key_modes_and_counts[key] = None
77
-
78
- # Score each rollout based on how many agreement keys they match
79
- # If there is no mode for a key, or if a certain result doesn't have that key, it doesn't count.
80
- # TODO(mengk): This may bias towards results that have more keys.
81
- indep_result_scores: list[int] = []
82
- for r in indep_results:
83
- score = 0
84
- for key in agreement_keys:
85
- mode_and_count = agt_key_modes_and_counts[key]
86
- if mode_and_count and r.get(key) == mode_and_count[0]:
87
- score += 1
88
- indep_result_scores.append(score)
89
-
90
- # Argmax
91
- max_idx = indep_result_scores.index(max(indep_result_scores))
92
-
93
- return max_idx, agt_key_modes_and_counts
94
-
95
-
96
- def compute_output_distributions(
97
- indep_results: list[dict[str, Any]], output_schema: dict[str, Any], agreement_keys: list[str]
98
- ):
99
- def _get_possible_values(key: str) -> list[str | bool | int | float]:
100
- if "enum" in output_schema.get("properties", {}).get(key, {}):
101
- return output_schema.get("properties", {}).get(key, {}).get("enum", [])
102
- elif output_schema.get("properties", {}).get(key, {}).get("type") == "boolean":
103
- return [True, False]
104
- else:
105
- return []
106
-
107
- raw_counts: dict[str, dict[str | bool | int | float, int]] = {
108
- key: {value: 0 for value in _get_possible_values(key)} for key in agreement_keys
109
- }
110
- # Collect counts for each possible value
111
- for result in indep_results:
112
- for key in agreement_keys:
113
- if (value := result.get(key)) is not None: # Could be none if the key is optional
114
- assert value in raw_counts[key], (
115
- "this should never happen; the value must be in possible values, since judge results have been validated against the schema"
116
- )
117
- raw_counts[key][value] += 1
118
-
119
- distributions: dict[str, JudgeOutputDistribution] = {}
120
- for agt_key in agreement_keys:
121
- distributions[agt_key] = {}
122
-
123
- # First normalize the counts to get probabilities
124
- counts = raw_counts[agt_key]
125
- total = sum(counts.values())
126
- probs = {value: (count / total) if total > 0 else 0.0 for value, count in counts.items()}
127
-
128
- for output_key, value in probs.items():
129
- mean, estimate_var = value, (value * (1 - value))
130
- # TODO(mengk): change to the wilson score interval
131
- ci_95 = float(1.96 * np.sqrt(estimate_var / total)) if total > 0 else 0.0
132
- estimate: EstimateWithCI = {
133
- "mean": mean,
134
- "var": estimate_var,
135
- "n": total,
136
- "ci_95": ci_95,
137
- }
138
- distributions[agt_key][output_key] = estimate
139
-
140
- return distributions