haiku.rag-slim 0.16.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag-slim might be problematic. Click here for more details.
- haiku/rag/app.py +430 -72
- haiku/rag/chunkers/__init__.py +31 -0
- haiku/rag/chunkers/base.py +31 -0
- haiku/rag/chunkers/docling_local.py +164 -0
- haiku/rag/chunkers/docling_serve.py +179 -0
- haiku/rag/cli.py +207 -24
- haiku/rag/cli_chat.py +489 -0
- haiku/rag/client.py +1251 -266
- haiku/rag/config/__init__.py +16 -10
- haiku/rag/config/loader.py +5 -44
- haiku/rag/config/models.py +126 -17
- haiku/rag/converters/__init__.py +31 -0
- haiku/rag/converters/base.py +63 -0
- haiku/rag/converters/docling_local.py +193 -0
- haiku/rag/converters/docling_serve.py +229 -0
- haiku/rag/converters/text_utils.py +237 -0
- haiku/rag/embeddings/__init__.py +123 -24
- haiku/rag/embeddings/voyageai.py +175 -20
- haiku/rag/graph/__init__.py +0 -11
- haiku/rag/graph/agui/__init__.py +8 -2
- haiku/rag/graph/agui/cli_renderer.py +1 -1
- haiku/rag/graph/agui/emitter.py +219 -31
- haiku/rag/graph/agui/server.py +20 -62
- haiku/rag/graph/agui/stream.py +1 -2
- haiku/rag/graph/research/__init__.py +5 -2
- haiku/rag/graph/research/dependencies.py +12 -126
- haiku/rag/graph/research/graph.py +390 -135
- haiku/rag/graph/research/models.py +91 -112
- haiku/rag/graph/research/prompts.py +99 -91
- haiku/rag/graph/research/state.py +35 -27
- haiku/rag/inspector/__init__.py +8 -0
- haiku/rag/inspector/app.py +259 -0
- haiku/rag/inspector/widgets/__init__.py +6 -0
- haiku/rag/inspector/widgets/chunk_list.py +100 -0
- haiku/rag/inspector/widgets/context_modal.py +89 -0
- haiku/rag/inspector/widgets/detail_view.py +130 -0
- haiku/rag/inspector/widgets/document_list.py +75 -0
- haiku/rag/inspector/widgets/info_modal.py +209 -0
- haiku/rag/inspector/widgets/search_modal.py +183 -0
- haiku/rag/inspector/widgets/visual_modal.py +126 -0
- haiku/rag/mcp.py +106 -102
- haiku/rag/monitor.py +33 -9
- haiku/rag/providers/__init__.py +5 -0
- haiku/rag/providers/docling_serve.py +108 -0
- haiku/rag/qa/__init__.py +12 -10
- haiku/rag/qa/agent.py +43 -61
- haiku/rag/qa/prompts.py +35 -57
- haiku/rag/reranking/__init__.py +9 -6
- haiku/rag/reranking/base.py +1 -1
- haiku/rag/reranking/cohere.py +5 -4
- haiku/rag/reranking/mxbai.py +5 -2
- haiku/rag/reranking/vllm.py +3 -4
- haiku/rag/reranking/zeroentropy.py +6 -5
- haiku/rag/store/__init__.py +2 -1
- haiku/rag/store/engine.py +242 -42
- haiku/rag/store/exceptions.py +4 -0
- haiku/rag/store/models/__init__.py +8 -2
- haiku/rag/store/models/chunk.py +190 -0
- haiku/rag/store/models/document.py +46 -0
- haiku/rag/store/repositories/chunk.py +141 -121
- haiku/rag/store/repositories/document.py +25 -84
- haiku/rag/store/repositories/settings.py +11 -14
- haiku/rag/store/upgrades/__init__.py +19 -3
- haiku/rag/store/upgrades/v0_10_1.py +1 -1
- haiku/rag/store/upgrades/v0_19_6.py +65 -0
- haiku/rag/store/upgrades/v0_20_0.py +68 -0
- haiku/rag/store/upgrades/v0_23_1.py +100 -0
- haiku/rag/store/upgrades/v0_9_3.py +3 -3
- haiku/rag/utils.py +371 -146
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/METADATA +15 -12
- haiku_rag_slim-0.24.0.dist-info/RECORD +78 -0
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/WHEEL +1 -1
- haiku/rag/chunker.py +0 -65
- haiku/rag/embeddings/base.py +0 -25
- haiku/rag/embeddings/ollama.py +0 -28
- haiku/rag/embeddings/openai.py +0 -26
- haiku/rag/embeddings/vllm.py +0 -29
- haiku/rag/graph/agui/events.py +0 -254
- haiku/rag/graph/common/__init__.py +0 -5
- haiku/rag/graph/common/models.py +0 -42
- haiku/rag/graph/common/nodes.py +0 -265
- haiku/rag/graph/common/prompts.py +0 -46
- haiku/rag/graph/common/utils.py +0 -44
- haiku/rag/graph/deep_qa/__init__.py +0 -1
- haiku/rag/graph/deep_qa/dependencies.py +0 -27
- haiku/rag/graph/deep_qa/graph.py +0 -243
- haiku/rag/graph/deep_qa/models.py +0 -20
- haiku/rag/graph/deep_qa/prompts.py +0 -59
- haiku/rag/graph/deep_qa/state.py +0 -56
- haiku/rag/graph/research/common.py +0 -87
- haiku/rag/reader.py +0 -135
- haiku_rag_slim-0.16.0.dist-info/RECORD +0 -71
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/entry_points.txt +0 -0
- {haiku_rag_slim-0.16.0.dist-info → haiku_rag_slim-0.24.0.dist-info}/licenses/LICENSE +0 -0
haiku/rag/embeddings/voyageai.py
CHANGED
|
@@ -1,27 +1,182 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Literal, cast
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.embeddings.base import EmbeddingModel
|
|
6
|
+
from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType
|
|
7
|
+
from pydantic_ai.embeddings.settings import EmbeddingSettings
|
|
8
|
+
from pydantic_ai.exceptions import ModelAPIError
|
|
9
|
+
from pydantic_ai.usage import RequestUsage
|
|
10
|
+
|
|
1
11
|
try:
|
|
2
|
-
from
|
|
12
|
+
from voyageai.client_async import AsyncClient
|
|
13
|
+
from voyageai.error import VoyageError
|
|
14
|
+
except ImportError as _import_error:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"Please install `voyageai` to use the VoyageAI embeddings model, "
|
|
17
|
+
"you can use — `pip install voyageai`"
|
|
18
|
+
) from _import_error
|
|
19
|
+
|
|
20
|
+
LatestVoyageAIEmbeddingModelNames = Literal[
|
|
21
|
+
"voyage-3-large",
|
|
22
|
+
"voyage-3.5",
|
|
23
|
+
"voyage-3.5-lite",
|
|
24
|
+
"voyage-code-3",
|
|
25
|
+
"voyage-finance-2",
|
|
26
|
+
"voyage-law-2",
|
|
27
|
+
"voyage-code-2",
|
|
28
|
+
]
|
|
29
|
+
"""Latest VoyageAI embedding models.
|
|
30
|
+
|
|
31
|
+
See [VoyageAI Embeddings](https://docs.voyageai.com/docs/embeddings)
|
|
32
|
+
for available models and their capabilities.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
VoyageAIEmbeddingModelName = str | LatestVoyageAIEmbeddingModelNames
|
|
36
|
+
"""Possible VoyageAI embedding model names."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class VoyageAIEmbeddingSettings(EmbeddingSettings, total=False):
|
|
40
|
+
"""Settings used for a VoyageAI embedding model request.
|
|
41
|
+
|
|
42
|
+
All fields from [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] are supported,
|
|
43
|
+
plus VoyageAI-specific settings prefixed with `voyageai_`.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# ALL FIELDS MUST BE `voyageai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
47
|
+
|
|
48
|
+
voyageai_truncation: bool
|
|
49
|
+
"""Whether to truncate inputs that exceed the model's context length.
|
|
50
|
+
|
|
51
|
+
Defaults to True. If False, an error is raised for inputs that are too long.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
voyageai_output_dtype: Literal["float", "int8", "uint8", "binary", "ubinary"]
|
|
55
|
+
"""The output data type for embeddings.
|
|
56
|
+
|
|
57
|
+
- `'float'` (default): 32-bit floats
|
|
58
|
+
- `'int8'`: Signed 8-bit integers (quantized)
|
|
59
|
+
- `'uint8'`: Unsigned 8-bit integers (quantized)
|
|
60
|
+
- `'binary'`: Binary embeddings
|
|
61
|
+
- `'ubinary'`: Unsigned binary embeddings
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
_MAX_INPUT_TOKENS: dict[VoyageAIEmbeddingModelName, int] = {
|
|
66
|
+
"voyage-3-large": 32000,
|
|
67
|
+
"voyage-3.5": 32000,
|
|
68
|
+
"voyage-3.5-lite": 32000,
|
|
69
|
+
"voyage-code-3": 32000,
|
|
70
|
+
"voyage-finance-2": 32000,
|
|
71
|
+
"voyage-law-2": 16000,
|
|
72
|
+
"voyage-code-2": 16000,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(init=False)
|
|
77
|
+
class VoyageAIEmbeddingModel(EmbeddingModel):
|
|
78
|
+
"""VoyageAI embedding model implementation.
|
|
79
|
+
|
|
80
|
+
VoyageAI provides state-of-the-art embedding models optimized for
|
|
81
|
+
retrieval, with specialized models for code, finance, and legal domains.
|
|
82
|
+
|
|
83
|
+
Example:
|
|
84
|
+
```python
|
|
85
|
+
from pydantic_ai.embeddings.voyageai import VoyageAIEmbeddingModel
|
|
86
|
+
|
|
87
|
+
model = VoyageAIEmbeddingModel('voyage-3.5')
|
|
88
|
+
```
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
_model_name: VoyageAIEmbeddingModelName = field(repr=False)
|
|
92
|
+
_client: AsyncClient = field(repr=False)
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
model_name: VoyageAIEmbeddingModelName,
|
|
97
|
+
*,
|
|
98
|
+
api_key: str | None = None,
|
|
99
|
+
max_retries: int = 0,
|
|
100
|
+
timeout: int | None = None,
|
|
101
|
+
settings: EmbeddingSettings | None = None,
|
|
102
|
+
):
|
|
103
|
+
"""Initialize a VoyageAI embedding model.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model_name: The name of the VoyageAI model to use.
|
|
107
|
+
See [VoyageAI models](https://docs.voyageai.com/docs/embeddings)
|
|
108
|
+
for available options.
|
|
109
|
+
api_key: The VoyageAI API key. If not provided, uses the
|
|
110
|
+
`VOYAGE_API_KEY` environment variable.
|
|
111
|
+
max_retries: Maximum number of retries for failed requests.
|
|
112
|
+
timeout: Request timeout in seconds.
|
|
113
|
+
settings: Model-specific [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings]
|
|
114
|
+
to use as defaults for this model.
|
|
115
|
+
"""
|
|
116
|
+
self._model_name = model_name
|
|
117
|
+
self._client = AsyncClient(
|
|
118
|
+
api_key=api_key,
|
|
119
|
+
max_retries=max_retries,
|
|
120
|
+
timeout=timeout,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
super().__init__(settings=settings)
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def model_name(self) -> VoyageAIEmbeddingModelName:
|
|
127
|
+
"""The embedding model name."""
|
|
128
|
+
return self._model_name
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def system(self) -> str:
|
|
132
|
+
"""The embedding model provider."""
|
|
133
|
+
return "voyageai"
|
|
134
|
+
|
|
135
|
+
async def embed(
|
|
136
|
+
self,
|
|
137
|
+
inputs: str | Sequence[str],
|
|
138
|
+
*,
|
|
139
|
+
input_type: EmbedInputType,
|
|
140
|
+
settings: EmbeddingSettings | None = None,
|
|
141
|
+
) -> EmbeddingResult:
|
|
142
|
+
inputs, settings = self.prepare_embed(inputs, settings)
|
|
143
|
+
settings = cast(VoyageAIEmbeddingSettings, settings)
|
|
144
|
+
|
|
145
|
+
voyageai_input_type = "document" if input_type == "document" else "query"
|
|
3
146
|
|
|
4
|
-
|
|
147
|
+
try:
|
|
148
|
+
response = await self._client.embed(
|
|
149
|
+
texts=list(inputs),
|
|
150
|
+
model=self.model_name,
|
|
151
|
+
input_type=voyageai_input_type,
|
|
152
|
+
truncation=settings.get("voyageai_truncation", True),
|
|
153
|
+
output_dtype=settings.get("voyageai_output_dtype", "float"),
|
|
154
|
+
output_dimension=settings.get("dimensions"),
|
|
155
|
+
)
|
|
156
|
+
except VoyageError as e:
|
|
157
|
+
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
|
|
5
158
|
|
|
6
|
-
|
|
159
|
+
return EmbeddingResult(
|
|
160
|
+
embeddings=response.embeddings,
|
|
161
|
+
inputs=inputs,
|
|
162
|
+
input_type=input_type,
|
|
163
|
+
usage=_map_usage(response.total_tokens, self.model_name),
|
|
164
|
+
model_name=self.model_name,
|
|
165
|
+
provider_name=self.system,
|
|
166
|
+
)
|
|
7
167
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
async def embed(self, text: str) -> list[float]: ...
|
|
168
|
+
async def max_input_tokens(self) -> int | None:
|
|
169
|
+
return _MAX_INPUT_TOKENS.get(self.model_name)
|
|
11
170
|
|
|
12
|
-
@overload
|
|
13
|
-
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
171
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
return []
|
|
19
|
-
if isinstance(text, str):
|
|
20
|
-
res = client.embed([text], model=self._model, output_dtype="float")
|
|
21
|
-
return res.embeddings[0] # type: ignore[return-value]
|
|
22
|
-
else:
|
|
23
|
-
res = client.embed(text, model=self._model, output_dtype="float")
|
|
24
|
-
return res.embeddings # type: ignore[return-value]
|
|
172
|
+
def _map_usage(total_tokens: int, model: str) -> RequestUsage:
|
|
173
|
+
usage_data = {"total_tokens": total_tokens}
|
|
174
|
+
response_data = {"model": model, "usage": usage_data}
|
|
25
175
|
|
|
26
|
-
|
|
27
|
-
|
|
176
|
+
return RequestUsage.extract(
|
|
177
|
+
response_data,
|
|
178
|
+
provider="voyageai",
|
|
179
|
+
provider_url="https://api.voyageai.com",
|
|
180
|
+
provider_fallback="voyageai",
|
|
181
|
+
api_flavor="embeddings",
|
|
182
|
+
)
|
haiku/rag/graph/__init__.py
CHANGED
|
@@ -1,25 +1,14 @@
|
|
|
1
|
-
"""Graph module for haiku.rag.
|
|
2
|
-
|
|
3
|
-
This module contains all graph-related functionality including:
|
|
4
|
-
- AG-UI protocol for graph streaming
|
|
5
|
-
- Common graph utilities and models
|
|
6
|
-
- Research graph implementation
|
|
7
|
-
- Deep QA graph implementation
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
1
|
from haiku.rag.graph.agui import (
|
|
11
2
|
AGUIConsoleRenderer,
|
|
12
3
|
AGUIEmitter,
|
|
13
4
|
create_agui_server,
|
|
14
5
|
stream_graph,
|
|
15
6
|
)
|
|
16
|
-
from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
|
|
17
7
|
from haiku.rag.graph.research.graph import build_research_graph
|
|
18
8
|
|
|
19
9
|
__all__ = [
|
|
20
10
|
"AGUIConsoleRenderer",
|
|
21
11
|
"AGUIEmitter",
|
|
22
|
-
"build_deep_qa_graph",
|
|
23
12
|
"build_research_graph",
|
|
24
13
|
"create_agui_server",
|
|
25
14
|
"stream_graph",
|
haiku/rag/graph/agui/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""Generic AG-UI protocol support for haiku.rag graphs."""
|
|
2
2
|
|
|
3
3
|
from haiku.rag.graph.agui.cli_renderer import AGUIConsoleRenderer
|
|
4
|
-
from haiku.rag.graph.agui.emitter import
|
|
5
|
-
|
|
4
|
+
from haiku.rag.graph.agui.emitter import (
|
|
5
|
+
AGUIEmitter,
|
|
6
6
|
AGUIEvent,
|
|
7
7
|
emit_activity,
|
|
8
8
|
emit_activity_delta,
|
|
@@ -17,6 +17,9 @@ from haiku.rag.graph.agui.events import (
|
|
|
17
17
|
emit_text_message_content,
|
|
18
18
|
emit_text_message_end,
|
|
19
19
|
emit_text_message_start,
|
|
20
|
+
emit_tool_call_args,
|
|
21
|
+
emit_tool_call_end,
|
|
22
|
+
emit_tool_call_start,
|
|
20
23
|
)
|
|
21
24
|
from haiku.rag.graph.agui.server import (
|
|
22
25
|
RunAgentInput,
|
|
@@ -48,6 +51,9 @@ __all__ = [
|
|
|
48
51
|
"emit_text_message_content",
|
|
49
52
|
"emit_text_message_end",
|
|
50
53
|
"emit_text_message_start",
|
|
54
|
+
"emit_tool_call_args",
|
|
55
|
+
"emit_tool_call_end",
|
|
56
|
+
"emit_tool_call_start",
|
|
51
57
|
"format_sse_event",
|
|
52
58
|
"stream_graph",
|
|
53
59
|
]
|
haiku/rag/graph/agui/emitter.py
CHANGED
|
@@ -2,23 +2,39 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import hashlib
|
|
5
|
+
import json
|
|
5
6
|
from collections.abc import AsyncIterator
|
|
7
|
+
from typing import Any
|
|
6
8
|
from uuid import uuid4
|
|
7
9
|
|
|
10
|
+
from ag_ui.core import (
|
|
11
|
+
ActivitySnapshotEvent,
|
|
12
|
+
BaseEvent,
|
|
13
|
+
RunErrorEvent,
|
|
14
|
+
RunFinishedEvent,
|
|
15
|
+
RunStartedEvent,
|
|
16
|
+
StateDeltaEvent,
|
|
17
|
+
StateSnapshotEvent,
|
|
18
|
+
StepFinishedEvent,
|
|
19
|
+
StepStartedEvent,
|
|
20
|
+
TextMessageChunkEvent,
|
|
21
|
+
TextMessageContentEvent,
|
|
22
|
+
TextMessageEndEvent,
|
|
23
|
+
TextMessageStartEvent,
|
|
24
|
+
ToolCallArgsEvent,
|
|
25
|
+
ToolCallEndEvent,
|
|
26
|
+
ToolCallStartEvent,
|
|
27
|
+
)
|
|
8
28
|
from pydantic import BaseModel
|
|
9
29
|
|
|
10
|
-
from haiku.rag.graph.agui.
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
emit_step_finished,
|
|
19
|
-
emit_step_started,
|
|
20
|
-
emit_text_message,
|
|
21
|
-
)
|
|
30
|
+
from haiku.rag.graph.agui.state import compute_state_delta
|
|
31
|
+
|
|
32
|
+
AGUIEvent = dict[str, Any]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _serialize_event(event: BaseEvent) -> AGUIEvent:
|
|
36
|
+
"""Serialize an ag_ui event to a dict with camelCase keys."""
|
|
37
|
+
return event.model_dump(mode="json", by_alias=True, exclude_none=True)
|
|
22
38
|
|
|
23
39
|
|
|
24
40
|
class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
@@ -54,7 +70,7 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
54
70
|
self._thread_id = thread_id or str(uuid4())
|
|
55
71
|
self._run_id = run_id or str(uuid4())
|
|
56
72
|
self._last_state: StateT | None = None
|
|
57
|
-
self.
|
|
73
|
+
self._active_steps: set[str] = set()
|
|
58
74
|
self._use_deltas = use_deltas
|
|
59
75
|
|
|
60
76
|
@property
|
|
@@ -79,8 +95,14 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
79
95
|
self._thread_id = self._generate_thread_id(state_json)
|
|
80
96
|
|
|
81
97
|
# RunStarted (state snapshot follows immediately with full state)
|
|
82
|
-
self.
|
|
83
|
-
|
|
98
|
+
self.emit(
|
|
99
|
+
_serialize_event(
|
|
100
|
+
RunStartedEvent(thread_id=self._thread_id, run_id=self._run_id)
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
self.emit(
|
|
104
|
+
_serialize_event(StateSnapshotEvent(snapshot=initial_state.model_dump()))
|
|
105
|
+
)
|
|
84
106
|
# Store a deep copy to detect future changes
|
|
85
107
|
self._last_state = initial_state.model_copy(deep=True)
|
|
86
108
|
|
|
@@ -90,14 +112,17 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
90
112
|
Args:
|
|
91
113
|
step_name: Name of the step being started
|
|
92
114
|
"""
|
|
93
|
-
self.
|
|
94
|
-
self.
|
|
115
|
+
self._active_steps.add(step_name)
|
|
116
|
+
self.emit(_serialize_event(StepStartedEvent(step_name=step_name)))
|
|
95
117
|
|
|
96
|
-
def finish_step(self) -> None:
|
|
97
|
-
"""Emit StepFinished event for the
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
118
|
+
def finish_step(self, step_name: str) -> None:
|
|
119
|
+
"""Emit StepFinished event for the specified step.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
step_name: Name of the step being finished
|
|
123
|
+
"""
|
|
124
|
+
self._active_steps.discard(step_name)
|
|
125
|
+
self.emit(_serialize_event(StepFinishedEvent(step_name=step_name)))
|
|
101
126
|
|
|
102
127
|
def log(self, message: str, role: str = "assistant") -> None:
|
|
103
128
|
"""Emit a text message event.
|
|
@@ -106,7 +131,16 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
106
131
|
message: The message content
|
|
107
132
|
role: The role of the sender (default: assistant)
|
|
108
133
|
"""
|
|
109
|
-
|
|
134
|
+
message_id = str(uuid4())
|
|
135
|
+
self.emit(
|
|
136
|
+
_serialize_event(
|
|
137
|
+
TextMessageChunkEvent(
|
|
138
|
+
message_id=message_id,
|
|
139
|
+
role=role, # type: ignore[arg-type]
|
|
140
|
+
delta=message,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
)
|
|
110
144
|
|
|
111
145
|
def update_state(self, new_state: StateT) -> None:
|
|
112
146
|
"""Emit StateDelta or StateSnapshot for state change.
|
|
@@ -116,26 +150,40 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
116
150
|
"""
|
|
117
151
|
if self._use_deltas and self._last_state is not None:
|
|
118
152
|
# Emit delta for incremental updates
|
|
119
|
-
|
|
153
|
+
delta = compute_state_delta(self._last_state, new_state)
|
|
154
|
+
self.emit(_serialize_event(StateDeltaEvent(delta=delta)))
|
|
120
155
|
else:
|
|
121
156
|
# Emit full snapshot for initial state or when deltas disabled
|
|
122
|
-
self.
|
|
157
|
+
self.emit(
|
|
158
|
+
_serialize_event(StateSnapshotEvent(snapshot=new_state.model_dump()))
|
|
159
|
+
)
|
|
123
160
|
# Store a deep copy to detect future changes
|
|
124
161
|
self._last_state = new_state.model_copy(deep=True)
|
|
125
162
|
|
|
126
163
|
def update_activity(
|
|
127
|
-
self,
|
|
164
|
+
self,
|
|
165
|
+
activity_type: str,
|
|
166
|
+
content: dict[str, Any],
|
|
167
|
+
message_id: str | None = None,
|
|
128
168
|
) -> None:
|
|
129
169
|
"""Emit ActivitySnapshot event.
|
|
130
170
|
|
|
131
171
|
Args:
|
|
132
172
|
activity_type: Type of activity (e.g., "planning", "searching")
|
|
133
|
-
content:
|
|
173
|
+
content: Structured payload representing the activity state
|
|
134
174
|
message_id: Optional message ID to associate activity with (auto-generated if None)
|
|
135
175
|
"""
|
|
136
176
|
if message_id is None:
|
|
137
177
|
message_id = str(uuid4())
|
|
138
|
-
self.
|
|
178
|
+
self.emit(
|
|
179
|
+
_serialize_event(
|
|
180
|
+
ActivitySnapshotEvent(
|
|
181
|
+
message_id=message_id,
|
|
182
|
+
activity_type=activity_type,
|
|
183
|
+
content=content,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
)
|
|
139
187
|
|
|
140
188
|
def finish_run(self, result: ResultT) -> None:
|
|
141
189
|
"""Emit RunFinished event.
|
|
@@ -143,7 +191,18 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
143
191
|
Args:
|
|
144
192
|
result: The final result from the graph
|
|
145
193
|
"""
|
|
146
|
-
|
|
194
|
+
# Convert result to dict if it's a Pydantic model
|
|
195
|
+
result_data: Any = result
|
|
196
|
+
if hasattr(result, "model_dump"):
|
|
197
|
+
result_data = result.model_dump() # type: ignore[union-attr]
|
|
198
|
+
|
|
199
|
+
self.emit(
|
|
200
|
+
_serialize_event(
|
|
201
|
+
RunFinishedEvent(
|
|
202
|
+
thread_id=self._thread_id, run_id=self._run_id, result=result_data
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
)
|
|
147
206
|
|
|
148
207
|
def error(self, error: Exception, code: str | None = None) -> None:
|
|
149
208
|
"""Emit RunError event.
|
|
@@ -152,9 +211,9 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
152
211
|
error: The exception that occurred
|
|
153
212
|
code: Optional error code
|
|
154
213
|
"""
|
|
155
|
-
self.
|
|
214
|
+
self.emit(_serialize_event(RunErrorEvent(message=str(error), code=code)))
|
|
156
215
|
|
|
157
|
-
def
|
|
216
|
+
def emit(self, event: AGUIEvent) -> None:
|
|
158
217
|
"""Put event in queue.
|
|
159
218
|
|
|
160
219
|
Args:
|
|
@@ -195,3 +254,132 @@ class AGUIEmitter[StateT: BaseModel, ResultT]:
|
|
|
195
254
|
# Use hash of input for deterministic thread ID
|
|
196
255
|
hash_obj = hashlib.sha256(input_data.encode("utf-8"))
|
|
197
256
|
return hash_obj.hexdigest()[:16]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def emit_text_message_start(message_id: str, role: str = "assistant") -> AGUIEvent:
|
|
260
|
+
"""Create a TextMessageStart event."""
|
|
261
|
+
return _serialize_event(
|
|
262
|
+
TextMessageStartEvent(message_id=message_id, role=role) # type: ignore[arg-type]
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def emit_text_message_content(message_id: str, delta: str) -> AGUIEvent:
|
|
267
|
+
"""Create a TextMessageContent event."""
|
|
268
|
+
return _serialize_event(TextMessageContentEvent(message_id=message_id, delta=delta))
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def emit_text_message_end(message_id: str) -> AGUIEvent:
|
|
272
|
+
"""Create a TextMessageEnd event."""
|
|
273
|
+
return _serialize_event(TextMessageEndEvent(message_id=message_id))
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def emit_tool_call_start(
|
|
277
|
+
tool_call_id: str,
|
|
278
|
+
tool_name: str,
|
|
279
|
+
parent_message_id: str | None = None,
|
|
280
|
+
) -> AGUIEvent:
|
|
281
|
+
"""Create a ToolCallStart event."""
|
|
282
|
+
return _serialize_event(
|
|
283
|
+
ToolCallStartEvent(
|
|
284
|
+
tool_call_id=tool_call_id,
|
|
285
|
+
tool_call_name=tool_name,
|
|
286
|
+
parent_message_id=parent_message_id,
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def emit_tool_call_args(tool_call_id: str, args: dict[str, Any]) -> AGUIEvent:
|
|
292
|
+
"""Create a ToolCallArgs event."""
|
|
293
|
+
return _serialize_event(
|
|
294
|
+
ToolCallArgsEvent(tool_call_id=tool_call_id, delta=json.dumps(args))
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def emit_tool_call_end(tool_call_id: str) -> AGUIEvent:
|
|
299
|
+
"""Create a ToolCallEnd event."""
|
|
300
|
+
return _serialize_event(ToolCallEndEvent(tool_call_id=tool_call_id))
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def emit_run_started(thread_id: str, run_id: str) -> AGUIEvent:
|
|
304
|
+
"""Create a RunStarted event."""
|
|
305
|
+
return _serialize_event(RunStartedEvent(thread_id=thread_id, run_id=run_id))
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def emit_run_finished(thread_id: str, run_id: str, result: Any) -> AGUIEvent:
|
|
309
|
+
"""Create a RunFinished event."""
|
|
310
|
+
# Convert result to dict if it's a Pydantic model
|
|
311
|
+
if hasattr(result, "model_dump"):
|
|
312
|
+
result = result.model_dump()
|
|
313
|
+
return _serialize_event(
|
|
314
|
+
RunFinishedEvent(thread_id=thread_id, run_id=run_id, result=result)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def emit_run_error(message: str, code: str | None = None) -> AGUIEvent:
|
|
319
|
+
"""Create a RunError event."""
|
|
320
|
+
return _serialize_event(RunErrorEvent(message=message, code=code))
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def emit_step_started(step_name: str) -> AGUIEvent:
|
|
324
|
+
"""Create a StepStarted event."""
|
|
325
|
+
return _serialize_event(StepStartedEvent(step_name=step_name))
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def emit_step_finished(step_name: str) -> AGUIEvent:
|
|
329
|
+
"""Create a StepFinished event."""
|
|
330
|
+
return _serialize_event(StepFinishedEvent(step_name=step_name))
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def emit_text_message(content: str, role: str = "assistant") -> AGUIEvent:
|
|
334
|
+
"""Create a TextMessageChunk event (convenience wrapper)."""
|
|
335
|
+
message_id = str(uuid4())
|
|
336
|
+
return _serialize_event(
|
|
337
|
+
TextMessageChunkEvent(
|
|
338
|
+
message_id=message_id,
|
|
339
|
+
role=role, # type: ignore[arg-type]
|
|
340
|
+
delta=content,
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def emit_state_snapshot(state: BaseModel) -> AGUIEvent:
|
|
346
|
+
"""Create a StateSnapshot event."""
|
|
347
|
+
return _serialize_event(StateSnapshotEvent(snapshot=state.model_dump()))
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def emit_state_delta(old_state: BaseModel, new_state: BaseModel) -> AGUIEvent:
|
|
351
|
+
"""Create a StateDelta event with JSON Patch operations."""
|
|
352
|
+
delta = compute_state_delta(old_state, new_state)
|
|
353
|
+
return _serialize_event(StateDeltaEvent(delta=delta))
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def emit_activity(
|
|
357
|
+
message_id: str,
|
|
358
|
+
activity_type: str,
|
|
359
|
+
content: dict[str, Any],
|
|
360
|
+
) -> AGUIEvent:
|
|
361
|
+
"""Create an ActivitySnapshot event."""
|
|
362
|
+
return _serialize_event(
|
|
363
|
+
ActivitySnapshotEvent(
|
|
364
|
+
message_id=message_id,
|
|
365
|
+
activity_type=activity_type,
|
|
366
|
+
content=content,
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def emit_activity_delta(
|
|
372
|
+
message_id: str,
|
|
373
|
+
activity_type: str,
|
|
374
|
+
patch: list[dict[str, Any]],
|
|
375
|
+
) -> AGUIEvent:
|
|
376
|
+
"""Create an ActivityDelta event with JSON Patch operations."""
|
|
377
|
+
from ag_ui.core import ActivityDeltaEvent
|
|
378
|
+
|
|
379
|
+
return _serialize_event(
|
|
380
|
+
ActivityDeltaEvent(
|
|
381
|
+
message_id=message_id,
|
|
382
|
+
activity_type=activity_type,
|
|
383
|
+
patch=patch,
|
|
384
|
+
)
|
|
385
|
+
)
|