langroid 0.19.5__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,649 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import time
5
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
6
+
7
+ from arango.client import ArangoClient
8
+ from arango.database import StandardDatabase
9
+ from arango.exceptions import ArangoError, ServerConnectionError
10
+ from numpy import ceil
11
+ from rich import print
12
+ from rich.console import Console
13
+
14
+ from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
15
+ from langroid.agent.chat_document import ChatDocument
16
+ from langroid.agent.special.arangodb.system_messages import (
17
+ ADDRESSING_INSTRUCTION,
18
+ DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE,
19
+ DONE_INSTRUCTION,
20
+ SCHEMA_PROVIDED_SYS_MSG,
21
+ SCHEMA_TOOLS_SYS_MSG,
22
+ )
23
+ from langroid.agent.special.arangodb.tools import (
24
+ AQLCreationTool,
25
+ AQLRetrievalTool,
26
+ ArangoSchemaTool,
27
+ aql_retrieval_tool_name,
28
+ arango_schema_tool_name,
29
+ )
30
+ from langroid.agent.special.arangodb.utils import count_fields, trim_schema
31
+ from langroid.agent.tools.orchestration import DoneTool, ForwardTool
32
+ from langroid.exceptions import LangroidImportError
33
+ from langroid.mytypes import Entity
34
+ from langroid.pydantic_v1 import BaseModel, BaseSettings
35
+ from langroid.utils.constants import SEND_TO
36
+
37
+ logger = logging.getLogger(__name__)
38
+ console = Console()
39
+
40
+ ARANGO_ERROR_MSG = "There was an error in your AQL Query"
41
+ T = TypeVar("T")
42
+
43
+
44
+ class ArangoSettings(BaseSettings):
45
+ client: ArangoClient | None = None
46
+ db: StandardDatabase | None = None
47
+ url: str = ""
48
+ username: str = ""
49
+ password: str = ""
50
+ database: str = ""
51
+
52
+ class Config:
53
+ env_prefix = "ARANGO_"
54
+
55
+
56
+ class QueryResult(BaseModel):
57
+ success: bool
58
+ data: Optional[
59
+ Union[
60
+ str,
61
+ int,
62
+ float,
63
+ bool,
64
+ None,
65
+ List[Any],
66
+ Dict[str, Any],
67
+ List[Dict[str, Any]],
68
+ ]
69
+ ] = None
70
+
71
+ class Config:
72
+ # Allow arbitrary types for flexibility
73
+ arbitrary_types_allowed = True
74
+
75
+ # Handle JSON serialization of special types
76
+ json_encoders = {
77
+ # Add custom encoders if needed, e.g.:
78
+ datetime.datetime: lambda v: v.isoformat(),
79
+ # Could add others for specific ArangoDB types
80
+ }
81
+
82
+ # Validate all assignments
83
+ validate_assignment = True
84
+
85
+ # Frozen=True if we want immutability
86
+ frozen = False
87
+
88
+
89
+ class ArangoChatAgentConfig(ChatAgentConfig):
90
+ arango_settings: ArangoSettings = ArangoSettings()
91
+ system_message: str = DEFAULT_ARANGO_CHAT_SYSTEM_MESSAGE
92
+ kg_schema: str | Dict[str, List[Dict[str, Any]]] | None = None
93
+ database_created: bool = False
94
+ prepopulate_schema: bool = True
95
+ use_functions_api: bool = True
96
+ max_num_results: int = 10 # how many results to return from AQL query
97
+ max_result_tokens: int = 1000 # truncate long results to this many tokens
98
+ max_schema_fields: int = 500 # max fields to show in schema
99
+ max_tries: int = 10 # how many attempts to answer user question
100
+ use_tools: bool = False
101
+ schema_sample_pct: float = 0
102
+ # whether the agent is used in a continuous chat with user,
103
+ # as opposed to returning a result from the task.run()
104
+ chat_mode: bool = False
105
+ addressing_prefix: str = ""
106
+
107
+
108
+ class ArangoChatAgent(ChatAgent):
109
+ def __init__(self, config: ArangoChatAgentConfig):
110
+ super().__init__(config)
111
+ self.config: ArangoChatAgentConfig = config
112
+ self.init_state()
113
+ self._validate_config()
114
+ self._import_arango()
115
+ self._initialize_db()
116
+ self._init_tools_sys_message()
117
+
118
+ def init_state(self) -> None:
119
+ super().init_state()
120
+ self.current_retrieval_aql_query: str = ""
121
+ self.num_tries = 0 # how many attempts to answer user question
122
+
123
+ def user_response(
124
+ self,
125
+ msg: Optional[str | ChatDocument] = None,
126
+ ) -> Optional[ChatDocument]:
127
+ response = super().user_response(msg)
128
+ response_str = response.content if response is not None else ""
129
+ if response_str != "":
130
+ self.num_tries = 0 # reset number of tries if user responds
131
+ return response
132
+
133
+ def llm_response(
134
+ self, message: Optional[str | ChatDocument] = None
135
+ ) -> Optional[ChatDocument]:
136
+ if self.num_tries > self.config.max_tries:
137
+ if self.config.chat_mode:
138
+ return self.create_llm_response(
139
+ content=f"""
140
+ {self.config.addressing_prefix}User
141
+ I give up, since I have exceeded the
142
+ maximum number of tries ({self.config.max_tries}).
143
+ Feel free to give me some hints!
144
+ """
145
+ )
146
+ else:
147
+ return self.create_llm_response(
148
+ tool_messages=[
149
+ DoneTool(
150
+ content=f"""
151
+ Exceeded maximum number of tries ({self.config.max_tries}).
152
+ """
153
+ )
154
+ ]
155
+ )
156
+
157
+ if isinstance(message, ChatDocument) and message.metadata.sender == Entity.USER:
158
+ message.content = (
159
+ message.content
160
+ + "\n"
161
+ + """
162
+ (REMEMBER, Do NOT use more than ONE TOOL/FUNCTION at a time!
163
+ you must WAIT for a helper to send you the RESULT(S) before
164
+ making another TOOL/FUNCTION call)
165
+ """
166
+ )
167
+
168
+ return super().llm_response(message)
169
+
170
+ def _validate_config(self) -> None:
171
+ assert isinstance(self.config, ArangoChatAgentConfig)
172
+ if (
173
+ self.config.arango_settings.client is None
174
+ or self.config.arango_settings.db is None
175
+ ):
176
+ if not all(
177
+ [
178
+ self.config.arango_settings.url,
179
+ self.config.arango_settings.username,
180
+ self.config.arango_settings.password,
181
+ self.config.arango_settings.database,
182
+ ]
183
+ ):
184
+ raise ValueError("ArangoDB connection info must be provided")
185
+
186
+ def _import_arango(self) -> None:
187
+ global ArangoClient
188
+ try:
189
+ from arango.client import ArangoClient
190
+ except ImportError:
191
+ raise LangroidImportError("python-arango", "arango")
192
+
193
+ def _has_any_data(self) -> bool:
194
+ for c in self.db.collections(): # type: ignore
195
+ if c["name"].startswith("_"):
196
+ continue
197
+ if self.db.collection(c["name"]).count() > 0: # type: ignore
198
+ return True
199
+ return False
200
+
201
+ def _initialize_db(self) -> None:
202
+ try:
203
+ logger.info("Initializing ArangoDB client connection...")
204
+ self.client = self.config.arango_settings.client or ArangoClient(
205
+ hosts=self.config.arango_settings.url
206
+ )
207
+
208
+ logger.info("Connecting to database...")
209
+ self.db = self.config.arango_settings.db or self.client.db(
210
+ self.config.arango_settings.database,
211
+ username=self.config.arango_settings.username,
212
+ password=self.config.arango_settings.password,
213
+ )
214
+
215
+ logger.info("Checking for existing data in collections...")
216
+ # Check if any non-system collection has data
217
+ self.config.database_created = self._has_any_data()
218
+
219
+ # If database has data, get schema
220
+ if self.config.database_created:
221
+ logger.info("Database has existing data, retrieving schema...")
222
+ # this updates self.config.kg_schema
223
+ self.arango_schema_tool(None)
224
+ else:
225
+ logger.info("No existing data found in database")
226
+
227
+ except Exception as e:
228
+ logger.error(f"Database initialization failed: {e}")
229
+ raise ConnectionError(f"Failed to initialize ArangoDB connection: {e}")
230
+
231
+ def close(self) -> None:
232
+ if self.client:
233
+ self.client.close()
234
+
235
+ @staticmethod
236
+ def cleanup_graph_db(db) -> None: # type: ignore
237
+ # First delete graphs to properly handle edge collections
238
+ for graph in db.graphs():
239
+ graph_name = graph["name"]
240
+ if not graph_name.startswith("_"): # Skip system graphs
241
+ try:
242
+ db.delete_graph(graph_name)
243
+ except Exception as e:
244
+ print(f"Failed to delete graph {graph_name}: {e}")
245
+
246
+ # Clear existing collections
247
+ for collection in db.collections():
248
+ if not collection["name"].startswith("_"): # Skip system collections
249
+ try:
250
+ db.delete_collection(collection["name"])
251
+ except Exception as e:
252
+ print(f"Failed to delete collection {collection['name']}: {e}")
253
+
254
+ def with_retry(
255
+ self, func: Callable[[], T], max_retries: int = 3, delay: float = 1.0
256
+ ) -> T:
257
+ """Execute a function with retries on connection error"""
258
+ for attempt in range(max_retries):
259
+ try:
260
+ return func()
261
+ except ArangoError:
262
+ if attempt == max_retries - 1:
263
+ raise
264
+ logger.warning(
265
+ f"Connection failed (attempt {attempt + 1}/{max_retries}). "
266
+ f"Retrying in {delay} seconds..."
267
+ )
268
+ time.sleep(delay)
269
+ # Reconnect if needed
270
+ self._initialize_db()
271
+ return func() # Final attempt after loop if not raised
272
+
273
+ def read_query(
274
+ self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
275
+ ) -> QueryResult:
276
+ """Execute a read query with connection retry."""
277
+ if not self.db:
278
+ return QueryResult(
279
+ success=False, data="No database connection is established."
280
+ )
281
+
282
+ def execute_read() -> QueryResult:
283
+ try:
284
+ cursor = self.db.aql.execute(query, bind_vars=bind_vars)
285
+ records = [doc for doc in cursor] # type: ignore
286
+ records = records[: self.config.max_num_results]
287
+ logger.warning(f"Records retrieved: {records}")
288
+ return QueryResult(success=True, data=records if records else [])
289
+ except Exception as e:
290
+ if isinstance(e, ServerConnectionError):
291
+ raise
292
+ logger.error(f"Failed to execute query: {query}\n{e}")
293
+ error_message = self.retry_query(e, query)
294
+ return QueryResult(success=False, data=error_message)
295
+
296
+ try:
297
+ return self.with_retry(execute_read) # type: ignore
298
+ except Exception as e:
299
+ return QueryResult(
300
+ success=False, data=f"Failed after max retries: {str(e)}"
301
+ )
302
+
303
+ def write_query(
304
+ self, query: str, bind_vars: Optional[Dict[Any, Any]] = None
305
+ ) -> QueryResult:
306
+ """Execute a write query with connection retry."""
307
+ if not self.db:
308
+ return QueryResult(
309
+ success=False, data="No database connection is established."
310
+ )
311
+
312
+ def execute_write() -> QueryResult:
313
+ try:
314
+ self.db.aql.execute(query, bind_vars=bind_vars)
315
+ return QueryResult(success=True)
316
+ except Exception as e:
317
+ if isinstance(e, ServerConnectionError):
318
+ raise
319
+ logger.error(f"Failed to execute query: {query}\n{e}")
320
+ error_message = self.retry_query(e, query)
321
+ return QueryResult(success=False, data=error_message)
322
+
323
+ try:
324
+ return self.with_retry(execute_write) # type: ignore
325
+ except Exception as e:
326
+ return QueryResult(
327
+ success=False, data=f"Failed after max retries: {str(e)}"
328
+ )
329
+
330
+ def _limit_tokens(self, text: str) -> str:
331
+ result = text
332
+ n_toks = self.num_tokens(result)
333
+ if n_toks > self.config.max_result_tokens:
334
+ logger.warning(
335
+ f"""
336
+ Your query resulted in a large result of
337
+ {n_toks} tokens,
338
+ which will be truncated to {self.config.max_result_tokens} tokens.
339
+ If this does not give satisfactory results,
340
+ please retry with a more focused query.
341
+ """
342
+ )
343
+ if self.parser is not None:
344
+ result = self.parser.truncate_tokens(
345
+ result,
346
+ self.config.max_result_tokens,
347
+ )
348
+ else:
349
+ result = result[: self.config.max_result_tokens * 4] # truncate roughly
350
+ return result
351
+
352
+ def aql_retrieval_tool(self, msg: AQLRetrievalTool) -> str:
353
+ """Handle AQL query for data retrieval"""
354
+ if not self.tried_schema:
355
+ return f"""
356
+ You need to use `{arango_schema_tool_name}` first to get the
357
+ database schema before using `{aql_retrieval_tool_name}`. This ensures
358
+ you know the correct collection names and edge definitions.
359
+ """
360
+ elif not self.config.database_created:
361
+ return """
362
+ You need to create the database first using `{aql_creation_tool_name}`.
363
+ """
364
+ self.num_tries += 1
365
+ query = msg.aql_query
366
+ self.current_retrieval_aql_query = query
367
+ logger.info(f"Executing AQL query: {query}")
368
+ response = self.read_query(query)
369
+
370
+ if isinstance(response.data, list) and len(response.data) == 0:
371
+ return """
372
+ No results found. Check if your collection names are correct -
373
+ they are case-sensitive. Use exact names from the schema.
374
+ Try modifying your query based on the RETRY-SUGGESTIONS
375
+ in your instructions.
376
+ """
377
+ # truncate long results
378
+ result = str(response.data)
379
+ return self._limit_tokens(result)
380
+
381
+ def aql_creation_tool(self, msg: AQLCreationTool) -> str:
382
+ """Handle AQL query for creating data"""
383
+ self.num_tries += 1
384
+ query = msg.aql_query
385
+ logger.info(f"Executing AQL query: {query}")
386
+ response = self.write_query(query)
387
+
388
+ if response.success:
389
+ self.config.database_created = True
390
+ return "AQL query executed successfully"
391
+ return str(response.data)
392
+
393
+ def arango_schema_tool(
394
+ self,
395
+ msg: ArangoSchemaTool | None,
396
+ ) -> Dict[str, List[Dict[str, Any]]] | str:
397
+ """Get database schema. If collections=None, include all collections.
398
+ If properties=False, show only connection info,
399
+ else show all properties and example-docs.
400
+ """
401
+
402
+ if msg is not None:
403
+ collections = msg.collections
404
+ properties = msg.properties
405
+ else:
406
+ collections = None
407
+ properties = True
408
+ self.tried_schema = True
409
+ if (
410
+ self.config.kg_schema is not None
411
+ and len(self.config.kg_schema) > 0
412
+ and msg is None
413
+ ):
414
+ # we are trying to pre-populate full schema before the agent runs,
415
+ # so get it if it's already available
416
+ # (Note of course that this "full schema" may actually be incomplete)
417
+ return self.config.kg_schema
418
+
419
+ # increment tries only if the LLM is asking for the schema,
420
+ # in which case msg will not be None
421
+ self.num_tries += msg is not None
422
+
423
+ try:
424
+ # Get graph schemas (keeping full graph info)
425
+ graph_schema = [
426
+ {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]}
427
+ for g in self.db.graphs() # type: ignore
428
+ ]
429
+
430
+ # Get collection schemas
431
+ collection_schema = []
432
+ for collection in self.db.collections(): # type: ignore
433
+ if collection["name"].startswith("_"):
434
+ continue
435
+
436
+ col_name = collection["name"]
437
+ if collections and col_name not in collections:
438
+ continue
439
+
440
+ col_type = collection["type"]
441
+ col_size = self.db.collection(col_name).count()
442
+
443
+ if col_size == 0:
444
+ continue
445
+
446
+ if properties:
447
+ # Full property collection with sampling
448
+ lim = self.config.schema_sample_pct * col_size # type: ignore
449
+ limit_amount = ceil(lim / 100.0) or 1
450
+ sample_query = f"""
451
+ FOR doc in {col_name}
452
+ LIMIT {limit_amount}
453
+ RETURN doc
454
+ """
455
+
456
+ properties_list = []
457
+ example_doc = None
458
+
459
+ def simplify_doc(doc: Any) -> Any:
460
+ if isinstance(doc, list) and len(doc) > 0:
461
+ return [simplify_doc(doc[0])]
462
+ if isinstance(doc, dict):
463
+ return {k: simplify_doc(v) for k, v in doc.items()}
464
+ return doc
465
+
466
+ for doc in self.db.aql.execute(sample_query): # type: ignore
467
+ if example_doc is None:
468
+ example_doc = simplify_doc(doc)
469
+ for key, value in doc.items():
470
+ prop = {"name": key, "type": type(value).__name__}
471
+ if prop not in properties_list:
472
+ properties_list.append(prop)
473
+
474
+ collection_schema.append(
475
+ {
476
+ "collection_name": col_name,
477
+ "collection_type": col_type,
478
+ f"{col_type}_properties": properties_list,
479
+ f"example_{col_type}": example_doc,
480
+ }
481
+ )
482
+ else:
483
+ # Basic info + from/to for edges only
484
+ collection_info = {
485
+ "collection_name": col_name,
486
+ "collection_type": col_type,
487
+ }
488
+ if col_type == "edge":
489
+ # Get a sample edge to extract from/to fields
490
+ sample_edge = next(
491
+ self.db.aql.execute( # type: ignore
492
+ f"FOR e IN {col_name} LIMIT 1 RETURN e"
493
+ ),
494
+ None,
495
+ )
496
+ if sample_edge:
497
+ collection_info["from_collection"] = sample_edge[
498
+ "_from"
499
+ ].split("/")[0]
500
+ collection_info["to_collection"] = sample_edge["_to"].split(
501
+ "/"
502
+ )[0]
503
+
504
+ collection_schema.append(collection_info)
505
+
506
+ schema = {
507
+ "Graph Schema": graph_schema,
508
+ "Collection Schema": collection_schema,
509
+ }
510
+ schema_str = json.dumps(schema, indent=2)
511
+ logger.warning(f"Schema retrieved:\n{schema_str}")
512
+ with open("logs/arango-schema.json", "w") as f:
513
+ f.write(schema_str)
514
+ if (n_fields := count_fields(schema)) > self.config.max_schema_fields:
515
+ logger.warning(
516
+ f"""
517
+ Schema has {n_fields} fields, which exceeds the maximum of
518
+ {self.config.max_schema_fields}. Showing a trimmed version
519
+ that only includes edge info and no other properties.
520
+ """
521
+ )
522
+ schema = trim_schema(schema)
523
+ n_fields = count_fields(schema)
524
+ logger.warning(f"Schema trimmed down to {n_fields} fields.")
525
+ schema_str = (
526
+ json.dumps(schema)
527
+ + "\n"
528
+ + f"""
529
+
530
+ CAUTION: The requested schema was too large, so
531
+ the schema has been trimmed down to show only all collection names,
532
+ their types,
533
+ and edge relationships (from/to collections) without any properties.
534
+ To find out more about the schema, you can EITHER:
535
+ - Use the `{arango_schema_tool_name}` tool again with the
536
+ `properties` arg set to True, and `collections` arg set to
537
+ specific collections you want to know more about, OR
538
+ - Use the `{aql_retrieval_tool_name}` tool to learn more about
539
+ the schema by querying the database.
540
+
541
+ """
542
+ )
543
+ if msg is None:
544
+ self.config.kg_schema = schema_str
545
+ return schema_str
546
+ self.config.kg_schema = schema
547
+ return schema
548
+
549
+ except Exception as e:
550
+ logger.error(f"Schema retrieval failed: {str(e)}")
551
+ return f"Failed to retrieve schema: {str(e)}"
552
+
553
+ def _init_tools_sys_message(self) -> None:
554
+ """Initialize system msg and enable tools"""
555
+ self.tried_schema = False
556
+ message = self._format_message()
557
+ self.config.system_message = self.config.system_message.format(mode=message)
558
+
559
+ if self.config.chat_mode:
560
+ self.config.addressing_prefix = self.config.addressing_prefix or SEND_TO
561
+ self.config.system_message += ADDRESSING_INSTRUCTION.format(
562
+ prefix=self.config.addressing_prefix
563
+ )
564
+ else:
565
+ self.config.system_message += DONE_INSTRUCTION
566
+
567
+ super().__init__(self.config)
568
+ # Note we are enabling GraphSchemaTool regardless of whether
569
+ # self.config.prepopulate_schema is True or False, because
570
+ # even when schema provided, the agent may later want to get the schema,
571
+ # e.g. if the db evolves, or schema was trimmed due to size, or
572
+ # if it needs to bring in the schema into recent context.
573
+
574
+ self.enable_message(
575
+ [
576
+ ArangoSchemaTool,
577
+ AQLRetrievalTool,
578
+ AQLCreationTool,
579
+ ForwardTool,
580
+ ]
581
+ )
582
+ if not self.config.chat_mode:
583
+ self.enable_message(DoneTool)
584
+
585
+ def _format_message(self) -> str:
586
+ if self.db is None:
587
+ raise ValueError("Database connection not established")
588
+
589
+ assert isinstance(self.config, ArangoChatAgentConfig)
590
+ return (
591
+ SCHEMA_TOOLS_SYS_MSG
592
+ if not self.config.prepopulate_schema
593
+ else SCHEMA_PROVIDED_SYS_MSG.format(schema=self.arango_schema_tool(None))
594
+ )
595
+
596
+ def handle_message_fallback(
597
+ self, msg: str | ChatDocument
598
+ ) -> str | ForwardTool | None:
599
+ """When LLM sends a no-tool msg, assume user is the intended recipient,
600
+ and if in interactive mode, forward the msg to the user.
601
+ """
602
+ done_tool_name = DoneTool.default_value("request")
603
+ forward_tool_name = ForwardTool.default_value("request")
604
+ aql_retrieval_tool_instructions = AQLRetrievalTool.instructions()
605
+ # TODO the aql_retrieval_tool_instructions may be empty/minimal
606
+ # when using self.config.use_functions_api = True.
607
+ tools_instruction = f"""
608
+ For example you may want to use the TOOL
609
+ `{aql_retrieval_tool_name}` according to these instructions:
610
+ {aql_retrieval_tool_instructions}
611
+ """
612
+ if isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM:
613
+ if self.interactive:
614
+ return ForwardTool(agent="User")
615
+ else:
616
+ if self.config.chat_mode:
617
+ return f"""
618
+ Since you did not explicitly address the User, it is not clear
619
+ whether:
620
+ - you intend this to be the final response to the
621
+ user's query/request, in which case you must use the
622
+ `{forward_tool_name}` to indicate this.
623
+ - OR, you FORGOT to use an Appropriate TOOL,
624
+ in which case you should use the available tools to
625
+ make progress on the user's query/request.
626
+ {tools_instruction}
627
+ """
628
+ return f"""
629
+ The intent of your response is not clear:
630
+ - if you intended this to be the FINAL answer to the user's query,
631
+ then use the `{done_tool_name}` to indicate so,
632
+ with the `content` set to the answer or result.
633
+ - otherwise, use one of the available tools to make progress
634
+ to arrive at the final answer.
635
+ {tools_instruction}
636
+ """
637
+ return None
638
+
639
+ def retry_query(self, e: Exception, query: str) -> str:
640
+ """Generate error message for failed AQL query"""
641
+ logger.error(f"AQL Query failed: {query}\nException: {e}")
642
+
643
+ error_message = f"""\
644
+ {ARANGO_ERROR_MSG}: '{query}'
645
+ {str(e)}
646
+ Please try again with a corrected query.
647
+ """
648
+
649
+ return error_message