klea-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- klea_utils/__init__.py +0 -0
- klea_utils/api.py +50 -0
- klea_utils/errors.py +17 -0
- klea_utils/graph/__init__.py +0 -0
- klea_utils/graph/base.py +403 -0
- klea_utils/llm.py +462 -0
- klea_utils/nodes/__init__.py +0 -0
- klea_utils/nodes/abstract.py +187 -0
- klea_utils/nodes/answer_general.py +93 -0
- klea_utils/nodes/base.py +262 -0
- klea_utils/nodes/fixed_answer.py +37 -0
- klea_utils/nodes/guard.py +72 -0
- klea_utils/nodes/guard_router.py +46 -0
- klea_utils/nodes/prompts/AnswerGeneral_system.md +17 -0
- klea_utils/nodes/prompts/AnswerGeneral_user.md +1 -0
- klea_utils/nodes/prompts/GuardNode_system.md +1 -0
- klea_utils/nodes/prompts/GuardNode_user.md +1 -0
- klea_utils/nodes/prompts/SummariseMemoryNode_system.md +10 -0
- klea_utils/nodes/prompts/SummariseMemoryNode_user.md +13 -0
- klea_utils/nodes/prompts/__init__.py +0 -0
- klea_utils/nodes/summarise_memory.py +102 -0
- klea_utils/plogging.py +64 -0
- klea_utils/stores/__init__.py +9 -0
- klea_utils/stores/config.py +33 -0
- klea_utils/stores/ingestion.py +530 -0
- klea_utils/stores/retrieval.py +158 -0
- klea_utils/stores/utils.py +180 -0
- klea_utils/tools.py +66 -0
- klea_utils/ui/__init__.py +9 -0
- klea_utils/ui/vs_create.py +241 -0
- klea_utils-0.1.0.dist-info/METADATA +66 -0
- klea_utils-0.1.0.dist-info/RECORD +36 -0
- klea_utils-0.1.0.dist-info/WHEEL +5 -0
- klea_utils-0.1.0.dist-info/entry_points.txt +2 -0
- klea_utils-0.1.0.dist-info/licenses/LICENSE +674 -0
- klea_utils-0.1.0.dist-info/top_level.txt +1 -0
klea_utils/__init__.py
ADDED
|
File without changes
|
klea_utils/api.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
API related common utils
|
|
4
|
+
|
|
5
|
+
File: klea_utils/api.py
|
|
6
|
+
|
|
7
|
+
Copyright 2026 Ankur Sinha
|
|
8
|
+
Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from pydantic import AnyUrl
|
|
13
|
+
from pydantic import ValidationError as PydanticValidationError
|
|
14
|
+
from tenacity import (
|
|
15
|
+
retry,
|
|
16
|
+
retry_if_exception_type,
|
|
17
|
+
stop_after_attempt,
|
|
18
|
+
wait_random_exponential,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_url(value: str) -> str:
|
|
23
|
+
"""Return *value* if it is a valid HTTP(S) URL, else raise ``ValueError``."""
|
|
24
|
+
try:
|
|
25
|
+
AnyUrl(value)
|
|
26
|
+
except PydanticValidationError:
|
|
27
|
+
raise ValueError(f"'{value}' is not a valid HTTP(S) URL")
|
|
28
|
+
return value
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@retry(
|
|
32
|
+
wait=wait_random_exponential(multiplier=1, max=10),
|
|
33
|
+
stop=stop_after_attempt(10),
|
|
34
|
+
retry=retry_if_exception_type(
|
|
35
|
+
(httpx.ConnectError, httpx.HTTPStatusError, httpx.ReadError, httpx.ReadTimeout)
|
|
36
|
+
),
|
|
37
|
+
reraise=True,
|
|
38
|
+
)
|
|
39
|
+
async def check_api_is_ready(url: str):
|
|
40
|
+
"""Exponentially drop off checking that API is ready
|
|
41
|
+
|
|
42
|
+
:param url: url of health end point
|
|
43
|
+
:type url: str
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
async with httpx.AsyncClient() as client:
|
|
47
|
+
response = await client.get(url)
|
|
48
|
+
response.raise_for_status()
|
|
49
|
+
|
|
50
|
+
return response.json()
|
klea_utils/errors.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Custom errors
|
|
4
|
+
|
|
5
|
+
File: utils_pkg/klea_utils/errors.py
|
|
6
|
+
|
|
7
|
+
Copyright 2026 Ankur Sinha
|
|
8
|
+
Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LLMInitializationError(Exception):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PromptTemplateError(Exception):
|
|
17
|
+
pass
|
|
File without changes
|
klea_utils/graph/base.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Base class for LangGraph-based orchestrators
|
|
4
|
+
|
|
5
|
+
File: klea_utils/graph/base.py
|
|
6
|
+
|
|
7
|
+
Copyright 2026 Ankur Sinha
|
|
8
|
+
Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import sys
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from textwrap import dedent
|
|
18
|
+
from typing import Any, List, Literal, Type, final
|
|
19
|
+
|
|
20
|
+
from fastmcp import Client
|
|
21
|
+
from fastmcp.mcp_config import MCPConfig
|
|
22
|
+
from langgraph.checkpoint.memory import InMemorySaver
|
|
23
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
24
|
+
from mcp.types import Tool
|
|
25
|
+
from pydantic import BaseModel, create_model
|
|
26
|
+
|
|
27
|
+
from klea_utils.stores.config import VectorStoresConfig
|
|
28
|
+
from klea_utils.stores.retrieval import VSRetriever
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BaseLangGraph(ABC):
|
|
32
|
+
"""Abstract base class for LangGraph-based orchestrators.
|
|
33
|
+
|
|
34
|
+
Provides common infrastructure for:
|
|
35
|
+
- Configuration loading from env files
|
|
36
|
+
- MCP client creation from JSON config
|
|
37
|
+
- LLM model setup (delegated to subclasses)
|
|
38
|
+
- LangGraph compilation and execution
|
|
39
|
+
- Session checkpointing
|
|
40
|
+
- Dual-stream logging
|
|
41
|
+
|
|
42
|
+
Subclasses must implement:
|
|
43
|
+
- :meth:`_setup_models`: Create LLM model instances
|
|
44
|
+
- :meth:`_create_graph`: Build and compile the LangGraph
|
|
45
|
+
- Set ``env_class`` to the appropriate Pydantic settings class
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
#: Pydantic BaseSettings class for env loading.
|
|
49
|
+
#: Subclasses must set this to their AppEnv class.
|
|
50
|
+
env_class: Type[BaseModel]
|
|
51
|
+
|
|
52
|
+
#: Pydantic BaseModel class for configuration loading.
|
|
53
|
+
#: Subclasses must set this to their AppConfig class.
|
|
54
|
+
config_class: Type[BaseModel]
|
|
55
|
+
|
|
56
|
+
#: Name of the environment variable that controls the env file path.
|
|
57
|
+
env_var: str = "ENV_FILE"
|
|
58
|
+
|
|
59
|
+
#: Default config file name if the environment variable is not set.
|
|
60
|
+
env_file_default: str = "config.env"
|
|
61
|
+
|
|
62
|
+
#: Logger name for this orchestrator.
|
|
63
|
+
logger_name: str = "BaseLangGraph"
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
logging_level: int = logging.DEBUG,
|
|
68
|
+
memory: bool = True,
|
|
69
|
+
):
|
|
70
|
+
"""Initialise the base orchestrator.
|
|
71
|
+
|
|
72
|
+
:param logging_level: Logging level for the orchestrator
|
|
73
|
+
:param memory: Whether to enable checkpoint-based session memory
|
|
74
|
+
"""
|
|
75
|
+
self.env_file = os.getenv(self.env_var, self.env_file_default)
|
|
76
|
+
self.app_env: BaseModel
|
|
77
|
+
|
|
78
|
+
self.c_model = None
|
|
79
|
+
|
|
80
|
+
self.memory = memory
|
|
81
|
+
|
|
82
|
+
self.tools_description: dict[str, str] = {}
|
|
83
|
+
self.domain_mcp_configs: dict[str, MCPConfig] = {}
|
|
84
|
+
self.checkpointer: InMemorySaver | None = InMemorySaver() if memory else None
|
|
85
|
+
|
|
86
|
+
self.config_dict: dict[str, Any]
|
|
87
|
+
|
|
88
|
+
self.graph: CompiledStateGraph | None = None
|
|
89
|
+
|
|
90
|
+
self.mcp_config: MCPConfig | None = None
|
|
91
|
+
self.mcp_client: Client | None = None
|
|
92
|
+
self.mcp_tools: list[Tool] | None = None
|
|
93
|
+
|
|
94
|
+
self.stores_config: VectorStoresConfig | None = None
|
|
95
|
+
self.stores: VSRetriever | None = None
|
|
96
|
+
self.embedding_model: str | None = None
|
|
97
|
+
self.default_k: int = 5
|
|
98
|
+
self.k_max: int = 10
|
|
99
|
+
|
|
100
|
+
self.QueryDomainSchema: Type[BaseModel] | None = None
|
|
101
|
+
|
|
102
|
+
from klea_utils.plogging import setup_logger
|
|
103
|
+
|
|
104
|
+
self.logger = setup_logger(self.logger_name, stderr_level=logging_level)
|
|
105
|
+
|
|
106
|
+
def _load_env(self) -> None:
|
|
107
|
+
"""Load env file, and configuration
|
|
108
|
+
|
|
109
|
+
Uses ``self.env_class`` and ``self.env_file`` to locate and parse
|
|
110
|
+
the configuration file. Raises FileNotFoundError if the file does not exist.
|
|
111
|
+
"""
|
|
112
|
+
env_file_path = Path(self.env_file)
|
|
113
|
+
if not env_file_path.exists():
|
|
114
|
+
raise FileNotFoundError(
|
|
115
|
+
f"""Could not find env file: {self.env_file}. You can use the {self.env_var} environment variable to specify the env file."""
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.app_env = self.env_class(_env_file=self.env_file)
|
|
119
|
+
assert self.app_env
|
|
120
|
+
self.logger.debug(f"env file: {self.env_file}")
|
|
121
|
+
self.logger.debug(f"env: {self.app_env}")
|
|
122
|
+
|
|
123
|
+
if "app_config_file" in self.env_class.model_fields:
|
|
124
|
+
config_file = Path(self.app_env.app_config_file)
|
|
125
|
+
if not config_file.exists():
|
|
126
|
+
raise FileNotFoundError(f"Could not find config file: {config_file}")
|
|
127
|
+
else:
|
|
128
|
+
with open(config_file, "r") as f:
|
|
129
|
+
config_dict = json.load(f)
|
|
130
|
+
self.logger.debug(f"{config_dict = }")
|
|
131
|
+
self.app_config = self.config_class(**config_dict)
|
|
132
|
+
self.logger.debug(f"{self.app_config = }")
|
|
133
|
+
else:
|
|
134
|
+
raise FileNotFoundError(
|
|
135
|
+
f"No config file provided. Please provide one in the env file ({self.env_file})."
|
|
136
|
+
+ f"You can use the {self.env_var} environment variable to specify the env file."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _create_mcp_client(self) -> None:
|
|
140
|
+
"""Create MCP client from the JSON config file.
|
|
141
|
+
|
|
142
|
+
Reads the MCP server configurations from ``self.app_env.mcp_config_file``
|
|
143
|
+
and creates a ``fastmcp.Client`` instance.
|
|
144
|
+
"""
|
|
145
|
+
if self.mcp_config and self.mcp_config.mcpServers:
|
|
146
|
+
self.logger.debug(f"{self.mcp_config = }")
|
|
147
|
+
self.mcp_client = Client(self.mcp_config)
|
|
148
|
+
else:
|
|
149
|
+
self.logger.warning("No MCP server configured.")
|
|
150
|
+
self.mcp_client = None
|
|
151
|
+
|
|
152
|
+
async def _get_mcp_tools(self) -> None:
|
|
153
|
+
"""Get MCP tools."""
|
|
154
|
+
if self.mcp_client:
|
|
155
|
+
async with self.mcp_client:
|
|
156
|
+
self.mcp_tools = await self.mcp_client.list_tools()
|
|
157
|
+
self.logger.debug(f"{self.mcp_tools =}")
|
|
158
|
+
self._build_tools_description()
|
|
159
|
+
|
|
160
|
+
def _build_tools_description(self) -> None:
|
|
161
|
+
"""Build per-domain tool descriptions from fetched MCP tools."""
|
|
162
|
+
self.tools_description = {}
|
|
163
|
+
if not self.mcp_tools or not self.domain_mcp_configs:
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
# map server names to domains
|
|
167
|
+
domain_servers: dict[str, list[str]] = {}
|
|
168
|
+
num_servers = 0
|
|
169
|
+
for domain, config in self.domain_mcp_configs.items():
|
|
170
|
+
if config.mcpServers:
|
|
171
|
+
domain_servers[domain] = list(config.mcpServers.keys())
|
|
172
|
+
num_servers += len(list(config.mcpServers.keys()))
|
|
173
|
+
|
|
174
|
+
for domain, server_names in domain_servers.items():
|
|
175
|
+
desc = ""
|
|
176
|
+
ctr = 0
|
|
177
|
+
for t in self.mcp_tools:
|
|
178
|
+
if "dummy" in t.name:
|
|
179
|
+
continue
|
|
180
|
+
# tools will be prefixed with server names
|
|
181
|
+
if num_servers > 1:
|
|
182
|
+
if not any(t.name.startswith(s + "_") for s in server_names):
|
|
183
|
+
continue
|
|
184
|
+
# otherwise, there's only one server
|
|
185
|
+
ctr += 1
|
|
186
|
+
desc += dedent(f"""
|
|
187
|
+
## {ctr}. {t.name}
|
|
188
|
+
|
|
189
|
+
### Description
|
|
190
|
+
|
|
191
|
+
{t.description}
|
|
192
|
+
|
|
193
|
+
""")
|
|
194
|
+
if t.inputSchema:
|
|
195
|
+
desc += dedent(f"""
|
|
196
|
+
### Parameters
|
|
197
|
+
|
|
198
|
+
{t.inputSchema.get("properties")}
|
|
199
|
+
|
|
200
|
+
""")
|
|
201
|
+
self.tools_description[domain] = desc
|
|
202
|
+
|
|
203
|
+
async def _get_vector_stores(self) -> None:
|
|
204
|
+
"""Get vector stores"""
|
|
205
|
+
if self.stores_config and self.embedding_model:
|
|
206
|
+
self.stores = VSRetriever(
|
|
207
|
+
vs_config=self.stores_config,
|
|
208
|
+
logger=self.logger,
|
|
209
|
+
embedding_model=self.embedding_model,
|
|
210
|
+
default_k=self.default_k,
|
|
211
|
+
k_max=self.k_max,
|
|
212
|
+
)
|
|
213
|
+
self.stores.setup()
|
|
214
|
+
self.logger.info(f"Vector stores loaded: {self.stores.domains}")
|
|
215
|
+
|
|
216
|
+
# dynamically generate schema for domains
|
|
217
|
+
all_domains = self.stores.domains.copy()
|
|
218
|
+
all_domains.append("undefined")
|
|
219
|
+
|
|
220
|
+
self.QueryDomainSchema = create_model(
|
|
221
|
+
"QueryDomainSchema",
|
|
222
|
+
query_domains=(List[Literal[tuple(all_domains)]], "undefined"),
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
self.logger.warning("No vector stores configured.")
|
|
226
|
+
|
|
227
|
+
def _export_graph_png(self, filename: str) -> None:
|
|
228
|
+
"""Export the LangGraph as a Mermaid PNG diagram.
|
|
229
|
+
|
|
230
|
+
Skipped when running inside Docker (``RUNNING_IN_DOCKER`` env var set).
|
|
231
|
+
|
|
232
|
+
:param filename: Output file path for the PNG
|
|
233
|
+
"""
|
|
234
|
+
if os.environ.get("RUNNING_IN_DOCKER", 0):
|
|
235
|
+
return
|
|
236
|
+
try:
|
|
237
|
+
assert self.graph
|
|
238
|
+
self.graph.get_graph().draw_mermaid_png(output_file_path=filename)
|
|
239
|
+
except BaseException as e:
|
|
240
|
+
self.logger.error("Something went wrong generating lang graph png")
|
|
241
|
+
self.logger.error(e)
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# Abstract methods -- subclasses must implement these
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
@abstractmethod
|
|
248
|
+
def _configure_resources(self) -> None:
|
|
249
|
+
"""Configure vector stores and MCP servers
|
|
250
|
+
|
|
251
|
+
Subclasses should implement this to populate ``self.stores_config``,
|
|
252
|
+
``self.mcp_config``, and ``self.domain_mcp_configs``, which will be used
|
|
253
|
+
to create the vector store class, mcp client, and per-domain tool descriptions.
|
|
254
|
+
"""
|
|
255
|
+
...
|
|
256
|
+
|
|
257
|
+
@abstractmethod
|
|
258
|
+
def _setup_models(self) -> None:
|
|
259
|
+
"""Set up LLM model instances.
|
|
260
|
+
|
|
261
|
+
Subclasses should assign model instances to ``self.c_model`` (and
|
|
262
|
+
optionally ``self.r_model`` for reasoning models).
|
|
263
|
+
"""
|
|
264
|
+
...
|
|
265
|
+
|
|
266
|
+
@abstractmethod
|
|
267
|
+
async def _create_graph(self) -> None:
|
|
268
|
+
"""Build and compile the LangGraph, storing it in ``self.graph``.
|
|
269
|
+
|
|
270
|
+
This is where subclasses define their nodes, edges, and conditional
|
|
271
|
+
routing logic.
|
|
272
|
+
"""
|
|
273
|
+
...
|
|
274
|
+
|
|
275
|
+
# ------------------------------------------------------------------
|
|
276
|
+
# Hook methods -- override for pre/post setup work
|
|
277
|
+
# ------------------------------------------------------------------
|
|
278
|
+
|
|
279
|
+
def _pre_setup(self) -> None:
|
|
280
|
+
"""Hook called before the standard setup sequence.
|
|
281
|
+
|
|
282
|
+
Override to perform subclass-specific initialisation before
|
|
283
|
+
config loading and model setup.
|
|
284
|
+
"""
|
|
285
|
+
pass
|
|
286
|
+
|
|
287
|
+
def _post_setup(self) -> None:
|
|
288
|
+
"""Hook called after the standard setup sequence.
|
|
289
|
+
|
|
290
|
+
Override to perform subclass-specific finalisation after
|
|
291
|
+
the LangGraph has been compiled.
|
|
292
|
+
"""
|
|
293
|
+
pass
|
|
294
|
+
|
|
295
|
+
# ------------------------------------------------------------------
|
|
296
|
+
# Template method
|
|
297
|
+
# ------------------------------------------------------------------
|
|
298
|
+
|
|
299
|
+
async def _pre_graph(self) -> None:
|
|
300
|
+
"""Hook called after MCP client setup but before graph compilation.
|
|
301
|
+
|
|
302
|
+
Override to perform subclass-specific initialisation that depends
|
|
303
|
+
on config and MCP client but must happen before the LangGraph is built.
|
|
304
|
+
"""
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
@final
|
|
308
|
+
async def setup(self) -> None:
|
|
309
|
+
"""Set up the orchestrator.
|
|
310
|
+
|
|
311
|
+
Calls hooks and template methods in this order:
|
|
312
|
+
1. ``_pre_setup()``
|
|
313
|
+
2. ``_load_env()``
|
|
314
|
+
3. ``_setup_models()``
|
|
315
|
+
4. ``_create_mcp_client()``
|
|
316
|
+
5. ``_pre_graph()``
|
|
317
|
+
6. ``_create_graph()``
|
|
318
|
+
7. ``_post_setup()``
|
|
319
|
+
"""
|
|
320
|
+
self._pre_setup()
|
|
321
|
+
self._load_env()
|
|
322
|
+
self._configure_resources()
|
|
323
|
+
self._setup_models()
|
|
324
|
+
self._create_mcp_client()
|
|
325
|
+
await self._get_mcp_tools()
|
|
326
|
+
await self._get_vector_stores()
|
|
327
|
+
await self._pre_graph()
|
|
328
|
+
await self._create_graph()
|
|
329
|
+
self._post_setup()
|
|
330
|
+
|
|
331
|
+
# ------------------------------------------------------------------
|
|
332
|
+
# Execution methods -- identical across all implementations
|
|
333
|
+
# ------------------------------------------------------------------
|
|
334
|
+
|
|
335
|
+
async def run_graph_invoke_state(
|
|
336
|
+
self, state: dict, thread_id: str = "default_thread"
|
|
337
|
+
) -> dict:
|
|
338
|
+
"""Run the graph, accepting and returning full state dicts.
|
|
339
|
+
|
|
340
|
+
:param state: Initial graph state (must contain ``query`` key)
|
|
341
|
+
:param thread_id: Session/thread identifier for checkpointing
|
|
342
|
+
:returns: Final graph state
|
|
343
|
+
"""
|
|
344
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
345
|
+
|
|
346
|
+
if "query" not in state:
|
|
347
|
+
self.logger.error(f"Provided state should include the key 'query': {state}")
|
|
348
|
+
sys.exit(-1)
|
|
349
|
+
|
|
350
|
+
final_state = await self.graph.ainvoke(state, config=config)
|
|
351
|
+
self.logger.debug(final_state)
|
|
352
|
+
return final_state
|
|
353
|
+
|
|
354
|
+
# TODO: fields to be extracted from the final state to be returned should
|
|
355
|
+
# be configurable with a schema
|
|
356
|
+
async def run_graph_invoke(
|
|
357
|
+
self, query: str, thread_id: str = "default_thread"
|
|
358
|
+
) -> str:
|
|
359
|
+
"""Run the graph with a simple string query.
|
|
360
|
+
|
|
361
|
+
:param query: User query string
|
|
362
|
+
:param thread_id: Session/thread identifier for checkpointing
|
|
363
|
+
:returns: The ``message_for_user`` field from the final state
|
|
364
|
+
"""
|
|
365
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
366
|
+
|
|
367
|
+
final_state = await self.graph.ainvoke({"query": query}, config=config)
|
|
368
|
+
|
|
369
|
+
self.logger.debug(f"{final_state =}")
|
|
370
|
+
if message := final_state.get("message_for_user", None):
|
|
371
|
+
return message
|
|
372
|
+
else:
|
|
373
|
+
return "I was unable to answer"
|
|
374
|
+
|
|
375
|
+
async def run_graph_stream(self, query: str, thread_id: str = "default_thread"):
|
|
376
|
+
"""Run the graph and yield intermediate ``message_for_user`` values.
|
|
377
|
+
|
|
378
|
+
:param query: User query string
|
|
379
|
+
:param thread_id: Session/thread identifier for checkpointing
|
|
380
|
+
:yields: ``message_for_user`` strings from each node
|
|
381
|
+
"""
|
|
382
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
383
|
+
|
|
384
|
+
for chunk in self.graph.astream({"query": query}, config=config):
|
|
385
|
+
for node, state in chunk.items():
|
|
386
|
+
self.logger.debug(f"{node}: {repr(state)}")
|
|
387
|
+
if message := state.get("message_for_user", None):
|
|
388
|
+
self.logger.info(f"User message: {message}")
|
|
389
|
+
yield message
|
|
390
|
+
else:
|
|
391
|
+
self.logger.debug(f"Working in node: {node}")
|
|
392
|
+
|
|
393
|
+
async def graph_stream(self, query: str, thread_id: str = "default_thread") -> Any:
|
|
394
|
+
"""Run the graph and return the raw astream result.
|
|
395
|
+
|
|
396
|
+
:param query: User query string
|
|
397
|
+
:param thread_id: Session/thread identifier for checkpointing
|
|
398
|
+
:returns: Raw async generator from ``graph.astream()``
|
|
399
|
+
"""
|
|
400
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
401
|
+
|
|
402
|
+
res = await self.graph.astream({"query": query}, config=config)
|
|
403
|
+
return res
|