haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of haiku.rag-slim might be problematic. Click here for more details.

Files changed (94) hide show
  1. haiku/rag/app.py +430 -72
  2. haiku/rag/chunkers/__init__.py +31 -0
  3. haiku/rag/chunkers/base.py +31 -0
  4. haiku/rag/chunkers/docling_local.py +164 -0
  5. haiku/rag/chunkers/docling_serve.py +179 -0
  6. haiku/rag/cli.py +207 -24
  7. haiku/rag/cli_chat.py +489 -0
  8. haiku/rag/client.py +1251 -266
  9. haiku/rag/config/__init__.py +16 -10
  10. haiku/rag/config/loader.py +5 -44
  11. haiku/rag/config/models.py +126 -17
  12. haiku/rag/converters/__init__.py +31 -0
  13. haiku/rag/converters/base.py +63 -0
  14. haiku/rag/converters/docling_local.py +193 -0
  15. haiku/rag/converters/docling_serve.py +229 -0
  16. haiku/rag/converters/text_utils.py +237 -0
  17. haiku/rag/embeddings/__init__.py +123 -24
  18. haiku/rag/embeddings/voyageai.py +175 -20
  19. haiku/rag/graph/__init__.py +0 -11
  20. haiku/rag/graph/agui/__init__.py +8 -2
  21. haiku/rag/graph/agui/cli_renderer.py +1 -1
  22. haiku/rag/graph/agui/emitter.py +219 -31
  23. haiku/rag/graph/agui/server.py +20 -62
  24. haiku/rag/graph/agui/stream.py +1 -2
  25. haiku/rag/graph/research/__init__.py +5 -2
  26. haiku/rag/graph/research/dependencies.py +12 -126
  27. haiku/rag/graph/research/graph.py +390 -135
  28. haiku/rag/graph/research/models.py +91 -112
  29. haiku/rag/graph/research/prompts.py +99 -91
  30. haiku/rag/graph/research/state.py +35 -27
  31. haiku/rag/inspector/__init__.py +8 -0
  32. haiku/rag/inspector/app.py +259 -0
  33. haiku/rag/inspector/widgets/__init__.py +6 -0
  34. haiku/rag/inspector/widgets/chunk_list.py +100 -0
  35. haiku/rag/inspector/widgets/context_modal.py +89 -0
  36. haiku/rag/inspector/widgets/detail_view.py +130 -0
  37. haiku/rag/inspector/widgets/document_list.py +75 -0
  38. haiku/rag/inspector/widgets/info_modal.py +209 -0
  39. haiku/rag/inspector/widgets/search_modal.py +183 -0
  40. haiku/rag/inspector/widgets/visual_modal.py +126 -0
  41. haiku/rag/mcp.py +106 -102
  42. haiku/rag/monitor.py +33 -9
  43. haiku/rag/providers/__init__.py +5 -0
  44. haiku/rag/providers/docling_serve.py +108 -0
  45. haiku/rag/qa/__init__.py +12 -10
  46. haiku/rag/qa/agent.py +43 -61
  47. haiku/rag/qa/prompts.py +35 -57
  48. haiku/rag/reranking/__init__.py +9 -6
  49. haiku/rag/reranking/base.py +1 -1
  50. haiku/rag/reranking/cohere.py +5 -4
  51. haiku/rag/reranking/mxbai.py +5 -2
  52. haiku/rag/reranking/vllm.py +3 -4
  53. haiku/rag/reranking/zeroentropy.py +6 -5
  54. haiku/rag/store/__init__.py +2 -1
  55. haiku/rag/store/engine.py +242 -42
  56. haiku/rag/store/exceptions.py +4 -0
  57. haiku/rag/store/models/__init__.py +8 -2
  58. haiku/rag/store/models/chunk.py +190 -0
  59. haiku/rag/store/models/document.py +46 -0
  60. haiku/rag/store/repositories/chunk.py +141 -121
  61. haiku/rag/store/repositories/document.py +25 -84
  62. haiku/rag/store/repositories/settings.py +11 -14
  63. haiku/rag/store/upgrades/__init__.py +19 -3
  64. haiku/rag/store/upgrades/v0_10_1.py +1 -1
  65. haiku/rag/store/upgrades/v0_19_6.py +65 -0
  66. haiku/rag/store/upgrades/v0_20_0.py +68 -0
  67. haiku/rag/store/upgrades/v0_23_1.py +100 -0
  68. haiku/rag/store/upgrades/v0_9_3.py +3 -3
  69. haiku/rag/utils.py +371 -146
  70. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
  71. haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
  72. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
  73. haiku/rag/chunker.py +0 -65
  74. haiku/rag/embeddings/base.py +0 -25
  75. haiku/rag/embeddings/ollama.py +0 -28
  76. haiku/rag/embeddings/openai.py +0 -26
  77. haiku/rag/embeddings/vllm.py +0 -29
  78. haiku/rag/graph/agui/events.py +0 -254
  79. haiku/rag/graph/common/__init__.py +0 -5
  80. haiku/rag/graph/common/models.py +0 -42
  81. haiku/rag/graph/common/nodes.py +0 -265
  82. haiku/rag/graph/common/prompts.py +0 -46
  83. haiku/rag/graph/common/utils.py +0 -44
  84. haiku/rag/graph/deep_qa/__init__.py +0 -1
  85. haiku/rag/graph/deep_qa/dependencies.py +0 -27
  86. haiku/rag/graph/deep_qa/graph.py +0 -243
  87. haiku/rag/graph/deep_qa/models.py +0 -20
  88. haiku/rag/graph/deep_qa/prompts.py +0 -59
  89. haiku/rag/graph/deep_qa/state.py +0 -56
  90. haiku/rag/graph/research/common.py +0 -87
  91. haiku/rag/reader.py +0 -135
  92. haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
  93. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
  94. {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,27 +1,182 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass, field
3
+ from typing import Literal, cast
4
+
5
+ from pydantic_ai.embeddings.base import EmbeddingModel
6
+ from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType
7
+ from pydantic_ai.embeddings.settings import EmbeddingSettings
8
+ from pydantic_ai.exceptions import ModelAPIError
9
+ from pydantic_ai.usage import RequestUsage
10
+
1
11
  try:
2
- from typing import overload
12
+ from voyageai.client_async import AsyncClient
13
+ from voyageai.error import VoyageError
14
+ except ImportError as _import_error:
15
+ raise ImportError(
16
+ "Please install `voyageai` to use the VoyageAI embeddings model, "
17
+ "you can use — `pip install voyageai`"
18
+ ) from _import_error
19
+
20
+ LatestVoyageAIEmbeddingModelNames = Literal[
21
+ "voyage-3-large",
22
+ "voyage-3.5",
23
+ "voyage-3.5-lite",
24
+ "voyage-code-3",
25
+ "voyage-finance-2",
26
+ "voyage-law-2",
27
+ "voyage-code-2",
28
+ ]
29
+ """Latest VoyageAI embedding models.
30
+
31
+ See [VoyageAI Embeddings](https://docs.voyageai.com/docs/embeddings)
32
+ for available models and their capabilities.
33
+ """
34
+
35
+ VoyageAIEmbeddingModelName = str | LatestVoyageAIEmbeddingModelNames
36
+ """Possible VoyageAI embedding model names."""
37
+
38
+
39
+ class VoyageAIEmbeddingSettings(EmbeddingSettings, total=False):
40
+ """Settings used for a VoyageAI embedding model request.
41
+
42
+ All fields from [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] are supported,
43
+ plus VoyageAI-specific settings prefixed with `voyageai_`.
44
+ """
45
+
46
+ # ALL FIELDS MUST BE `voyageai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
47
+
48
+ voyageai_truncation: bool
49
+ """Whether to truncate inputs that exceed the model's context length.
50
+
51
+ Defaults to True. If False, an error is raised for inputs that are too long.
52
+ """
53
+
54
+ voyageai_output_dtype: Literal["float", "int8", "uint8", "binary", "ubinary"]
55
+ """The output data type for embeddings.
56
+
57
+ - `'float'` (default): 32-bit floats
58
+ - `'int8'`: Signed 8-bit integers (quantized)
59
+ - `'uint8'`: Unsigned 8-bit integers (quantized)
60
+ - `'binary'`: Binary embeddings
61
+ - `'ubinary'`: Unsigned binary embeddings
62
+ """
63
+
64
+
65
+ _MAX_INPUT_TOKENS: dict[VoyageAIEmbeddingModelName, int] = {
66
+ "voyage-3-large": 32000,
67
+ "voyage-3.5": 32000,
68
+ "voyage-3.5-lite": 32000,
69
+ "voyage-code-3": 32000,
70
+ "voyage-finance-2": 32000,
71
+ "voyage-law-2": 16000,
72
+ "voyage-code-2": 16000,
73
+ }
74
+
75
+
76
+ @dataclass(init=False)
77
+ class VoyageAIEmbeddingModel(EmbeddingModel):
78
+ """VoyageAI embedding model implementation.
79
+
80
+ VoyageAI provides state-of-the-art embedding models optimized for
81
+ retrieval, with specialized models for code, finance, and legal domains.
82
+
83
+ Example:
84
+ ```python
85
+ from pydantic_ai.embeddings.voyageai import VoyageAIEmbeddingModel
86
+
87
+ model = VoyageAIEmbeddingModel('voyage-3.5')
88
+ ```
89
+ """
90
+
91
+ _model_name: VoyageAIEmbeddingModelName = field(repr=False)
92
+ _client: AsyncClient = field(repr=False)
93
+
94
+ def __init__(
95
+ self,
96
+ model_name: VoyageAIEmbeddingModelName,
97
+ *,
98
+ api_key: str | None = None,
99
+ max_retries: int = 0,
100
+ timeout: int | None = None,
101
+ settings: EmbeddingSettings | None = None,
102
+ ):
103
+ """Initialize a VoyageAI embedding model.
104
+
105
+ Args:
106
+ model_name: The name of the VoyageAI model to use.
107
+ See [VoyageAI models](https://docs.voyageai.com/docs/embeddings)
108
+ for available options.
109
+ api_key: The VoyageAI API key. If not provided, uses the
110
+ `VOYAGE_API_KEY` environment variable.
111
+ max_retries: Maximum number of retries for failed requests.
112
+ timeout: Request timeout in seconds.
113
+ settings: Model-specific [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings]
114
+ to use as defaults for this model.
115
+ """
116
+ self._model_name = model_name
117
+ self._client = AsyncClient(
118
+ api_key=api_key,
119
+ max_retries=max_retries,
120
+ timeout=timeout,
121
+ )
122
+
123
+ super().__init__(settings=settings)
124
+
125
+ @property
126
+ def model_name(self) -> VoyageAIEmbeddingModelName:
127
+ """The embedding model name."""
128
+ return self._model_name
129
+
130
+ @property
131
+ def system(self) -> str:
132
+ """The embedding model provider."""
133
+ return "voyageai"
134
+
135
+ async def embed(
136
+ self,
137
+ inputs: str | Sequence[str],
138
+ *,
139
+ input_type: EmbedInputType,
140
+ settings: EmbeddingSettings | None = None,
141
+ ) -> EmbeddingResult:
142
+ inputs, settings = self.prepare_embed(inputs, settings)
143
+ settings = cast(VoyageAIEmbeddingSettings, settings)
144
+
145
+ voyageai_input_type = "document" if input_type == "document" else "query"
3
146
 
4
- from voyageai.client import Client # type: ignore
147
+ try:
148
+ response = await self._client.embed(
149
+ texts=list(inputs),
150
+ model=self.model_name,
151
+ input_type=voyageai_input_type,
152
+ truncation=settings.get("voyageai_truncation", True),
153
+ output_dtype=settings.get("voyageai_output_dtype", "float"),
154
+ output_dimension=settings.get("dimensions"),
155
+ )
156
+ except VoyageError as e:
157
+ raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
5
158
 
6
- from haiku.rag.embeddings.base import EmbedderBase
159
+ return EmbeddingResult(
160
+ embeddings=response.embeddings,
161
+ inputs=inputs,
162
+ input_type=input_type,
163
+ usage=_map_usage(response.total_tokens, self.model_name),
164
+ model_name=self.model_name,
165
+ provider_name=self.system,
166
+ )
7
167
 
8
- class Embedder(EmbedderBase):
9
- @overload
10
- async def embed(self, text: str) -> list[float]: ...
168
+ async def max_input_tokens(self) -> int | None:
169
+ return _MAX_INPUT_TOKENS.get(self.model_name)
11
170
 
12
- @overload
13
- async def embed(self, text: list[str]) -> list[list[float]]: ...
14
171
 
15
- async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
16
- client = Client()
17
- if not text:
18
- return []
19
- if isinstance(text, str):
20
- res = client.embed([text], model=self._model, output_dtype="float")
21
- return res.embeddings[0] # type: ignore[return-value]
22
- else:
23
- res = client.embed(text, model=self._model, output_dtype="float")
24
- return res.embeddings # type: ignore[return-value]
172
+ def _map_usage(total_tokens: int, model: str) -> RequestUsage:
173
+ usage_data = {"total_tokens": total_tokens}
174
+ response_data = {"model": model, "usage": usage_data}
25
175
 
26
- except ImportError:
27
- pass
176
+ return RequestUsage.extract(
177
+ response_data,
178
+ provider="voyageai",
179
+ provider_url="https://api.voyageai.com",
180
+ provider_fallback="voyageai",
181
+ api_flavor="embeddings",
182
+ )
@@ -1,25 +1,14 @@
1
- """Graph module for haiku.rag.
2
-
3
- This module contains all graph-related functionality including:
4
- - AG-UI protocol for graph streaming
5
- - Common graph utilities and models
6
- - Research graph implementation
7
- - Deep QA graph implementation
8
- """
9
-
10
1
  from haiku.rag.graph.agui import (
11
2
  AGUIConsoleRenderer,
12
3
  AGUIEmitter,
13
4
  create_agui_server,
14
5
  stream_graph,
15
6
  )
16
- from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
17
7
  from haiku.rag.graph.research.graph import build_research_graph
18
8
 
19
9
  __all__ = [
20
10
  "AGUIConsoleRenderer",
21
11
  "AGUIEmitter",
22
- "build_deep_qa_graph",
23
12
  "build_research_graph",
24
13
  "create_agui_server",
25
14
  "stream_graph",
@@ -1,8 +1,8 @@
1
1
  """Generic AG-UI protocol support for haiku.rag graphs."""
2
2
 
3
3
  from haiku.rag.graph.agui.cli_renderer import AGUIConsoleRenderer
4
- from haiku.rag.graph.agui.emitter import AGUIEmitter
5
- from haiku.rag.graph.agui.events import (
4
+ from haiku.rag.graph.agui.emitter import (
5
+ AGUIEmitter,
6
6
  AGUIEvent,
7
7
  emit_activity,
8
8
  emit_activity_delta,
@@ -17,6 +17,9 @@ from haiku.rag.graph.agui.events import (
17
17
  emit_text_message_content,
18
18
  emit_text_message_end,
19
19
  emit_text_message_start,
20
+ emit_tool_call_args,
21
+ emit_tool_call_end,
22
+ emit_tool_call_start,
20
23
  )
21
24
  from haiku.rag.graph.agui.server import (
22
25
  RunAgentInput,
@@ -48,6 +51,9 @@ __all__ = [
48
51
  "emit_text_message_content",
49
52
  "emit_text_message_end",
50
53
  "emit_text_message_start",
54
+ "emit_tool_call_args",
55
+ "emit_tool_call_end",
56
+ "emit_tool_call_start",
51
57
  "format_sse_event",
52
58
  "stream_graph",
53
59
  ]
@@ -5,7 +5,7 @@ from typing import Any
5
5
 
6
6
  from rich.console import Console
7
7
 
8
- from haiku.rag.graph.agui.events import AGUIEvent
8
+ from haiku.rag.graph.agui.emitter import AGUIEvent
9
9
 
10
10
 
11
11
  class AGUIConsoleRenderer:
@@ -2,23 +2,39 @@
2
2
 
3
3
  import asyncio
4
4
  import hashlib
5
+ import json
5
6
  from collections.abc import AsyncIterator
7
+ from typing import Any
6
8
  from uuid import uuid4
7
9
 
10
+ from ag_ui.core import (
11
+ ActivitySnapshotEvent,
12
+ BaseEvent,
13
+ RunErrorEvent,
14
+ RunFinishedEvent,
15
+ RunStartedEvent,
16
+ StateDeltaEvent,
17
+ StateSnapshotEvent,
18
+ StepFinishedEvent,
19
+ StepStartedEvent,
20
+ TextMessageChunkEvent,
21
+ TextMessageContentEvent,
22
+ TextMessageEndEvent,
23
+ TextMessageStartEvent,
24
+ ToolCallArgsEvent,
25
+ ToolCallEndEvent,
26
+ ToolCallStartEvent,
27
+ )
8
28
  from pydantic import BaseModel
9
29
 
10
- from haiku.rag.graph.agui.events import (
11
- AGUIEvent,
12
- emit_activity,
13
- emit_run_error,
14
- emit_run_finished,
15
- emit_run_started,
16
- emit_state_delta,
17
- emit_state_snapshot,
18
- emit_step_finished,
19
- emit_step_started,
20
- emit_text_message,
21
- )
30
+ from haiku.rag.graph.agui.state import compute_state_delta
31
+
32
+ AGUIEvent = dict[str, Any]
33
+
34
+
35
+ def _serialize_event(event: BaseEvent) -> AGUIEvent:
36
+ """Serialize an ag_ui event to a dict with camelCase keys."""
37
+ return event.model_dump(mode="json", by_alias=True, exclude_none=True)
22
38
 
23
39
 
24
40
  class AGUIEmitter[StateT: BaseModel, ResultT]:
@@ -54,7 +70,7 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
54
70
  self._thread_id = thread_id or str(uuid4())
55
71
  self._run_id = run_id or str(uuid4())
56
72
  self._last_state: StateT | None = None
57
- self._current_step: str | None = None
73
+ self._active_steps: set[str] = set()
58
74
  self._use_deltas = use_deltas
59
75
 
60
76
  @property
@@ -79,8 +95,14 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
79
95
  self._thread_id = self._generate_thread_id(state_json)
80
96
 
81
97
  # RunStarted (state snapshot follows immediately with full state)
82
- self._emit(emit_run_started(self._thread_id, self._run_id))
83
- self._emit(emit_state_snapshot(initial_state))
98
+ self.emit(
99
+ _serialize_event(
100
+ RunStartedEvent(thread_id=self._thread_id, run_id=self._run_id)
101
+ )
102
+ )
103
+ self.emit(
104
+ _serialize_event(StateSnapshotEvent(snapshot=initial_state.model_dump()))
105
+ )
84
106
  # Store a deep copy to detect future changes
85
107
  self._last_state = initial_state.model_copy(deep=True)
86
108
 
@@ -90,14 +112,17 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
90
112
  Args:
91
113
  step_name: Name of the step being started
92
114
  """
93
- self._current_step = step_name
94
- self._emit(emit_step_started(step_name))
115
+ self._active_steps.add(step_name)
116
+ self.emit(_serialize_event(StepStartedEvent(step_name=step_name)))
95
117
 
96
- def finish_step(self) -> None:
97
- """Emit StepFinished event for the current step."""
98
- if self._current_step:
99
- self._emit(emit_step_finished(self._current_step))
100
- self._current_step = None
118
+ def finish_step(self, step_name: str) -> None:
119
+ """Emit StepFinished event for the specified step.
120
+
121
+ Args:
122
+ step_name: Name of the step being finished
123
+ """
124
+ self._active_steps.discard(step_name)
125
+ self.emit(_serialize_event(StepFinishedEvent(step_name=step_name)))
101
126
 
102
127
  def log(self, message: str, role: str = "assistant") -> None:
103
128
  """Emit a text message event.
@@ -106,7 +131,16 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
106
131
  message: The message content
107
132
  role: The role of the sender (default: assistant)
108
133
  """
109
- self._emit(emit_text_message(message, role))
134
+ message_id = str(uuid4())
135
+ self.emit(
136
+ _serialize_event(
137
+ TextMessageChunkEvent(
138
+ message_id=message_id,
139
+ role=role, # type: ignore[arg-type]
140
+ delta=message,
141
+ )
142
+ )
143
+ )
110
144
 
111
145
  def update_state(self, new_state: StateT) -> None:
112
146
  """Emit StateDelta or StateSnapshot for state change.
@@ -116,26 +150,40 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
116
150
  """
117
151
  if self._use_deltas and self._last_state is not None:
118
152
  # Emit delta for incremental updates
119
- self._emit(emit_state_delta(self._last_state, new_state))
153
+ delta = compute_state_delta(self._last_state, new_state)
154
+ self.emit(_serialize_event(StateDeltaEvent(delta=delta)))
120
155
  else:
121
156
  # Emit full snapshot for initial state or when deltas disabled
122
- self._emit(emit_state_snapshot(new_state))
157
+ self.emit(
158
+ _serialize_event(StateSnapshotEvent(snapshot=new_state.model_dump()))
159
+ )
123
160
  # Store a deep copy to detect future changes
124
161
  self._last_state = new_state.model_copy(deep=True)
125
162
 
126
163
  def update_activity(
127
- self, activity_type: str, content: str, message_id: str | None = None
164
+ self,
165
+ activity_type: str,
166
+ content: dict[str, Any],
167
+ message_id: str | None = None,
128
168
  ) -> None:
129
169
  """Emit ActivitySnapshot event.
130
170
 
131
171
  Args:
132
172
  activity_type: Type of activity (e.g., "planning", "searching")
133
- content: Description of the activity
173
+ content: Structured payload representing the activity state
134
174
  message_id: Optional message ID to associate activity with (auto-generated if None)
135
175
  """
136
176
  if message_id is None:
137
177
  message_id = str(uuid4())
138
- self._emit(emit_activity(message_id, activity_type, content))
178
+ self.emit(
179
+ _serialize_event(
180
+ ActivitySnapshotEvent(
181
+ message_id=message_id,
182
+ activity_type=activity_type,
183
+ content=content,
184
+ )
185
+ )
186
+ )
139
187
 
140
188
  def finish_run(self, result: ResultT) -> None:
141
189
  """Emit RunFinished event.
@@ -143,7 +191,18 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
143
191
  Args:
144
192
  result: The final result from the graph
145
193
  """
146
- self._emit(emit_run_finished(self._thread_id, self._run_id, result))
194
+ # Convert result to dict if it's a Pydantic model
195
+ result_data: Any = result
196
+ if hasattr(result, "model_dump"):
197
+ result_data = result.model_dump() # type: ignore[union-attr]
198
+
199
+ self.emit(
200
+ _serialize_event(
201
+ RunFinishedEvent(
202
+ thread_id=self._thread_id, run_id=self._run_id, result=result_data
203
+ )
204
+ )
205
+ )
147
206
 
148
207
  def error(self, error: Exception, code: str | None = None) -> None:
149
208
  """Emit RunError event.
@@ -152,9 +211,9 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
152
211
  error: The exception that occurred
153
212
  code: Optional error code
154
213
  """
155
- self._emit(emit_run_error(str(error), code))
214
+ self.emit(_serialize_event(RunErrorEvent(message=str(error), code=code)))
156
215
 
157
- def _emit(self, event: AGUIEvent) -> None:
216
+ def emit(self, event: AGUIEvent) -> None:
158
217
  """Put event in queue.
159
218
 
160
219
  Args:
@@ -195,3 +254,132 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
195
254
  # Use hash of input for deterministic thread ID
196
255
  hash_obj = hashlib.sha256(input_data.encode("utf-8"))
197
256
  return hash_obj.hexdigest()[:16]
257
+
258
+
259
+ def emit_text_message_start(message_id: str, role: str = "assistant") -> AGUIEvent:
260
+ """Create a TextMessageStart event."""
261
+ return _serialize_event(
262
+ TextMessageStartEvent(message_id=message_id, role=role) # type: ignore[arg-type]
263
+ )
264
+
265
+
266
+ def emit_text_message_content(message_id: str, delta: str) -> AGUIEvent:
267
+ """Create a TextMessageContent event."""
268
+ return _serialize_event(TextMessageContentEvent(message_id=message_id, delta=delta))
269
+
270
+
271
+ def emit_text_message_end(message_id: str) -> AGUIEvent:
272
+ """Create a TextMessageEnd event."""
273
+ return _serialize_event(TextMessageEndEvent(message_id=message_id))
274
+
275
+
276
+ def emit_tool_call_start(
277
+ tool_call_id: str,
278
+ tool_name: str,
279
+ parent_message_id: str | None = None,
280
+ ) -> AGUIEvent:
281
+ """Create a ToolCallStart event."""
282
+ return _serialize_event(
283
+ ToolCallStartEvent(
284
+ tool_call_id=tool_call_id,
285
+ tool_call_name=tool_name,
286
+ parent_message_id=parent_message_id,
287
+ )
288
+ )
289
+
290
+
291
+ def emit_tool_call_args(tool_call_id: str, args: dict[str, Any]) -> AGUIEvent:
292
+ """Create a ToolCallArgs event."""
293
+ return _serialize_event(
294
+ ToolCallArgsEvent(tool_call_id=tool_call_id, delta=json.dumps(args))
295
+ )
296
+
297
+
298
+ def emit_tool_call_end(tool_call_id: str) -> AGUIEvent:
299
+ """Create a ToolCallEnd event."""
300
+ return _serialize_event(ToolCallEndEvent(tool_call_id=tool_call_id))
301
+
302
+
303
+ def emit_run_started(thread_id: str, run_id: str) -> AGUIEvent:
304
+ """Create a RunStarted event."""
305
+ return _serialize_event(RunStartedEvent(thread_id=thread_id, run_id=run_id))
306
+
307
+
308
+ def emit_run_finished(thread_id: str, run_id: str, result: Any) -> AGUIEvent:
309
+ """Create a RunFinished event."""
310
+ # Convert result to dict if it's a Pydantic model
311
+ if hasattr(result, "model_dump"):
312
+ result = result.model_dump()
313
+ return _serialize_event(
314
+ RunFinishedEvent(thread_id=thread_id, run_id=run_id, result=result)
315
+ )
316
+
317
+
318
+ def emit_run_error(message: str, code: str | None = None) -> AGUIEvent:
319
+ """Create a RunError event."""
320
+ return _serialize_event(RunErrorEvent(message=message, code=code))
321
+
322
+
323
+ def emit_step_started(step_name: str) -> AGUIEvent:
324
+ """Create a StepStarted event."""
325
+ return _serialize_event(StepStartedEvent(step_name=step_name))
326
+
327
+
328
+ def emit_step_finished(step_name: str) -> AGUIEvent:
329
+ """Create a StepFinished event."""
330
+ return _serialize_event(StepFinishedEvent(step_name=step_name))
331
+
332
+
333
+ def emit_text_message(content: str, role: str = "assistant") -> AGUIEvent:
334
+ """Create a TextMessageChunk event (convenience wrapper)."""
335
+ message_id = str(uuid4())
336
+ return _serialize_event(
337
+ TextMessageChunkEvent(
338
+ message_id=message_id,
339
+ role=role, # type: ignore[arg-type]
340
+ delta=content,
341
+ )
342
+ )
343
+
344
+
345
+ def emit_state_snapshot(state: BaseModel) -> AGUIEvent:
346
+ """Create a StateSnapshot event."""
347
+ return _serialize_event(StateSnapshotEvent(snapshot=state.model_dump()))
348
+
349
+
350
+ def emit_state_delta(old_state: BaseModel, new_state: BaseModel) -> AGUIEvent:
351
+ """Create a StateDelta event with JSON Patch operations."""
352
+ delta = compute_state_delta(old_state, new_state)
353
+ return _serialize_event(StateDeltaEvent(delta=delta))
354
+
355
+
356
+ def emit_activity(
357
+ message_id: str,
358
+ activity_type: str,
359
+ content: dict[str, Any],
360
+ ) -> AGUIEvent:
361
+ """Create an ActivitySnapshot event."""
362
+ return _serialize_event(
363
+ ActivitySnapshotEvent(
364
+ message_id=message_id,
365
+ activity_type=activity_type,
366
+ content=content,
367
+ )
368
+ )
369
+
370
+
371
+ def emit_activity_delta(
372
+ message_id: str,
373
+ activity_type: str,
374
+ patch: list[dict[str, Any]],
375
+ ) -> AGUIEvent:
376
+ """Create an ActivityDelta event with JSON Patch operations."""
377
+ from ag_ui.core import ActivityDeltaEvent
378
+
379
+ return _serialize_event(
380
+ ActivityDeltaEvent(
381
+ message_id=message_id,
382
+ activity_type=activity_type,
383
+ patch=patch,
384
+ )
385
+ )