docent-python 0.1.41a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (59) hide show
  1. docent/__init__.py +4 -0
  2. docent/_llm_util/__init__.py +0 -0
  3. docent/_llm_util/data_models/__init__.py +0 -0
  4. docent/_llm_util/data_models/exceptions.py +48 -0
  5. docent/_llm_util/data_models/llm_output.py +331 -0
  6. docent/_llm_util/llm_cache.py +193 -0
  7. docent/_llm_util/llm_svc.py +472 -0
  8. docent/_llm_util/model_registry.py +134 -0
  9. docent/_llm_util/providers/__init__.py +0 -0
  10. docent/_llm_util/providers/anthropic.py +537 -0
  11. docent/_llm_util/providers/common.py +41 -0
  12. docent/_llm_util/providers/google.py +530 -0
  13. docent/_llm_util/providers/openai.py +745 -0
  14. docent/_llm_util/providers/openrouter.py +375 -0
  15. docent/_llm_util/providers/preference_types.py +104 -0
  16. docent/_llm_util/providers/provider_registry.py +164 -0
  17. docent/_log_util/__init__.py +3 -0
  18. docent/_log_util/logger.py +141 -0
  19. docent/data_models/__init__.py +14 -0
  20. docent/data_models/_tiktoken_util.py +91 -0
  21. docent/data_models/agent_run.py +473 -0
  22. docent/data_models/chat/__init__.py +37 -0
  23. docent/data_models/chat/content.py +56 -0
  24. docent/data_models/chat/message.py +191 -0
  25. docent/data_models/chat/tool.py +109 -0
  26. docent/data_models/citation.py +187 -0
  27. docent/data_models/formatted_objects.py +84 -0
  28. docent/data_models/judge.py +17 -0
  29. docent/data_models/metadata_util.py +16 -0
  30. docent/data_models/regex.py +56 -0
  31. docent/data_models/transcript.py +305 -0
  32. docent/data_models/util.py +170 -0
  33. docent/judges/__init__.py +23 -0
  34. docent/judges/analysis.py +77 -0
  35. docent/judges/impl.py +587 -0
  36. docent/judges/runner.py +129 -0
  37. docent/judges/stats.py +205 -0
  38. docent/judges/types.py +320 -0
  39. docent/judges/util/forgiving_json.py +108 -0
  40. docent/judges/util/meta_schema.json +86 -0
  41. docent/judges/util/meta_schema.py +29 -0
  42. docent/judges/util/parse_output.py +68 -0
  43. docent/judges/util/voting.py +139 -0
  44. docent/loaders/load_inspect.py +215 -0
  45. docent/py.typed +0 -0
  46. docent/samples/__init__.py +3 -0
  47. docent/samples/load.py +9 -0
  48. docent/samples/log.eval +0 -0
  49. docent/samples/tb_airline.json +1 -0
  50. docent/sdk/__init__.py +0 -0
  51. docent/sdk/agent_run_writer.py +317 -0
  52. docent/sdk/client.py +1186 -0
  53. docent/sdk/llm_context.py +432 -0
  54. docent/trace.py +2741 -0
  55. docent/trace_temp.py +1086 -0
  56. docent_python-0.1.41a0.dist-info/METADATA +33 -0
  57. docent_python-0.1.41a0.dist-info/RECORD +59 -0
  58. docent_python-0.1.41a0.dist-info/WHEEL +4 -0
  59. docent_python-0.1.41a0.dist-info/licenses/LICENSE.md +13 -0
@@ -0,0 +1,473 @@
1
+ import sys
2
+ import textwrap
3
+ from collections import deque
4
+ from datetime import datetime
5
+ from typing import Any, Literal, TypedDict, cast
6
+ from uuid import uuid4
7
+
8
+ import yaml
9
+ from pydantic import (
10
+ BaseModel,
11
+ Field,
12
+ PrivateAttr,
13
+ field_validator,
14
+ model_validator,
15
+ )
16
+ from pydantic_core import to_jsonable_python
17
+
18
+ from docent._log_util import get_logger
19
+ from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
20
+ from docent.data_models.metadata_util import dump_metadata
21
+ from docent.data_models.transcript import Transcript, TranscriptGroup
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class FilterableField(TypedDict):
27
+ name: str
28
+ type: Literal["str", "bool", "int", "float"]
29
+
30
+
31
+ class AgentRun(BaseModel):
32
+ """Represents a complete run of an agent with transcripts and metadata.
33
+
34
+ An AgentRun encapsulates the execution of an agent, storing all communication
35
+ transcripts and associated metadata. It must contain at least one transcript.
36
+
37
+ Attributes:
38
+ id: Unique identifier for the agent run, auto-generated by default.
39
+ name: Optional human-readable name for the agent run.
40
+ description: Optional description of the agent run.
41
+ transcripts: List of Transcript objects.
42
+ transcript_groups: List of TranscriptGroup objects.
43
+ metadata: Additional structured metadata about the agent run as a JSON-serializable dictionary.
44
+ """
45
+
46
+ id: str = Field(default_factory=lambda: str(uuid4()))
47
+ name: str | None = None
48
+ description: str | None = None
49
+
50
+ transcripts: list[Transcript]
51
+ transcript_groups: list[TranscriptGroup] = Field(default_factory=list)
52
+ metadata: dict[str, Any] = Field(default_factory=dict)
53
+
54
+ @field_validator("transcripts", mode="before")
55
+ @classmethod
56
+ def _validate_transcripts_type(cls, v: Any) -> Any:
57
+ if isinstance(v, dict):
58
+ logger.warning(
59
+ "dict[str, Transcript] for transcripts is deprecated. Use list[Transcript] instead."
60
+ )
61
+ v = cast(dict[str, Transcript], v)
62
+ return [Transcript.model_validate(t) for t in v.values()]
63
+ return v
64
+
65
+ @field_validator("transcript_groups", mode="before")
66
+ @classmethod
67
+ def _validate_transcript_groups_type(cls, v: Any) -> Any:
68
+ if isinstance(v, dict):
69
+ logger.warning(
70
+ "dict[str, TranscriptGroup] for transcript_groups is deprecated. Use list[TranscriptGroup] instead."
71
+ )
72
+ v = cast(dict[str, TranscriptGroup], v)
73
+ return [TranscriptGroup.model_validate(tg) for tg in v.values()]
74
+ return v
75
+
76
+ @model_validator(mode="after")
77
+ def _validate_transcripts_not_empty(self):
78
+ """Validates that the agent run contains at least one transcript.
79
+
80
+ Raises:
81
+ ValueError: If the transcripts list is empty.
82
+
83
+ Returns:
84
+ AgentRun: The validated AgentRun instance.
85
+ """
86
+ if len(self.transcripts) == 0:
87
+ raise ValueError("AgentRun must have at least one transcript")
88
+ return self
89
+
90
+ def get_filterable_fields(self, max_depth: int = 1) -> list[FilterableField]:
91
+ """Returns a list of all fields that can be used to filter the agent run,
92
+ by recursively exploring the model_dump() for singleton types in dictionaries.
93
+
94
+ Returns:
95
+ list[FilterableField]: A list of filterable fields, where each field is a
96
+ dictionary containing its 'name' (path) and 'type'.
97
+ """
98
+
99
+ result: list[FilterableField] = []
100
+
101
+ def _explore_dict(d: dict[str, Any], prefix: str, depth: int):
102
+ nonlocal result
103
+
104
+ if depth > max_depth:
105
+ return
106
+
107
+ for k, v in d.items():
108
+ if isinstance(v, (str, int, float, bool)):
109
+ result.append(
110
+ {
111
+ "name": f"{prefix}.{k}",
112
+ "type": cast(Literal["str", "bool", "int", "float"], type(v).__name__),
113
+ }
114
+ )
115
+ elif isinstance(v, dict):
116
+ _explore_dict(cast(dict[str, Any], v), f"{prefix}.{k}", depth + 1)
117
+
118
+ # Look at the agent run metadata
119
+ _explore_dict(to_jsonable_python(self.metadata), "metadata", 0)
120
+ # Look at the transcript metadata
121
+ # TODO(mengk): restore this later when we have the ability to integrate with SQL.
122
+ # for t_id, t in self.transcripts.items():
123
+ # _explore_dict(
124
+ # t.metadata.model_dump(strip_internal_fields=True), f"transcript.{t_id}.metadata", 0
125
+ # )
126
+
127
+ # Append the text field
128
+ result.append({"name": "agent_run_id", "type": "str"})
129
+ result.append({"name": "text", "type": "str"})
130
+
131
+ return result
132
+
133
+ ######################
134
+ # Converting to text #
135
+ ######################
136
+
137
+ def _to_text_impl(self, token_limit: int = sys.maxsize) -> list[str]:
138
+ """
139
+ Core implementation for converting agent run to text representation.
140
+
141
+ Args:
142
+ token_limit: Maximum tokens per returned string under the GPT-4 tokenization scheme
143
+ use_blocks: If True, use individual message blocks. If False, use action units.
144
+
145
+ Returns:
146
+ List of strings, each at most token_limit tokens
147
+ """
148
+ # Generate transcript strings using appropriate method
149
+ transcript_strs: list[str] = []
150
+ for i, t in enumerate(self.transcripts):
151
+ transcript_content = t.to_str(
152
+ token_limit=sys.maxsize,
153
+ transcript_idx=i,
154
+ )[0]
155
+ transcript_strs.append(f"<transcript>\n{transcript_content}\n</transcript>")
156
+
157
+ transcripts_str = "\n\n".join(transcript_strs)
158
+
159
+ # Gather metadata
160
+ metadata_obj = to_jsonable_python(self.metadata)
161
+ if self.name is not None:
162
+ metadata_obj["name"] = self.name
163
+ if self.description is not None:
164
+ metadata_obj["description"] = self.description
165
+
166
+ yaml_width = float("inf")
167
+ transcripts_str = (
168
+ f"Here is a complete agent run for analysis purposes only:\n{transcripts_str}\n\n"
169
+ )
170
+ metadata_str = f"Metadata about the complete agent run:\n<agent run metadata>\n{yaml.dump(metadata_obj, width=yaml_width)}\n</agent run metadata>"
171
+
172
+ if token_limit == sys.maxsize:
173
+ return [f"{transcripts_str}" f"{metadata_str}"]
174
+
175
+ # Compute message length; if fits, return the full transcript and metadata
176
+ transcript_str_tokens = get_token_count(transcripts_str)
177
+ metadata_str_tokens = get_token_count(metadata_str)
178
+ if transcript_str_tokens + metadata_str_tokens <= token_limit:
179
+ return [f"{transcripts_str}" f"{metadata_str}"]
180
+
181
+ # Otherwise, split up the transcript and metadata into chunks
182
+ else:
183
+ results: list[str] = []
184
+ transcript_token_counts = [get_token_count(t) for t in transcript_strs]
185
+ ranges = group_messages_into_ranges(
186
+ transcript_token_counts, metadata_str_tokens, token_limit - 50
187
+ )
188
+ for msg_range in ranges:
189
+ if msg_range.include_metadata:
190
+ cur_transcript_str = "\n\n".join(
191
+ transcript_strs[msg_range.start : msg_range.end]
192
+ )
193
+ results.append(
194
+ f"Here is a partial agent run for analysis purposes only:\n{cur_transcript_str}"
195
+ f"{metadata_str}"
196
+ )
197
+ else:
198
+ assert (
199
+ msg_range.end == msg_range.start + 1
200
+ ), "Ranges without metadata should be a single message"
201
+ t = self.transcripts[msg_range.start]
202
+ if msg_range.num_tokens < token_limit - 50:
203
+ transcript = (
204
+ f"<transcript>\n{t.to_str(token_limit=sys.maxsize)[0]}\n</transcript>"
205
+ )
206
+ result = (
207
+ f"Here is a partial agent run for analysis purposes only:\n{transcript}"
208
+ )
209
+ results.append(result)
210
+ else:
211
+ transcript_fragments: list[str] = t.to_str(
212
+ token_limit=token_limit - 50,
213
+ )
214
+ for fragment in transcript_fragments:
215
+ result = f"<transcript>\n{fragment}\n</transcript>"
216
+ result = (
217
+ f"Here is a partial agent run for analysis purposes only:\n{result}"
218
+ )
219
+ results.append(result)
220
+ return results
221
+
222
+ @property
223
+ def text(self) -> str:
224
+ """Concatenates all transcript texts with double newlines as separators.
225
+
226
+ Returns:
227
+ str: A string representation of all transcripts.
228
+ """
229
+ return self._to_text_impl(token_limit=sys.maxsize)[0]
230
+
231
+ ##############################
232
+ # New text rendering methods #
233
+ ##############################
234
+
235
+ # Transcript ID -> Transcript
236
+ _transcript_dict: dict[str, Transcript] | None = PrivateAttr(default=None)
237
+ # Transcript Group ID -> Transcript Group
238
+ _transcript_group_dict: dict[str, TranscriptGroup] | None = PrivateAttr(default=None)
239
+ # Canonical tree cache keyed by full_tree flag
240
+ _canonical_tree_cache: dict[bool, dict[str | None, list[tuple[Literal["t", "tg"], str]]]] = (
241
+ PrivateAttr(default_factory=dict)
242
+ )
243
+ # Transcript IDs (depth-first) cache keyed by full_tree flag
244
+ _transcript_ids_ordered_cache: dict[bool, list[str]] = PrivateAttr(default_factory=dict)
245
+
246
+ @property
247
+ def transcript_dict(self) -> dict[str, Transcript]:
248
+ """Lazily compute and cache a mapping from transcript ID to Transcript."""
249
+ if self._transcript_dict is None:
250
+ self._transcript_dict = {t.id: t for t in self.transcripts}
251
+ return self._transcript_dict
252
+
253
+ @property
254
+ def transcript_group_dict(self) -> dict[str, TranscriptGroup]:
255
+ """Lazily compute and cache a mapping from transcript group ID to TranscriptGroup."""
256
+ if self._transcript_group_dict is None:
257
+ self._transcript_group_dict = {tg.id: tg for tg in self.transcript_groups}
258
+ return self._transcript_group_dict
259
+
260
+ def _invalidate_caches(self) -> None:
261
+ """Reset cached lookups after mutating transcripts or transcript groups."""
262
+ self._transcript_dict = None
263
+ self._transcript_group_dict = None
264
+ self._canonical_tree_cache.clear()
265
+ self._transcript_ids_ordered_cache.clear()
266
+
267
+ def get_canonical_tree(
268
+ self, full_tree: bool = False
269
+ ) -> dict[str | None, list[tuple[Literal["t", "tg"], str]]]:
270
+ """Compute and cache the canonical, sorted transcript group tree.
271
+
272
+ Args:
273
+ full_tree: If True, include all transcript groups regardless of whether
274
+ they contain transcripts. If False, include only the minimal tree
275
+ that connects relevant groups and transcripts.
276
+
277
+ Returns:
278
+ Canonical tree mapping parent group id (or "__global_root") to a list of
279
+ children (type, id) tuples sorted by creation time.
280
+ """
281
+ if (
282
+ full_tree not in self._canonical_tree_cache
283
+ or full_tree not in self._transcript_ids_ordered_cache
284
+ ):
285
+ canonical_tree, transcript_idx_map = self._build_canonical_tree(full_tree=full_tree)
286
+ self._canonical_tree_cache[full_tree] = canonical_tree
287
+ self._transcript_ids_ordered_cache[full_tree] = list(transcript_idx_map.keys())
288
+ return self._canonical_tree_cache[full_tree]
289
+
290
+ def get_transcript_ids_ordered(self, full_tree: bool = False) -> list[str]:
291
+ """Compute and cache the depth-first transcript id ordering.
292
+
293
+ Args:
294
+ full_tree: Whether to compute based on the full tree or the minimal tree.
295
+
296
+ Returns:
297
+ List of transcript ids in depth-first order.
298
+ """
299
+ if (
300
+ full_tree not in self._transcript_ids_ordered_cache
301
+ or full_tree not in self._canonical_tree_cache
302
+ ):
303
+ canonical_tree, transcript_idx_map = self._build_canonical_tree(full_tree=full_tree)
304
+ self._canonical_tree_cache[full_tree] = canonical_tree
305
+ self._transcript_ids_ordered_cache[full_tree] = list(transcript_idx_map.keys())
306
+ return self._transcript_ids_ordered_cache[full_tree]
307
+
308
+ def _build_canonical_tree(self, full_tree: bool = False):
309
+ t_dict = self.transcript_dict
310
+ tg_dict = self.transcript_group_dict
311
+
312
+ # Find all transcript groups that have direct transcript children
313
+ # Also keep track of transcripts that are not in a group
314
+ tgs_to_transcripts: dict[str, set[str]] = {}
315
+ for transcript in t_dict.values():
316
+ if transcript.transcript_group_id is None:
317
+ tgs_to_transcripts.setdefault("__global_root", set()).add(transcript.id)
318
+ else:
319
+ tgs_to_transcripts.setdefault(transcript.transcript_group_id, set()).add(
320
+ transcript.id
321
+ )
322
+
323
+ # tg_tree maps from parent -> children. A child can be a group or a transcript.
324
+ # A parent must be a group (or None, for transcripts that are not in a group).
325
+ tg_tree: dict[str, set[tuple[Literal["t", "tg"], str]]] = {}
326
+
327
+ if full_tree:
328
+ for tg_id, tg in tg_dict.items():
329
+ tg_tree.setdefault(tg.parent_transcript_group_id or "__global_root", set()).add(
330
+ ("tg", tg_id)
331
+ )
332
+ for t_id in tgs_to_transcripts.get(tg_id, []):
333
+ tg_tree.setdefault(tg_id, set()).add(("t", t_id))
334
+ for t_id, t in t_dict.items():
335
+ tg_tree.setdefault(t.transcript_group_id or "__global_root", set()).add(("t", t_id))
336
+ else:
337
+ # Initialize q with "important" tgs
338
+ q, seen = deque(tgs_to_transcripts.keys()), set(tgs_to_transcripts.keys())
339
+
340
+ # Do an "upwards BFS" from leaves up to the root. Builds a tree of only relevant nodes.
341
+ while q:
342
+ u_id = q.popleft()
343
+ u = tg_dict.get(u_id) # None if __global_root
344
+
345
+ # Add the transcripts under this tg
346
+ for t_id in tgs_to_transcripts.get(u_id, []):
347
+ tg_tree.setdefault(u_id, set()).add(("t", t_id))
348
+
349
+ # Add an edge from the parent
350
+ if u is not None:
351
+ par_id = u.parent_transcript_group_id or "__global_root"
352
+ # Mark u as a child of par
353
+ tg_tree.setdefault(par_id, set()).add(("tg", u_id))
354
+ # If we haven't investigated the parent before, add to q
355
+ if par_id not in seen:
356
+ q.append(par_id)
357
+ seen.add(par_id)
358
+
359
+ # For each node, sort by created_at timestamp
360
+
361
+ def _cmp(element: tuple[Literal["t", "tg"], str]) -> datetime:
362
+ obj_type, obj_id = element
363
+ if obj_type == "tg":
364
+ return tg_dict[obj_id].created_at or datetime.max
365
+ else:
366
+ return t_dict[obj_id].created_at or datetime.max
367
+
368
+ c_tree: dict[str | None, list[tuple[Literal["t", "tg"], str]]] = {}
369
+ for tg_id in tg_tree:
370
+ children_ids = list(set(tg_tree[tg_id]))
371
+ sorted_children_ids = sorted(children_ids, key=_cmp)
372
+ c_tree[tg_id] = sorted_children_ids
373
+
374
+ # Compute transcript indices as the depth-first traversal index
375
+ transcript_idx_map: dict[str, int] = {}
376
+
377
+ def _assign_transcript_indices(cur_tg_id: str, next_idx: int) -> int:
378
+ children = c_tree.get(cur_tg_id, [])
379
+ for child_type, child_id in children:
380
+ if child_type == "tg":
381
+ next_idx = _assign_transcript_indices(child_id, next_idx)
382
+ else:
383
+ transcript_idx_map[child_id] = next_idx
384
+ next_idx += 1
385
+ return next_idx
386
+
387
+ _assign_transcript_indices("__global_root", 0)
388
+
389
+ return c_tree, transcript_idx_map
390
+
391
+ def delete_transcript_group_subtree(self, transcript_group_id: str) -> None:
392
+ """Delete a transcript group and all descendant groups/transcripts using the canonical tree."""
393
+ if transcript_group_id == "__global_root":
394
+ raise ValueError("Cannot delete the global root sentinel")
395
+ if transcript_group_id not in self.transcript_group_dict:
396
+ raise ValueError(
397
+ f"Transcript group '{transcript_group_id}' does not exist on this run."
398
+ )
399
+
400
+ canonical_tree = self.get_canonical_tree(full_tree=True)
401
+ groups_to_delete: set[str] = set()
402
+ transcripts_to_delete: set[str] = set()
403
+
404
+ queue: deque[str] = deque([transcript_group_id])
405
+ while queue:
406
+ current_group = queue.popleft()
407
+ groups_to_delete.add(current_group)
408
+ for child_type, child_id in canonical_tree.get(current_group, []):
409
+ if child_type == "tg":
410
+ queue.append(child_id)
411
+ else:
412
+ transcripts_to_delete.add(child_id)
413
+
414
+ if groups_to_delete:
415
+ self.transcript_groups = [
416
+ tg for tg in self.transcript_groups if tg.id not in groups_to_delete
417
+ ]
418
+ if transcripts_to_delete:
419
+ self.transcripts = [t for t in self.transcripts if t.id not in transcripts_to_delete]
420
+
421
+ self._invalidate_caches()
422
+
423
+ def to_text_new(
424
+ self,
425
+ agent_run_alias: int | str = 0,
426
+ t_idx_map: dict[str, int] | None = None,
427
+ indent: int = 0,
428
+ full_tree: bool = False,
429
+ ):
430
+ if isinstance(agent_run_alias, int):
431
+ agent_run_alias = f"R{agent_run_alias}"
432
+
433
+ c_tree = self.get_canonical_tree(full_tree=full_tree)
434
+ t_ids_ordered = self.get_transcript_ids_ordered(full_tree=full_tree)
435
+ if t_idx_map is None:
436
+ t_idx_map = {t_id: i for i, t_id in enumerate(t_ids_ordered)}
437
+ t_dict = self.transcript_dict
438
+ tg_dict = self.transcript_group_dict
439
+
440
+ # Traverse the tree and render the string
441
+ def _recurse(tg_id: str) -> str:
442
+ children_ids = c_tree.get(tg_id, [])
443
+ children_texts: list[str] = []
444
+ for child_type, child_id in children_ids:
445
+ if child_type == "tg":
446
+ children_texts.append(_recurse(child_id))
447
+ else:
448
+ cur_text = t_dict[child_id].to_text_new(
449
+ transcript_alias=t_idx_map[child_id],
450
+ indent=indent,
451
+ )
452
+ children_texts.append(cur_text)
453
+ children_text = "\n".join(children_texts)
454
+
455
+ # No wrapper for global root
456
+ if tg_id == "__global_root":
457
+ return children_text
458
+ # Delegate rendering to TranscriptGroup
459
+ else:
460
+ tg = tg_dict[tg_id]
461
+ return tg.to_text_new(children_text=children_text, indent=indent)
462
+
463
+ text = _recurse("__global_root")
464
+
465
+ # Append agent run metadata below the full content
466
+ metadata_text = dump_metadata(self.metadata)
467
+ if metadata_text is not None:
468
+ if indent > 0:
469
+ metadata_text = textwrap.indent(metadata_text, " " * indent)
470
+ metadata_alias = f"{agent_run_alias}M"
471
+ text += f"\n<|agent run metadata {metadata_alias}|>\n{metadata_text}\n</|agent run metadata {metadata_alias}|>"
472
+
473
+ return f"<|agent run {agent_run_alias}|>\n{text}\n</|agent run {agent_run_alias}|>\n"
@@ -0,0 +1,37 @@
1
+ from docent.data_models.chat.content import Content, ContentReasoning, ContentText
2
+ from docent.data_models.chat.message import (
3
+ AssistantMessage,
4
+ ChatMessage,
5
+ DocentAssistantMessage,
6
+ DocentChatMessage,
7
+ SystemMessage,
8
+ ToolMessage,
9
+ UserMessage,
10
+ parse_chat_message,
11
+ parse_docent_chat_message,
12
+ )
13
+ from docent.data_models.chat.tool import (
14
+ ToolCall,
15
+ ToolCallContent,
16
+ ToolInfo,
17
+ ToolParams,
18
+ )
19
+
20
+ __all__ = [
21
+ "ChatMessage",
22
+ "DocentChatMessage",
23
+ "AssistantMessage",
24
+ "DocentAssistantMessage",
25
+ "SystemMessage",
26
+ "ToolMessage",
27
+ "UserMessage",
28
+ "Content",
29
+ "ContentReasoning",
30
+ "ContentText",
31
+ "ToolCall",
32
+ "ToolCallContent",
33
+ "ToolInfo",
34
+ "ToolParams",
35
+ "parse_chat_message",
36
+ "parse_docent_chat_message",
37
+ ]
@@ -0,0 +1,56 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Discriminator
4
+
5
+
6
+ class BaseContent(BaseModel):
7
+ """Base class for all content types in chat messages.
8
+
9
+ Provides the foundation for different content types with a discriminator field.
10
+
11
+ Attributes:
12
+ type: The content type identifier, used for discriminating between content types.
13
+ """
14
+
15
+ type: Literal["text", "reasoning", "image", "audio", "video"]
16
+
17
+
18
+ class ContentText(BaseContent):
19
+ """Text content for chat messages.
20
+
21
+ Represents plain text content in a chat message.
22
+
23
+ Attributes:
24
+ type: Fixed as "text" to identify this content type.
25
+ text: The actual text content.
26
+ refusal: Optional flag indicating if this is a refusal message.
27
+ """
28
+
29
+ type: Literal["text"] = "text" # type: ignore
30
+ text: str
31
+ refusal: bool | None = None
32
+
33
+
34
+ class ContentReasoning(BaseContent):
35
+ """Reasoning content for chat messages.
36
+
37
+ Represents reasoning or thought process content in a chat message.
38
+
39
+ Attributes:
40
+ type: Fixed as "reasoning" to identify this content type.
41
+ reasoning: The actual reasoning text.
42
+ signature: Optional signature associated with the reasoning.
43
+ redacted: Flag indicating if the reasoning has been redacted.
44
+ """
45
+
46
+ type: Literal["reasoning"] = "reasoning" # type: ignore
47
+ reasoning: str
48
+ signature: str | None = None
49
+ redacted: bool = False
50
+
51
+
52
+ # Content type discriminated union
53
+ Content = Annotated[ContentText | ContentReasoning, Discriminator("type")]
54
+ """Discriminated union of possible content types using the 'type' field.
55
+ Can be either ContentText or ContentReasoning.
56
+ """