klea-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
klea_utils/__init__.py ADDED
File without changes
klea_utils/api.py ADDED
@@ -0,0 +1,50 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ API related common utils
4
+
5
+ File: klea_utils/api.py
6
+
7
+ Copyright 2026 Ankur Sinha
8
+ Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
9
+ """
10
+
11
+ import httpx
12
+ from pydantic import AnyUrl
13
+ from pydantic import ValidationError as PydanticValidationError
14
+ from tenacity import (
15
+ retry,
16
+ retry_if_exception_type,
17
+ stop_after_attempt,
18
+ wait_random_exponential,
19
+ )
20
+
21
+
22
+ def validate_url(value: str) -> str:
23
+ """Return *value* if it is a valid HTTP(S) URL, else raise ``ValueError``."""
24
+ try:
25
+ AnyUrl(value)
26
+ except PydanticValidationError:
27
+ raise ValueError(f"'{value}' is not a valid HTTP(S) URL")
28
+ return value
29
+
30
+
31
+ @retry(
32
+ wait=wait_random_exponential(multiplier=1, max=10),
33
+ stop=stop_after_attempt(10),
34
+ retry=retry_if_exception_type(
35
+ (httpx.ConnectError, httpx.HTTPStatusError, httpx.ReadError, httpx.ReadTimeout)
36
+ ),
37
+ reraise=True,
38
+ )
39
+ async def check_api_is_ready(url: str):
40
+ """Exponentially drop off checking that API is ready
41
+
42
+ :param url: url of health end point
43
+ :type url: str
44
+
45
+ """
46
+ async with httpx.AsyncClient() as client:
47
+ response = await client.get(url)
48
+ response.raise_for_status()
49
+
50
+ return response.json()
klea_utils/errors.py ADDED
@@ -0,0 +1,17 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Custom errors
4
+
5
+ File: utils_pkg/klea_utils/errors.py
6
+
7
+ Copyright 2026 Ankur Sinha
8
+ Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
9
+ """
10
+
11
+
12
+ class LLMInitializationError(Exception):
13
+ pass
14
+
15
+
16
+ class PromptTemplateError(Exception):
17
+ pass
File without changes
@@ -0,0 +1,403 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Base class for LangGraph-based orchestrators
4
+
5
+ File: klea_utils/graph/base.py
6
+
7
+ Copyright 2026 Ankur Sinha
8
+ Author: Ankur Sinha <sanjay DOT ankur AT gmail DOT com>
9
+ """
10
+
11
+ import json
12
+ import logging
13
+ import os
14
+ import sys
15
+ from abc import ABC, abstractmethod
16
+ from pathlib import Path
17
+ from textwrap import dedent
18
+ from typing import Any, List, Literal, Type, final
19
+
20
+ from fastmcp import Client
21
+ from fastmcp.mcp_config import MCPConfig
22
+ from langgraph.checkpoint.memory import InMemorySaver
23
+ from langgraph.graph.state import CompiledStateGraph
24
+ from mcp.types import Tool
25
+ from pydantic import BaseModel, create_model
26
+
27
+ from klea_utils.stores.config import VectorStoresConfig
28
+ from klea_utils.stores.retrieval import VSRetriever
29
+
30
+
31
+ class BaseLangGraph(ABC):
32
+ """Abstract base class for LangGraph-based orchestrators.
33
+
34
+ Provides common infrastructure for:
35
+ - Configuration loading from env files
36
+ - MCP client creation from JSON config
37
+ - LLM model setup (delegated to subclasses)
38
+ - LangGraph compilation and execution
39
+ - Session checkpointing
40
+ - Dual-stream logging
41
+
42
+ Subclasses must implement:
43
+ - :meth:`_setup_models`: Create LLM model instances
44
+ - :meth:`_create_graph`: Build and compile the LangGraph
45
+ - Set ``env_class`` to the appropriate Pydantic settings class
46
+ """
47
+
48
+ #: Pydantic BaseSettings class for env loading.
49
+ #: Subclasses must set this to their AppEnv class.
50
+ env_class: Type[BaseModel]
51
+
52
+ #: Pydantic BaseModel class for configuration loading.
53
+ #: Subclasses must set this to their AppConfig class.
54
+ config_class: Type[BaseModel]
55
+
56
+ #: Name of the environment variable that controls the env file path.
57
+ env_var: str = "ENV_FILE"
58
+
59
+ #: Default config file name if the environment variable is not set.
60
+ env_file_default: str = "config.env"
61
+
62
+ #: Logger name for this orchestrator.
63
+ logger_name: str = "BaseLangGraph"
64
+
65
+ def __init__(
66
+ self,
67
+ logging_level: int = logging.DEBUG,
68
+ memory: bool = True,
69
+ ):
70
+ """Initialise the base orchestrator.
71
+
72
+ :param logging_level: Logging level for the orchestrator
73
+ :param memory: Whether to enable checkpoint-based session memory
74
+ """
75
+ self.env_file = os.getenv(self.env_var, self.env_file_default)
76
+ self.app_env: BaseModel
77
+
78
+ self.c_model = None
79
+
80
+ self.memory = memory
81
+
82
+ self.tools_description: dict[str, str] = {}
83
+ self.domain_mcp_configs: dict[str, MCPConfig] = {}
84
+ self.checkpointer: InMemorySaver | None = InMemorySaver() if memory else None
85
+
86
+ self.config_dict: dict[str, Any]
87
+
88
+ self.graph: CompiledStateGraph | None = None
89
+
90
+ self.mcp_config: MCPConfig | None = None
91
+ self.mcp_client: Client | None = None
92
+ self.mcp_tools: list[Tool] | None = None
93
+
94
+ self.stores_config: VectorStoresConfig | None = None
95
+ self.stores: VSRetriever | None = None
96
+ self.embedding_model: str | None = None
97
+ self.default_k: int = 5
98
+ self.k_max: int = 10
99
+
100
+ self.QueryDomainSchema: Type[BaseModel] | None = None
101
+
102
+ from klea_utils.plogging import setup_logger
103
+
104
+ self.logger = setup_logger(self.logger_name, stderr_level=logging_level)
105
+
106
+ def _load_env(self) -> None:
107
+ """Load env file, and configuration
108
+
109
+ Uses ``self.env_class`` and ``self.env_file`` to locate and parse
110
+ the configuration file. Raises FileNotFoundError if the file does not exist.
111
+ """
112
+ env_file_path = Path(self.env_file)
113
+ if not env_file_path.exists():
114
+ raise FileNotFoundError(
115
+ f"""Could not find env file: {self.env_file}. You can use the {self.env_var} environment variable to specify the env file."""
116
+ )
117
+
118
+ self.app_env = self.env_class(_env_file=self.env_file)
119
+ assert self.app_env
120
+ self.logger.debug(f"env file: {self.env_file}")
121
+ self.logger.debug(f"env: {self.app_env}")
122
+
123
+ if "app_config_file" in self.env_class.model_fields:
124
+ config_file = Path(self.app_env.app_config_file)
125
+ if not config_file.exists():
126
+ raise FileNotFoundError(f"Could not find config file: {config_file}")
127
+ else:
128
+ with open(config_file, "r") as f:
129
+ config_dict = json.load(f)
130
+ self.logger.debug(f"{config_dict = }")
131
+ self.app_config = self.config_class(**config_dict)
132
+ self.logger.debug(f"{self.app_config = }")
133
+ else:
134
+ raise FileNotFoundError(
135
+ f"No config file provided. Please provide one in the env file ({self.env_file})."
136
+ + f"You can use the {self.env_var} environment variable to specify the env file."
137
+ )
138
+
139
+ def _create_mcp_client(self) -> None:
140
+ """Create MCP client from the JSON config file.
141
+
142
+ Reads the MCP server configurations from ``self.app_env.mcp_config_file``
143
+ and creates a ``fastmcp.Client`` instance.
144
+ """
145
+ if self.mcp_config and self.mcp_config.mcpServers:
146
+ self.logger.debug(f"{self.mcp_config = }")
147
+ self.mcp_client = Client(self.mcp_config)
148
+ else:
149
+ self.logger.warning("No MCP server configured.")
150
+ self.mcp_client = None
151
+
152
+ async def _get_mcp_tools(self) -> None:
153
+ """Get MCP tools."""
154
+ if self.mcp_client:
155
+ async with self.mcp_client:
156
+ self.mcp_tools = await self.mcp_client.list_tools()
157
+ self.logger.debug(f"{self.mcp_tools =}")
158
+ self._build_tools_description()
159
+
160
+ def _build_tools_description(self) -> None:
161
+ """Build per-domain tool descriptions from fetched MCP tools."""
162
+ self.tools_description = {}
163
+ if not self.mcp_tools or not self.domain_mcp_configs:
164
+ return
165
+
166
+ # map server names to domains
167
+ domain_servers: dict[str, list[str]] = {}
168
+ num_servers = 0
169
+ for domain, config in self.domain_mcp_configs.items():
170
+ if config.mcpServers:
171
+ domain_servers[domain] = list(config.mcpServers.keys())
172
+ num_servers += len(list(config.mcpServers.keys()))
173
+
174
+ for domain, server_names in domain_servers.items():
175
+ desc = ""
176
+ ctr = 0
177
+ for t in self.mcp_tools:
178
+ if "dummy" in t.name:
179
+ continue
180
+ # tools will be prefixed with server names
181
+ if num_servers > 1:
182
+ if not any(t.name.startswith(s + "_") for s in server_names):
183
+ continue
184
+ # otherwise, there's only one server
185
+ ctr += 1
186
+ desc += dedent(f"""
187
+ ## {ctr}. {t.name}
188
+
189
+ ### Description
190
+
191
+ {t.description}
192
+
193
+ """)
194
+ if t.inputSchema:
195
+ desc += dedent(f"""
196
+ ### Parameters
197
+
198
+ {t.inputSchema.get("properties")}
199
+
200
+ """)
201
+ self.tools_description[domain] = desc
202
+
203
+ async def _get_vector_stores(self) -> None:
204
+ """Get vector stores"""
205
+ if self.stores_config and self.embedding_model:
206
+ self.stores = VSRetriever(
207
+ vs_config=self.stores_config,
208
+ logger=self.logger,
209
+ embedding_model=self.embedding_model,
210
+ default_k=self.default_k,
211
+ k_max=self.k_max,
212
+ )
213
+ self.stores.setup()
214
+ self.logger.info(f"Vector stores loaded: {self.stores.domains}")
215
+
216
+ # dynamically generate schema for domains
217
+ all_domains = self.stores.domains.copy()
218
+ all_domains.append("undefined")
219
+
220
+ self.QueryDomainSchema = create_model(
221
+ "QueryDomainSchema",
222
+ query_domains=(List[Literal[tuple(all_domains)]], "undefined"),
223
+ )
224
+ else:
225
+ self.logger.warning("No vector stores configured.")
226
+
227
+ def _export_graph_png(self, filename: str) -> None:
228
+ """Export the LangGraph as a Mermaid PNG diagram.
229
+
230
+ Skipped when running inside Docker (``RUNNING_IN_DOCKER`` env var set).
231
+
232
+ :param filename: Output file path for the PNG
233
+ """
234
+ if os.environ.get("RUNNING_IN_DOCKER", 0):
235
+ return
236
+ try:
237
+ assert self.graph
238
+ self.graph.get_graph().draw_mermaid_png(output_file_path=filename)
239
+ except BaseException as e:
240
+ self.logger.error("Something went wrong generating lang graph png")
241
+ self.logger.error(e)
242
+
243
+ # ------------------------------------------------------------------
244
+ # Abstract methods -- subclasses must implement these
245
+ # ------------------------------------------------------------------
246
+
247
+ @abstractmethod
248
+ def _configure_resources(self) -> None:
249
+ """Configure vector stores and MCP servers
250
+
251
+ Subclasses should implement this to populate ``self.stores_config``,
252
+ ``self.mcp_config``, and ``self.domain_mcp_configs``, which will be used
253
+ to create the vector store class, mcp client, and per-domain tool descriptions.
254
+ """
255
+ ...
256
+
257
+ @abstractmethod
258
+ def _setup_models(self) -> None:
259
+ """Set up LLM model instances.
260
+
261
+ Subclasses should assign model instances to ``self.c_model`` (and
262
+ optionally ``self.r_model`` for reasoning models).
263
+ """
264
+ ...
265
+
266
+ @abstractmethod
267
+ async def _create_graph(self) -> None:
268
+ """Build and compile the LangGraph, storing it in ``self.graph``.
269
+
270
+ This is where subclasses define their nodes, edges, and conditional
271
+ routing logic.
272
+ """
273
+ ...
274
+
275
+ # ------------------------------------------------------------------
276
+ # Hook methods -- override for pre/post setup work
277
+ # ------------------------------------------------------------------
278
+
279
+ def _pre_setup(self) -> None:
280
+ """Hook called before the standard setup sequence.
281
+
282
+ Override to perform subclass-specific initialisation before
283
+ config loading and model setup.
284
+ """
285
+ pass
286
+
287
+ def _post_setup(self) -> None:
288
+ """Hook called after the standard setup sequence.
289
+
290
+ Override to perform subclass-specific finalisation after
291
+ the LangGraph has been compiled.
292
+ """
293
+ pass
294
+
295
+ # ------------------------------------------------------------------
296
+ # Template method
297
+ # ------------------------------------------------------------------
298
+
299
+ async def _pre_graph(self) -> None:
300
+ """Hook called after MCP client setup but before graph compilation.
301
+
302
+ Override to perform subclass-specific initialisation that depends
303
+ on config and MCP client but must happen before the LangGraph is built.
304
+ """
305
+ pass
306
+
307
+ @final
308
+ async def setup(self) -> None:
309
+ """Set up the orchestrator.
310
+
311
+ Calls hooks and template methods in this order:
312
+ 1. ``_pre_setup()``
313
+ 2. ``_load_env()``
314
+ 3. ``_setup_models()``
315
+ 4. ``_create_mcp_client()``
316
+ 5. ``_pre_graph()``
317
+ 6. ``_create_graph()``
318
+ 7. ``_post_setup()``
319
+ """
320
+ self._pre_setup()
321
+ self._load_env()
322
+ self._configure_resources()
323
+ self._setup_models()
324
+ self._create_mcp_client()
325
+ await self._get_mcp_tools()
326
+ await self._get_vector_stores()
327
+ await self._pre_graph()
328
+ await self._create_graph()
329
+ self._post_setup()
330
+
331
+ # ------------------------------------------------------------------
332
+ # Execution methods -- identical across all implementations
333
+ # ------------------------------------------------------------------
334
+
335
+ async def run_graph_invoke_state(
336
+ self, state: dict, thread_id: str = "default_thread"
337
+ ) -> dict:
338
+ """Run the graph, accepting and returning full state dicts.
339
+
340
+ :param state: Initial graph state (must contain ``query`` key)
341
+ :param thread_id: Session/thread identifier for checkpointing
342
+ :returns: Final graph state
343
+ """
344
+ config = {"configurable": {"thread_id": thread_id}}
345
+
346
+ if "query" not in state:
347
+ self.logger.error(f"Provided state should include the key 'query': {state}")
348
+ sys.exit(-1)
349
+
350
+ final_state = await self.graph.ainvoke(state, config=config)
351
+ self.logger.debug(final_state)
352
+ return final_state
353
+
354
+ # TODO: fields to be extracted from the final state to be returned should
355
+ # be configurable with a schema
356
+ async def run_graph_invoke(
357
+ self, query: str, thread_id: str = "default_thread"
358
+ ) -> str:
359
+ """Run the graph with a simple string query.
360
+
361
+ :param query: User query string
362
+ :param thread_id: Session/thread identifier for checkpointing
363
+ :returns: The ``message_for_user`` field from the final state
364
+ """
365
+ config = {"configurable": {"thread_id": thread_id}}
366
+
367
+ final_state = await self.graph.ainvoke({"query": query}, config=config)
368
+
369
+ self.logger.debug(f"{final_state =}")
370
+ if message := final_state.get("message_for_user", None):
371
+ return message
372
+ else:
373
+ return "I was unable to answer"
374
+
375
+ async def run_graph_stream(self, query: str, thread_id: str = "default_thread"):
376
+ """Run the graph and yield intermediate ``message_for_user`` values.
377
+
378
+ :param query: User query string
379
+ :param thread_id: Session/thread identifier for checkpointing
380
+ :yields: ``message_for_user`` strings from each node
381
+ """
382
+ config = {"configurable": {"thread_id": thread_id}}
383
+
384
+ for chunk in self.graph.astream({"query": query}, config=config):
385
+ for node, state in chunk.items():
386
+ self.logger.debug(f"{node}: {repr(state)}")
387
+ if message := state.get("message_for_user", None):
388
+ self.logger.info(f"User message: {message}")
389
+ yield message
390
+ else:
391
+ self.logger.debug(f"Working in node: {node}")
392
+
393
+ async def graph_stream(self, query: str, thread_id: str = "default_thread") -> Any:
394
+ """Run the graph and return the raw astream result.
395
+
396
+ :param query: User query string
397
+ :param thread_id: Session/thread identifier for checkpointing
398
+ :returns: Raw async generator from ``graph.astream()``
399
+ """
400
+ config = {"configurable": {"thread_id": thread_id}}
401
+
402
+ res = await self.graph.astream({"query": query}, config=config)
403
+ return res