langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -1,5 +1,5 @@
1
+ import copy
1
2
  import inspect
2
- import json
3
3
  import logging
4
4
  import textwrap
5
5
  from contextlib import ExitStack
@@ -7,8 +7,9 @@ from typing import Dict, List, Optional, Set, Tuple, Type, cast
7
7
 
8
8
  from rich import print
9
9
  from rich.console import Console
10
+ from rich.markup import escape
10
11
 
11
- from langroid.agent.base import Agent, AgentConfig
12
+ from langroid.agent.base import Agent, AgentConfig, noop_fn
12
13
  from langroid.agent.chat_document import ChatDocument
13
14
  from langroid.agent.tool_message import ToolMessage
14
15
  from langroid.language_models.base import (
@@ -19,8 +20,9 @@ from langroid.language_models.base import (
19
20
  )
20
21
  from langroid.language_models.openai_gpt import OpenAIGPT
21
22
  from langroid.utils.configuration import settings
23
+ from langroid.utils.output import status
22
24
 
23
- console = Console(quiet=settings.quiet)
25
+ console = Console()
24
26
 
25
27
  logger = logging.getLogger(__name__)
26
28
 
@@ -41,18 +43,36 @@ class ChatAgentConfig(AgentConfig):
41
43
 
42
44
  system_message: str = "You are a helpful assistant."
43
45
  user_message: Optional[str] = None
44
- use_tools: bool = True
45
- use_functions_api: bool = False
46
+ use_tools: bool = False
47
+ use_functions_api: bool = True
46
48
 
47
- def _switch_fn_to_tools(self) -> None:
49
+ def _set_fn_or_tools(self, fn_available: bool) -> None:
48
50
  """
49
- Switch to using our own ToolMessage mechanism,
50
- in case the LLM is not an OpenAI model.
51
+ Enable Langroid Tool or OpenAI-like fn-calling,
52
+ depending on config settings and availability of fn-calling.
51
53
  """
52
- if not self.use_functions_api:
54
+ if self.use_functions_api and not fn_available:
55
+ logger.debug(
56
+ """
57
+ You have enabled `use_functions_api` but the LLM does not support it.
58
+ So we will enable `use_tools` instead, so we can use
59
+ Langroid's ToolMessage mechanism.
60
+ """
61
+ )
62
+ self.use_functions_api = False
63
+ self.use_tools = True
64
+
65
+ if not self.use_functions_api or not self.use_tools:
53
66
  return
54
- self.use_functions_api = False
55
- self.use_tools = True
67
+ if self.use_functions_api and self.use_tools:
68
+ logger.debug(
69
+ """
70
+ You have enabled both `use_tools` and `use_functions_api`.
71
+ Turning off `use_tools`, since the LLM supports function-calling.
72
+ """
73
+ )
74
+ self.use_tools = False
75
+ self.use_functions_api = True
56
76
 
57
77
 
58
78
  class ChatAgent(Agent):
@@ -72,7 +92,9 @@ class ChatAgent(Agent):
72
92
  """
73
93
 
74
94
  def __init__(
75
- self, config: ChatAgentConfig, task: Optional[List[LLMMessage]] = None
95
+ self,
96
+ config: ChatAgentConfig = ChatAgentConfig(),
97
+ task: Optional[List[LLMMessage]] = None,
76
98
  ):
77
99
  """
78
100
  Chat-mode agent initialized with task spec as the initial message sequence
@@ -82,23 +104,7 @@ class ChatAgent(Agent):
82
104
  """
83
105
  super().__init__(config)
84
106
  self.config: ChatAgentConfig = config
85
- if (
86
- self.llm is not None
87
- and (
88
- not isinstance(self.llm, OpenAIGPT)
89
- or not self.llm.is_openai_chat_model()
90
- )
91
- and self.config.use_functions_api
92
- ):
93
- # for non-OpenAI models, use Langroid Tool instead of Function-calling
94
- logger.warning(
95
- f"""
96
- Function calling not available for {self.llm.config.chat_model},
97
- switching to Langroid Tools instead.
98
- """
99
- )
100
- self.config._switch_fn_to_tools()
101
-
107
+ self.config._set_fn_or_tools(self._fn_call_available())
102
108
  self.message_history: List[LLMMessage] = []
103
109
  self.tool_instructions_added: bool = False
104
110
  # An agent's "task" is defined by a system msg and an optional user msg;
@@ -130,8 +136,42 @@ class ChatAgent(Agent):
130
136
  self.llm_functions_usable: Set[str] = set()
131
137
  self.llm_function_force: Optional[Dict[str, str]] = None
132
138
 
139
+ def clone(self, i: int = 0) -> "ChatAgent":
140
+ """Create i'th clone of this agent, ensuring tool use/handling is cloned.
141
+ Important: We assume all member variables are in the __init__ method here
142
+ and in the Agent class.
143
+ TODO: We are attempting to close an agent after its state has been
144
+ changed in possibly many ways. Below is an imperfect solution. Caution advised.
145
+ Revisit later.
146
+ """
147
+ agent_cls = type(self)
148
+ config_copy = copy.deepcopy(self.config)
149
+ config_copy.name = f"{config_copy.name}-{i}"
150
+ new_agent = agent_cls(config_copy)
151
+ new_agent.system_tool_instructions = self.system_tool_instructions
152
+ new_agent.system_json_tool_instructions = self.system_json_tool_instructions
153
+ new_agent.llm_tools_map = self.llm_tools_map
154
+ new_agent.llm_functions_map = self.llm_functions_map
155
+ new_agent.llm_functions_handled = self.llm_functions_handled
156
+ new_agent.llm_functions_usable = self.llm_functions_usable
157
+ new_agent.llm_function_force = self.llm_function_force
158
+ # Caution - we are copying the vector-db, maybe we don't always want this?
159
+ new_agent.vecdb = self.vecdb
160
+ return new_agent
161
+
162
+ def _fn_call_available(self) -> bool:
163
+ """Does this agent's LLM support function calling?"""
164
+ return (
165
+ self.llm is not None
166
+ and isinstance(self.llm, OpenAIGPT)
167
+ and self.llm.is_openai_chat_model()
168
+ )
169
+
133
170
  def set_system_message(self, msg: str) -> None:
134
171
  self.system_message = msg
172
+ if len(self.message_history) > 0:
173
+ # if there is message history, update the system message in it
174
+ self.message_history[0].content = msg
135
175
 
136
176
  def set_user_message(self, msg: str) -> None:
137
177
  self.user_message = msg
@@ -188,46 +228,24 @@ class ChatAgent(Agent):
188
228
  enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
189
229
  if len(enabled_classes) == 0:
190
230
  return "You can ask questions in natural language."
191
-
192
231
  json_instructions = "\n\n".join(
193
232
  [
194
- textwrap.dedent(
195
- f"""
196
- TOOL: {msg_cls.default_value("request")}
197
- PURPOSE: {msg_cls.default_value("purpose")}
198
- JSON FORMAT: {
199
- json.dumps(
200
- msg_cls.llm_function_schema(request=True).parameters,
201
- indent=4,
202
- )
203
- }
204
- {"EXAMPLE: " + msg_cls.usage_example() if msg_cls.examples() else ""}
205
- """.lstrip()
206
- )
207
- for i, msg_cls in enumerate(enabled_classes)
233
+ msg_cls.json_instructions(tool=self.config.use_tools)
234
+ for _, msg_cls in enumerate(enabled_classes)
208
235
  if msg_cls.default_value("request") in self.llm_tools_usable
209
236
  ]
210
237
  )
211
- return textwrap.dedent(
212
- f"""
213
- === ALL AVAILABLE TOOLS and THEIR JSON FORMAT INSTRUCTIONS ===
214
- You have access to the following TOOLS to accomplish your task:
215
-
216
- {json_instructions}
217
-
218
- When one of the above TOOLs is applicable, you must express your
219
- request as "TOOL:" followed by the request in the above JSON format.
220
- """
221
- + """
222
- The JSON format will be:
223
- \\{
224
- "request": "<tool_name>",
225
- "<arg1>": <value1>,
226
- "<arg2>": <value2>,
227
- ...
228
- \\}
229
- ----------------------------
230
- """.lstrip()
238
+ # if any of the enabled classes has json_group_instructions, then use that,
239
+ # else fall back to ToolMessage.json_group_instructions
240
+ for msg_cls in enabled_classes:
241
+ if hasattr(msg_cls, "json_group_instructions") and callable(
242
+ getattr(msg_cls, "json_group_instructions")
243
+ ):
244
+ return msg_cls.json_group_instructions().format(
245
+ json_instructions=json_instructions
246
+ )
247
+ return ToolMessage.json_group_instructions().format(
248
+ json_instructions=json_instructions
231
249
  )
232
250
 
233
251
  def tool_instructions(self) -> str:
@@ -301,7 +319,7 @@ class ChatAgent(Agent):
301
319
  Useful when we want to replace a long user prompt, that may contain context
302
320
  documents plus a question, with just the question.
303
321
  Args:
304
- message (str): user message
322
+ message (str): new message to replace with
305
323
  role (str): role of message to replace
306
324
  """
307
325
  if len(self.message_history) == 0:
@@ -337,7 +355,8 @@ class ChatAgent(Agent):
337
355
 
338
356
  """.lstrip()
339
357
  )
340
- return LLMMessage(role=Role.SYSTEM, content=content)
358
+ # remove leading and trailing newlines and other whitespace
359
+ return LLMMessage(role=Role.SYSTEM, content=content.strip())
341
360
 
342
361
  def enable_message(
343
362
  self,
@@ -443,10 +462,10 @@ class ChatAgent(Agent):
443
462
  message_class: The only ToolMessage class to allow
444
463
  """
445
464
  request = message_class.__fields__["request"].default
446
- for r in self.llm_functions_usable:
447
- if r != request:
448
- self.llm_tools_usable.discard(r)
449
- self.llm_functions_usable.discard(r)
465
+ to_remove = [r for r in self.llm_tools_usable if r != request]
466
+ for r in to_remove:
467
+ self.llm_tools_usable.discard(r)
468
+ self.llm_functions_usable.discard(r)
450
469
 
451
470
  def llm_response(
452
471
  self, message: Optional[str | ChatDocument] = None
@@ -463,6 +482,8 @@ class ChatAgent(Agent):
463
482
  if self.llm is None:
464
483
  return None
465
484
  hist, output_len = self._prep_llm_messages(message)
485
+ if len(hist) == 0:
486
+ return None
466
487
  with StreamingIfAllowed(self.llm, self.llm.get_stream()):
467
488
  response = self.llm_response_messages(hist, output_len)
468
489
  # TODO - when response contains function_call we should include
@@ -472,9 +493,7 @@ class ChatAgent(Agent):
472
493
  response.metadata.tool_ids = (
473
494
  []
474
495
  if isinstance(message, str)
475
- else message.metadata.tool_ids
476
- if message is not None
477
- else []
496
+ else message.metadata.tool_ids if message is not None else []
478
497
  )
479
498
  return response
480
499
 
@@ -497,9 +516,7 @@ class ChatAgent(Agent):
497
516
  response.metadata.tool_ids = (
498
517
  []
499
518
  if isinstance(message, str)
500
- else message.metadata.tool_ids
501
- if message is not None
502
- else []
519
+ else message.metadata.tool_ids if message is not None else []
503
520
  )
504
521
  return response
505
522
 
@@ -546,8 +563,9 @@ class ChatAgent(Agent):
546
563
  if settings.debug:
547
564
  print(
548
565
  f"""
549
- [red]LLM Initial Msg History:
550
- {self.message_history_str()}
566
+ [grey37]LLM Initial Msg History:
567
+ {escape(self.message_history_str())}
568
+ [/grey37]
551
569
  """
552
570
  )
553
571
  else:
@@ -587,7 +605,10 @@ class ChatAgent(Agent):
587
605
  raise ValueError(
588
606
  """
589
607
  The message history is longer than the max chat context
590
- length allowed, and we have run out of messages to drop."""
608
+ length allowed, and we have run out of messages to drop.
609
+ HINT: In your `OpenAIGPTConfig` object, try increasing
610
+ `chat_context_length` or decreasing `max_output_tokens`.
611
+ """
591
612
  )
592
613
  # drop the second message, i.e. first msg after the sys msg
593
614
  # (typically user msg).
@@ -655,10 +676,17 @@ class ChatAgent(Agent):
655
676
  """
656
677
  assert self.config.llm is not None and self.llm is not None
657
678
  output_len = output_len or self.config.llm.max_output_tokens
679
+ streamer = noop_fn
680
+ if self.llm.get_stream():
681
+ streamer = self.callbacks.start_llm_stream()
682
+ self.llm.config.streamer = streamer
658
683
  with ExitStack() as stack: # for conditionally using rich spinner
659
- if not self.llm.get_stream() and not settings.quiet:
684
+ if not self.llm.get_stream():
660
685
  # show rich spinner only if not streaming!
661
- cm = console.status("LLM responding to messages...")
686
+ cm = status(
687
+ "LLM responding to messages...",
688
+ log_if_quiet=False,
689
+ )
662
690
  stack.enter_context(cm)
663
691
  if self.llm.get_stream() and not settings.quiet:
664
692
  console.print(f"[green]{self.indent}", end="")
@@ -670,17 +698,31 @@ class ChatAgent(Agent):
670
698
  functions=functions,
671
699
  function_call=fun_call,
672
700
  )
701
+ if self.llm.get_stream():
702
+ self.callbacks.finish_llm_stream(
703
+ content=str(response),
704
+ is_tool=self.has_tool_message_attempt(
705
+ ChatDocument.from_LLMResponse(response, displayed=True)
706
+ ),
707
+ )
708
+ self.llm.config.streamer = noop_fn
709
+ if response.cached:
710
+ self.callbacks.cancel_llm_stream()
711
+
673
712
  if not self.llm.get_stream() or response.cached:
674
713
  # We would have already displayed the msg "live" ONLY if
675
714
  # streaming was enabled, AND we did not find a cached response.
676
715
  # If we are here, it means the response has not yet been displayed.
677
716
  cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
678
- if response.function_call is not None:
679
- response_str = str(response.function_call)
680
- else:
681
- response_str = response.message
682
717
  if not settings.quiet:
683
- print(cached + "[green]" + response_str)
718
+ print(cached + "[green]" + escape(str(response)))
719
+ self.callbacks.show_llm_response(
720
+ content=str(response),
721
+ is_tool=self.has_tool_message_attempt(
722
+ ChatDocument.from_LLMResponse(response, displayed=True)
723
+ ),
724
+ cached=response.cached,
725
+ )
684
726
  self.update_token_usage(
685
727
  response,
686
728
  messages,
@@ -706,24 +748,42 @@ class ChatAgent(Agent):
706
748
  "auto" if self.llm_function_force is None else self.llm_function_force
707
749
  )
708
750
  assert self.llm is not None
751
+
752
+ streamer = noop_fn
753
+ if self.llm.get_stream():
754
+ streamer = self.callbacks.start_llm_stream()
755
+ self.llm.config.streamer = streamer
756
+
709
757
  response = await self.llm.achat(
710
758
  messages,
711
759
  output_len,
712
760
  functions=functions,
713
761
  function_call=fun_call,
714
762
  )
715
-
763
+ if self.llm.get_stream():
764
+ self.callbacks.finish_llm_stream(
765
+ content=str(response),
766
+ is_tool=self.has_tool_message_attempt(
767
+ ChatDocument.from_LLMResponse(response, displayed=True)
768
+ ),
769
+ )
770
+ self.llm.config.streamer = noop_fn
771
+ if response.cached:
772
+ self.callbacks.cancel_llm_stream()
716
773
  if not self.llm.get_stream() or response.cached:
717
774
  # We would have already displayed the msg "live" ONLY if
718
775
  # streaming was enabled, AND we did not find a cached response.
719
776
  # If we are here, it means the response has not yet been displayed.
720
777
  cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
721
- if response.function_call is not None:
722
- response_str = str(response.function_call)
723
- else:
724
- response_str = response.message
725
778
  if not settings.quiet:
726
- print(cached + "[green]" + response_str)
779
+ print(cached + "[green]" + escape(str(response)))
780
+ self.callbacks.show_llm_response(
781
+ content=str(response),
782
+ is_tool=self.has_tool_message_attempt(
783
+ ChatDocument.from_LLMResponse(response, displayed=True)
784
+ ),
785
+ cached=response.cached,
786
+ )
727
787
 
728
788
  self.update_token_usage(
729
789
  response,
@@ -785,12 +845,14 @@ class ChatAgent(Agent):
785
845
  """
786
846
  # explicitly call THIS class's respond method,
787
847
  # not a derived class's (or else there would be infinite recursion!)
848
+ n_msgs = len(self.message_history)
788
849
  with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
789
850
  response = cast(ChatDocument, ChatAgent.llm_response(self, message))
790
- # clear the last two messages, which are the
791
- # user message and the assistant response
792
- self.message_history.pop()
793
- self.message_history.pop()
851
+ # If there is a response, then we will have two additional
852
+ # messages in the message history, i.e. the user message and the
853
+ # assistant response. We want to (carefully) remove these two messages.
854
+ self.message_history.pop() if len(self.message_history) > n_msgs else None
855
+ self.message_history.pop() if len(self.message_history) > n_msgs else None
794
856
  return response
795
857
 
796
858
  async def llm_response_forget_async(self, message: str) -> ChatDocument:
@@ -799,14 +861,16 @@ class ChatAgent(Agent):
799
861
  """
800
862
  # explicitly call THIS class's respond method,
801
863
  # not a derived class's (or else there would be infinite recursion!)
864
+ n_msgs = len(self.message_history)
802
865
  with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
803
866
  response = cast(
804
867
  ChatDocument, await ChatAgent.llm_response_async(self, message)
805
868
  )
806
- # clear the last two messages, which are the
807
- # user message and the assistant response
808
- self.message_history.pop()
809
- self.message_history.pop()
869
+ # If there is a response, then we will have two additional
870
+ # messages in the message history, i.e. the user message and the
871
+ # assistant response. We want to (carefully) remove these two messages.
872
+ self.message_history.pop() if len(self.message_history) > n_msgs else None
873
+ self.message_history.pop() if len(self.message_history) > n_msgs else None
810
874
  return response
811
875
 
812
876
  def chat_num_tokens(self, messages: Optional[List[LLMMessage]] = None) -> int:
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
3
3
 
4
4
  from pydantic import BaseModel, Extra
5
5
 
6
+ from langroid.agent.tool_message import ToolMessage
6
7
  from langroid.language_models.base import (
7
8
  LLMFunctionCall,
8
9
  LLMMessage,
@@ -12,7 +13,7 @@ from langroid.language_models.base import (
12
13
  )
13
14
  from langroid.mytypes import DocMetaData, Document, Entity
14
15
  from langroid.parsing.agent_chats import parse_message
15
- from langroid.parsing.json import extract_top_level_json, top_level_json_field
16
+ from langroid.parsing.parse_json import extract_top_level_json, top_level_json_field
16
17
  from langroid.utils.output.printing import shorten_text
17
18
 
18
19
 
@@ -53,6 +54,7 @@ class ChatDocLoggerFields(BaseModel):
53
54
 
54
55
  class ChatDocument(Document):
55
56
  function_call: Optional[LLMFunctionCall] = None
57
+ tool_messages: List[ToolMessage] = []
56
58
  metadata: ChatDocMetaData
57
59
  attachment: None | ChatDocAttachment = None
58
60
 
@@ -82,7 +84,7 @@ class ChatDocument(Document):
82
84
  json_data = json.loads(j)
83
85
  tool = json_data.get("request")
84
86
  if tool is not None:
85
- tools.append(tool)
87
+ tools.append(str(tool))
86
88
  return tools
87
89
 
88
90
  def log_fields(self) -> ChatDocLoggerFields:
@@ -103,6 +105,8 @@ class ChatDocument(Document):
103
105
  content = self.content
104
106
  sender_entity = self.metadata.sender
105
107
  sender_name = self.metadata.sender_name
108
+ if tool_type == "FUNC":
109
+ content += str(self.function_call)
106
110
  return ChatDocLoggerFields(
107
111
  sender_entity=sender_entity,
108
112
  sender_name=sender_name,
@@ -140,6 +144,9 @@ class ChatDocument(Document):
140
144
  ChatDocument: ChatDocument representation of this LLMResponse.
141
145
  """
142
146
  recipient, message = response.get_recipient_and_message()
147
+ message = message.strip()
148
+ if message in ["''", '""']:
149
+ message = ""
143
150
  return ChatDocument(
144
151
  content=message,
145
152
  function_call=response.function_call,
@@ -188,6 +195,16 @@ class ChatDocument(Document):
188
195
  if isinstance(message, ChatDocument):
189
196
  content = message.content
190
197
  fun_call = message.function_call
198
+ if message.metadata.sender == Entity.USER and fun_call is not None:
199
+ # This may happen when a (parent agent's) LLM generates a
200
+ # a Function-call, and it ends up being sent to the current task's
201
+ # LLM (possibly because the function-call is mis-named or has other
202
+ # issues and couldn't be handled by handler methods).
203
+ # But a function-call can only be generated by an entity with
204
+ # Role.ASSISTANT, so we instead put the content of the function-call
205
+ # in the content of the message.
206
+ content += " " + str(fun_call)
207
+ fun_call = None
191
208
  sender_name = message.metadata.sender_name
192
209
  tool_ids = message.metadata.tool_ids
193
210
  tool_id = tool_ids[-1] if len(tool_ids) > 0 else ""
@@ -8,7 +8,7 @@ from enum import Enum
8
8
  from typing import Any, Dict, List, Optional, Tuple, Type, cast, no_type_check
9
9
 
10
10
  from openai.types.beta import Assistant, Thread
11
- from openai.types.beta.threads import Run, ThreadMessage
11
+ from openai.types.beta.threads import Message, Run
12
12
  from openai.types.beta.threads.runs import RunStep
13
13
  from pydantic import BaseModel
14
14
  from rich import print
@@ -41,6 +41,8 @@ class AssistantTool(BaseModel):
41
41
  def dct(self) -> Dict[str, Any]:
42
42
  d = super().dict()
43
43
  d["type"] = d["type"].value
44
+ if self.type != ToolType.FUNCTION:
45
+ d.pop("function")
44
46
  return d
45
47
 
46
48
 
@@ -78,7 +80,6 @@ class OpenAIAssistantConfig(ChatAgentConfig):
78
80
 
79
81
 
80
82
  class OpenAIAssistant(ChatAgent):
81
-
82
83
  """
83
84
  A ChatAgent powered by OpenAI Assistant API:
84
85
  mainly, in `llm_response` method, we avoid maintaining conversation state,
@@ -257,22 +258,22 @@ class OpenAIAssistant(ChatAgent):
257
258
  self.llm.cache.store(assistant_key, self.assistant.id)
258
259
 
259
260
  @staticmethod
260
- def thread_msg_to_llm_msg(msg: ThreadMessage) -> LLMMessage:
261
+ def thread_msg_to_llm_msg(msg: Message) -> LLMMessage:
261
262
  """
262
- Convert a ThreadMessage to an LLMMessage
263
+ Convert a Message to an LLMMessage
263
264
  """
264
265
  return LLMMessage(
265
266
  content=msg.content[0].text.value, # type: ignore
266
- role=msg.role,
267
+ role=Role(msg.role),
267
268
  )
268
269
 
269
- def _update_messages_hash(self, msg: ThreadMessage | LLMMessage) -> None:
270
+ def _update_messages_hash(self, msg: Message | LLMMessage) -> None:
270
271
  """
271
272
  Update the hash-state in the thread with the given message.
272
273
  """
273
274
  if self.thread is None:
274
275
  raise ValueError("Thread is None")
275
- if isinstance(msg, ThreadMessage):
276
+ if isinstance(msg, Message):
276
277
  llm_msg = self.thread_msg_to_llm_msg(msg)
277
278
  else:
278
279
  llm_msg = msg
@@ -491,7 +492,7 @@ class OpenAIAssistant(ChatAgent):
491
492
  LLMMessage(
492
493
  # TODO: could be image, deal with it later
493
494
  content=m.content[0].text.value, # type: ignore
494
- role=m.role,
495
+ role=Role(m.role),
495
496
  )
496
497
  for m in thread_msgs
497
498
  ]
@@ -646,7 +647,7 @@ class OpenAIAssistant(ChatAgent):
646
647
  tool_outputs=tool_outputs, # type: ignore
647
648
  )
648
649
 
649
- def process_citations(self, thread_msg: ThreadMessage) -> None:
650
+ def process_citations(self, thread_msg: Message) -> None:
650
651
  """
651
652
  Process citations in the thread message.
652
653
  Modifies the thread message in-place.
@@ -664,7 +665,16 @@ class OpenAIAssistant(ChatAgent):
664
665
  )
665
666
  # Gather citations based on annotation attributes
666
667
  if file_citation := getattr(annotation, "file_citation", None):
667
- cited_file = self.client.files.retrieve(file_citation.file_id)
668
+ try:
669
+ cited_file = self.client.files.retrieve(file_citation.file_id)
670
+ except Exception:
671
+ logger.warning(
672
+ f"""
673
+ Could not retrieve cited file with id {file_citation.file_id},
674
+ ignoring.
675
+ """
676
+ )
677
+ continue
668
678
  citations.append(
669
679
  f"[{index}] '{file_citation.quote}',-- from {cited_file.filename}"
670
680
  )
@@ -1,27 +1,50 @@
1
- from .doc_chat_agent import DocChatAgent, DocChatAgentConfig
2
- from .recipient_validator_agent import (
3
- RecipientValidatorConfig,
4
- RecipientValidatorAttachment,
5
- RecipientValidator,
1
+ from .relevance_extractor_agent import (
2
+ RelevanceExtractorAgent,
3
+ RelevanceExtractorAgentConfig,
6
4
  )
5
+ from .doc_chat_agent import DocChatAgent, DocChatAgentConfig
7
6
  from .retriever_agent import (
8
7
  RecordMetadata,
9
8
  RecordDoc,
10
9
  RetrieverAgentConfig,
11
10
  RetrieverAgent,
12
11
  )
12
+ from .lance_doc_chat_agent import LanceDocChatAgent
13
13
  from .table_chat_agent import (
14
14
  dataframe_summary,
15
15
  TableChatAgent,
16
16
  TableChatAgentConfig,
17
17
  RunCodeTool,
18
18
  )
19
- from .relevance_extractor_agent import (
20
- RelevanceExtractorAgent,
21
- RelevanceExtractorAgentConfig,
22
- )
23
19
  from . import sql
20
+ from . import relevance_extractor_agent
24
21
  from . import doc_chat_agent
25
- from . import recipient_validator_agent
26
22
  from . import retriever_agent
23
+ from . import lance_tools
24
+ from . import lance_doc_chat_agent
25
+ from . import lance_rag
27
26
  from . import table_chat_agent
27
+
28
+ __all__ = [
29
+ "RelevanceExtractorAgent",
30
+ "RelevanceExtractorAgentConfig",
31
+ "DocChatAgent",
32
+ "DocChatAgentConfig",
33
+ "RecordMetadata",
34
+ "RecordDoc",
35
+ "RetrieverAgentConfig",
36
+ "RetrieverAgent",
37
+ "LanceDocChatAgent",
38
+ "dataframe_summary",
39
+ "TableChatAgent",
40
+ "TableChatAgentConfig",
41
+ "RunCodeTool",
42
+ "sql",
43
+ "relevance_extractor_agent",
44
+ "doc_chat_agent",
45
+ "retriever_agent",
46
+ "lance_tools",
47
+ "lance_doc_chat_agent",
48
+ "lance_rag",
49
+ "table_chat_agent",
50
+ ]