langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. langroid/__init__.py +106 -0
  2. langroid/agent/__init__.py +41 -0
  3. langroid/agent/base.py +1983 -0
  4. langroid/agent/batch.py +398 -0
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +598 -0
  7. langroid/agent/chat_agent.py +1899 -0
  8. langroid/agent/chat_document.py +454 -0
  9. langroid/agent/openai_assistant.py +882 -0
  10. langroid/agent/special/__init__.py +59 -0
  11. langroid/agent/special/arangodb/__init__.py +0 -0
  12. langroid/agent/special/arangodb/arangodb_agent.py +656 -0
  13. langroid/agent/special/arangodb/system_messages.py +186 -0
  14. langroid/agent/special/arangodb/tools.py +107 -0
  15. langroid/agent/special/arangodb/utils.py +36 -0
  16. langroid/agent/special/doc_chat_agent.py +1466 -0
  17. langroid/agent/special/lance_doc_chat_agent.py +262 -0
  18. langroid/agent/special/lance_rag/__init__.py +9 -0
  19. langroid/agent/special/lance_rag/critic_agent.py +198 -0
  20. langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
  21. langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
  22. langroid/agent/special/lance_tools.py +61 -0
  23. langroid/agent/special/neo4j/__init__.py +0 -0
  24. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  25. langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
  26. langroid/agent/special/neo4j/system_messages.py +120 -0
  27. langroid/agent/special/neo4j/tools.py +32 -0
  28. langroid/agent/special/relevance_extractor_agent.py +127 -0
  29. langroid/agent/special/retriever_agent.py +56 -0
  30. langroid/agent/special/sql/__init__.py +17 -0
  31. langroid/agent/special/sql/sql_chat_agent.py +654 -0
  32. langroid/agent/special/sql/utils/__init__.py +21 -0
  33. langroid/agent/special/sql/utils/description_extractors.py +190 -0
  34. langroid/agent/special/sql/utils/populate_metadata.py +85 -0
  35. langroid/agent/special/sql/utils/system_message.py +35 -0
  36. langroid/agent/special/sql/utils/tools.py +64 -0
  37. langroid/agent/special/table_chat_agent.py +263 -0
  38. langroid/agent/task.py +2095 -0
  39. langroid/agent/tool_message.py +393 -0
  40. langroid/agent/tools/__init__.py +38 -0
  41. langroid/agent/tools/duckduckgo_search_tool.py +50 -0
  42. langroid/agent/tools/file_tools.py +234 -0
  43. langroid/agent/tools/google_search_tool.py +39 -0
  44. langroid/agent/tools/metaphor_search_tool.py +68 -0
  45. langroid/agent/tools/orchestration.py +303 -0
  46. langroid/agent/tools/recipient_tool.py +235 -0
  47. langroid/agent/tools/retrieval_tool.py +32 -0
  48. langroid/agent/tools/rewind_tool.py +137 -0
  49. langroid/agent/tools/segment_extract_tool.py +41 -0
  50. langroid/agent/xml_tool_message.py +382 -0
  51. langroid/cachedb/__init__.py +17 -0
  52. langroid/cachedb/base.py +58 -0
  53. langroid/cachedb/momento_cachedb.py +108 -0
  54. langroid/cachedb/redis_cachedb.py +153 -0
  55. langroid/embedding_models/__init__.py +39 -0
  56. langroid/embedding_models/base.py +74 -0
  57. langroid/embedding_models/models.py +461 -0
  58. langroid/embedding_models/protoc/__init__.py +0 -0
  59. langroid/embedding_models/protoc/embeddings.proto +19 -0
  60. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  61. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  62. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  63. langroid/embedding_models/remote_embeds.py +153 -0
  64. langroid/exceptions.py +71 -0
  65. langroid/language_models/__init__.py +53 -0
  66. langroid/language_models/azure_openai.py +153 -0
  67. langroid/language_models/base.py +678 -0
  68. langroid/language_models/config.py +18 -0
  69. langroid/language_models/mock_lm.py +124 -0
  70. langroid/language_models/openai_gpt.py +1964 -0
  71. langroid/language_models/prompt_formatter/__init__.py +16 -0
  72. langroid/language_models/prompt_formatter/base.py +40 -0
  73. langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
  74. langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
  75. langroid/language_models/utils.py +151 -0
  76. langroid/mytypes.py +84 -0
  77. langroid/parsing/__init__.py +52 -0
  78. langroid/parsing/agent_chats.py +38 -0
  79. langroid/parsing/code_parser.py +121 -0
  80. langroid/parsing/document_parser.py +718 -0
  81. langroid/parsing/para_sentence_split.py +62 -0
  82. langroid/parsing/parse_json.py +155 -0
  83. langroid/parsing/parser.py +313 -0
  84. langroid/parsing/repo_loader.py +790 -0
  85. langroid/parsing/routing.py +36 -0
  86. langroid/parsing/search.py +275 -0
  87. langroid/parsing/spider.py +102 -0
  88. langroid/parsing/table_loader.py +94 -0
  89. langroid/parsing/url_loader.py +111 -0
  90. langroid/parsing/urls.py +273 -0
  91. langroid/parsing/utils.py +373 -0
  92. langroid/parsing/web_search.py +156 -0
  93. langroid/prompts/__init__.py +9 -0
  94. langroid/prompts/dialog.py +17 -0
  95. langroid/prompts/prompts_config.py +5 -0
  96. langroid/prompts/templates.py +141 -0
  97. langroid/pydantic_v1/__init__.py +10 -0
  98. langroid/pydantic_v1/main.py +4 -0
  99. langroid/utils/__init__.py +19 -0
  100. langroid/utils/algorithms/__init__.py +3 -0
  101. langroid/utils/algorithms/graph.py +103 -0
  102. langroid/utils/configuration.py +98 -0
  103. langroid/utils/constants.py +30 -0
  104. langroid/utils/git_utils.py +252 -0
  105. langroid/utils/globals.py +49 -0
  106. langroid/utils/logging.py +135 -0
  107. langroid/utils/object_registry.py +66 -0
  108. langroid/utils/output/__init__.py +20 -0
  109. langroid/utils/output/citations.py +41 -0
  110. langroid/utils/output/printing.py +99 -0
  111. langroid/utils/output/status.py +40 -0
  112. langroid/utils/pandas_utils.py +30 -0
  113. langroid/utils/pydantic_utils.py +602 -0
  114. langroid/utils/system.py +286 -0
  115. langroid/utils/types.py +93 -0
  116. langroid/vector_store/__init__.py +50 -0
  117. langroid/vector_store/base.py +359 -0
  118. langroid/vector_store/chromadb.py +214 -0
  119. langroid/vector_store/lancedb.py +406 -0
  120. langroid/vector_store/meilisearch.py +299 -0
  121. langroid/vector_store/momento.py +278 -0
  122. langroid/vector_store/qdrantdb.py +468 -0
  123. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
  124. langroid-0.33.7.dist-info/RECORD +127 -0
  125. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
  126. langroid-0.33.6.dist-info/RECORD +0 -7
  127. langroid-0.33.6.dist-info/entry_points.txt +0 -4
  128. pyproject.toml +0 -356
  129. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,882 @@
1
+ import asyncio
2
+ import json
3
+
4
+ # setup logger
5
+ import logging
6
+ import time
7
+ from enum import Enum
8
+ from typing import Any, Dict, List, Optional, Tuple, Type, cast, no_type_check
9
+
10
+ import openai
11
+ from openai.types.beta import Assistant, Thread
12
+ from openai.types.beta.assistant_update_params import (
13
+ ToolResources,
14
+ ToolResourcesCodeInterpreter,
15
+ )
16
+ from openai.types.beta.threads import Message, Run
17
+ from openai.types.beta.threads.runs import RunStep
18
+ from rich import print
19
+
20
+ from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
21
+ from langroid.agent.chat_document import ChatDocument
22
+ from langroid.agent.tool_message import ToolMessage
23
+ from langroid.language_models.base import LLMFunctionCall, LLMMessage, LLMResponse, Role
24
+ from langroid.language_models.openai_gpt import (
25
+ OpenAIChatModel,
26
+ OpenAIGPT,
27
+ OpenAIGPTConfig,
28
+ )
29
+ from langroid.pydantic_v1 import BaseModel
30
+ from langroid.utils.configuration import settings
31
+ from langroid.utils.system import generate_user_id, update_hash
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class ToolType(str, Enum):
37
+ RETRIEVAL = "file_search"
38
+ CODE_INTERPRETER = "code_interpreter"
39
+ FUNCTION = "function"
40
+
41
+
42
+ class AssistantTool(BaseModel):
43
+ type: ToolType
44
+ function: Dict[str, Any] | None = None
45
+
46
+ def dct(self) -> Dict[str, Any]:
47
+ d = super().dict()
48
+ d["type"] = d["type"].value
49
+ if self.type != ToolType.FUNCTION:
50
+ d.pop("function")
51
+ return d
52
+
53
+
54
+ class AssistantToolCall(BaseModel):
55
+ id: str
56
+ type: ToolType
57
+ function: LLMFunctionCall
58
+
59
+
60
+ class RunStatus(str, Enum):
61
+ QUEUED = "queued"
62
+ IN_PROGRESS = "in_progress"
63
+ COMPLETED = "completed"
64
+ REQUIRES_ACTION = "requires_action"
65
+ EXPIRED = "expired"
66
+ CANCELLING = "cancelling"
67
+ CANCELLED = "cancelled"
68
+ FAILED = "failed"
69
+ TIMEOUT = "timeout"
70
+
71
+
72
+ class OpenAIAssistantConfig(ChatAgentConfig):
73
+ use_cached_assistant: bool = False # set in script via user dialog
74
+ assistant_id: str | None = None
75
+ use_tools = False
76
+ use_functions_api = True
77
+ use_cached_thread: bool = False # set in script via user dialog
78
+ thread_id: str | None = None
79
+ # set to True once we can add Assistant msgs in threads
80
+ cache_responses: bool = True
81
+ timeout: int = 30 # can be different from llm.timeout
82
+ llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4o)
83
+ tools: List[AssistantTool] = []
84
+ files: List[str] = []
85
+
86
+
87
+ class OpenAIAssistant(ChatAgent):
88
+ """
89
+ A ChatAgent powered by OpenAI Assistant API:
90
+ mainly, in `llm_response` method, we avoid maintaining conversation state,
91
+ and instead let the Assistant API do it for us.
92
+ Also handles persistent storage of Assistant and Threads:
93
+ stores their ids (for given user, org) in a cache, and
94
+ reuses them based on config.use_cached_assistant and config.use_cached_thread.
95
+
96
+ This class can be used as a drop-in replacement for ChatAgent.
97
+ """
98
+
99
+ def __init__(self, config: OpenAIAssistantConfig):
100
+ super().__init__(config)
101
+ self.config: OpenAIAssistantConfig = config
102
+ self.llm: OpenAIGPT = OpenAIGPT(self.config.llm)
103
+ assert (
104
+ self.llm.cache is not None
105
+ ), "OpenAIAssistant requires a cache to store Assistant and Thread ids"
106
+
107
+ if not isinstance(self.llm.client, openai.OpenAI):
108
+ raise ValueError("Client must be OpenAI")
109
+ # handles for various entities and methods
110
+ self.client: openai.OpenAI = self.llm.client
111
+ self.runs = self.client.beta.threads.runs
112
+ self.threads = self.client.beta.threads
113
+ self.thread_messages = self.client.beta.threads.messages
114
+ self.assistants = self.client.beta.assistants
115
+ # which tool_ids are awaiting output submissions
116
+ self.pending_tool_ids: List[str] = []
117
+ self.cached_tool_ids: List[str] = []
118
+
119
+ self.thread: Thread | None = None
120
+ self.assistant: Assistant | None = None
121
+ self.run: Run | None = None
122
+
123
+ self._maybe_create_assistant(self.config.assistant_id)
124
+ self._maybe_create_thread(self.config.thread_id)
125
+ self._cache_store()
126
+
127
+ self.add_assistant_files(self.config.files)
128
+ self.add_assistant_tools(self.config.tools)
129
+
130
+ def add_assistant_files(self, files: List[str]) -> None:
131
+ """Add file_ids to assistant"""
132
+ if self.assistant is None:
133
+ raise ValueError("Assistant is None")
134
+ self.files = [
135
+ self.client.files.create(file=open(f, "rb"), purpose="assistants")
136
+ for f in files
137
+ ]
138
+ self.config.files = list(set(self.config.files + files))
139
+ self.assistant = self.assistants.update(
140
+ self.assistant.id,
141
+ tool_resources=ToolResources(
142
+ code_interpreter=ToolResourcesCodeInterpreter(
143
+ file_ids=[f.id for f in self.files],
144
+ ),
145
+ ),
146
+ )
147
+
148
+ def add_assistant_tools(self, tools: List[AssistantTool]) -> None:
149
+ """Add tools to assistant"""
150
+ if self.assistant is None:
151
+ raise ValueError("Assistant is None")
152
+ all_tool_dicts = [t.dct() for t in self.config.tools]
153
+ for t in tools:
154
+ if t.dct() not in all_tool_dicts:
155
+ self.config.tools.append(t)
156
+ self.assistant = self.assistants.update(
157
+ self.assistant.id,
158
+ tools=[tool.dct() for tool in self.config.tools], # type: ignore
159
+ )
160
+
161
+ def enable_message(
162
+ self,
163
+ message_class: Optional[Type[ToolMessage] | List[Type[ToolMessage]]],
164
+ use: bool = True,
165
+ handle: bool = True,
166
+ force: bool = False,
167
+ require_recipient: bool = False,
168
+ include_defaults: bool = True,
169
+ ) -> None:
170
+ """Override ChatAgent's method: extract the function-related args.
171
+ See that method for details. But specifically about the `include_defaults` arg:
172
+ Normally the OpenAI completion API ignores these fields, but the Assistant
173
+ fn-calling seems to pay attn to these, and if we don't want this,
174
+ we should set this to False.
175
+ """
176
+ if message_class is not None and isinstance(message_class, list):
177
+ for msg_class in message_class:
178
+ self.enable_message(
179
+ msg_class,
180
+ use=use,
181
+ handle=handle,
182
+ force=force,
183
+ require_recipient=require_recipient,
184
+ include_defaults=include_defaults,
185
+ )
186
+ return
187
+ super().enable_message(
188
+ message_class,
189
+ use=use,
190
+ handle=handle,
191
+ force=force,
192
+ require_recipient=require_recipient,
193
+ include_defaults=include_defaults,
194
+ )
195
+ if message_class is None or not use:
196
+ # no specific msg class, or
197
+ # we are not enabling USAGE/GENERATION of this tool/fn,
198
+ # then there's no need to attach the fn to the assistant
199
+ # (HANDLING the fn will still work via self.agent_response)
200
+ return
201
+ if self.config.use_tools:
202
+ sys_msg = self._create_system_and_tools_message()
203
+ self.set_system_message(sys_msg.content)
204
+ if not self.config.use_functions_api:
205
+ return
206
+ functions, _, _, _, _ = self._function_args()
207
+ if functions is None:
208
+ return
209
+ # add the functions to the assistant:
210
+ if self.assistant is None:
211
+ raise ValueError("Assistant is None")
212
+ tools = self.assistant.tools
213
+ tools.extend(
214
+ [
215
+ {
216
+ "type": "function", # type: ignore
217
+ "function": f.dict(),
218
+ }
219
+ for f in functions
220
+ ]
221
+ )
222
+ self.assistant = self.assistants.update(
223
+ self.assistant.id,
224
+ tools=tools, # type: ignore
225
+ )
226
+
227
+ def _cache_thread_key(self) -> str:
228
+ """Key to use for caching or retrieving thread id"""
229
+ org = self.client.organization or ""
230
+ uid = generate_user_id(org)
231
+ name = self.config.name
232
+ return "Thread:" + name + ":" + uid
233
+
234
+ def _cache_assistant_key(self) -> str:
235
+ """Key to use for caching or retrieving assistant id"""
236
+ org = self.client.organization or ""
237
+ uid = generate_user_id(org)
238
+ name = self.config.name
239
+ return "Assistant:" + name + ":" + uid
240
+
241
+ @no_type_check
242
+ def _cache_messages_key(self) -> str:
243
+ """Key to use when caching or retrieving thread messages"""
244
+ if self.thread is None:
245
+ raise ValueError("Thread is None")
246
+ return "Messages:" + self.thread.metadata["hash"]
247
+
248
+ @no_type_check
249
+ def _cache_thread_lookup(self) -> str | None:
250
+ """Try to retrieve cached thread_id associated with
251
+ this user + machine + organization"""
252
+ key = self._cache_thread_key()
253
+ if self.llm.cache is None:
254
+ return None
255
+ return self.llm.cache.retrieve(key)
256
+
257
+ @no_type_check
258
+ def _cache_assistant_lookup(self) -> str | None:
259
+ """Try to retrieve cached assistant_id associated with
260
+ this user + machine + organization"""
261
+ if self.llm.cache is None:
262
+ return None
263
+ key = self._cache_assistant_key()
264
+ return self.llm.cache.retrieve(key)
265
+
266
+ @no_type_check
267
+ def _cache_messages_lookup(self) -> LLMResponse | None:
268
+ """Try to retrieve cached response for the message-list-hash"""
269
+ if not settings.cache or self.llm.cache is None:
270
+ return None
271
+ key = self._cache_messages_key()
272
+ cached_dict = self.llm.cache.retrieve(key)
273
+ if cached_dict is None:
274
+ return None
275
+ return LLMResponse.parse_obj(cached_dict)
276
+
277
+ def _cache_store(self) -> None:
278
+ """
279
+ Cache the assistant_id, thread_id associated with
280
+ this user + machine + organization
281
+ """
282
+ if self.llm.cache is None:
283
+ return
284
+ if self.thread is None or self.assistant is None:
285
+ raise ValueError("Thread or Assistant is None")
286
+ thread_key = self._cache_thread_key()
287
+ self.llm.cache.store(thread_key, self.thread.id)
288
+
289
+ assistant_key = self._cache_assistant_key()
290
+ self.llm.cache.store(assistant_key, self.assistant.id)
291
+
292
+ @staticmethod
293
+ def thread_msg_to_llm_msg(msg: Message) -> LLMMessage:
294
+ """
295
+ Convert a Message to an LLMMessage
296
+ """
297
+ return LLMMessage(
298
+ content=msg.content[0].text.value, # type: ignore
299
+ role=Role(msg.role),
300
+ )
301
+
302
+ def _update_messages_hash(self, msg: Message | LLMMessage) -> None:
303
+ """
304
+ Update the hash-state in the thread with the given message.
305
+ """
306
+ if self.thread is None:
307
+ raise ValueError("Thread is None")
308
+ if isinstance(msg, Message):
309
+ llm_msg = self.thread_msg_to_llm_msg(msg)
310
+ else:
311
+ llm_msg = msg
312
+ hash = self.thread.metadata["hash"] # type: ignore
313
+ most_recent_msg = llm_msg.content
314
+ most_recent_role = llm_msg.role
315
+ hash = update_hash(hash, f"{most_recent_role}:{most_recent_msg}")
316
+ # TODO is this inplace?
317
+ self.thread = self.threads.update(
318
+ self.thread.id,
319
+ metadata={
320
+ "hash": hash,
321
+ },
322
+ )
323
+ assert self.thread.metadata["hash"] == hash # type: ignore
324
+
325
+ def _maybe_create_thread(self, id: str | None = None) -> None:
326
+ """Retrieve or create a thread if one does not exist,
327
+ or retrieve it from cache"""
328
+ if id is not None:
329
+ try:
330
+ self.thread = self.threads.retrieve(thread_id=id)
331
+ except Exception:
332
+ logger.warning(
333
+ f"""
334
+ Could not retrieve thread with id {id},
335
+ so creating a new one.
336
+ """
337
+ )
338
+ self.thread = None
339
+ if self.thread is not None:
340
+ return
341
+ cached = self._cache_thread_lookup()
342
+ if cached is not None:
343
+ if self.config.use_cached_thread:
344
+ self.thread = self.client.beta.threads.retrieve(thread_id=cached)
345
+ else:
346
+ logger.warning(
347
+ f"""
348
+ Found cached thread id {cached},
349
+ but config.use_cached_thread = False, so deleting it.
350
+ """
351
+ )
352
+ try:
353
+ self.client.beta.threads.delete(thread_id=cached)
354
+ except Exception:
355
+ logger.warning(
356
+ f"""
357
+ Could not delete thread with id {cached}, ignoring.
358
+ """
359
+ )
360
+ if self.llm.cache is not None:
361
+ self.llm.cache.delete_keys([self._cache_thread_key()])
362
+ if self.thread is None:
363
+ if self.assistant is None:
364
+ raise ValueError("Assistant is None")
365
+ self.thread = self.client.beta.threads.create()
366
+ hash_key_str = (
367
+ (self.assistant.instructions or "")
368
+ + str(self.config.use_tools)
369
+ + str(self.config.use_functions_api)
370
+ )
371
+ hash_hex = update_hash(None, s=hash_key_str)
372
+ self.thread = self.threads.update(
373
+ self.thread.id,
374
+ metadata={
375
+ "hash": hash_hex,
376
+ },
377
+ )
378
+ assert self.thread.metadata["hash"] == hash_hex # type: ignore
379
+
380
+ def _maybe_create_assistant(self, id: str | None = None) -> None:
381
+ """Retrieve or create an assistant if one does not exist,
382
+ or retrieve it from cache"""
383
+ if id is not None:
384
+ try:
385
+ self.assistant = self.assistants.retrieve(assistant_id=id)
386
+ except Exception:
387
+ logger.warning(
388
+ f"""
389
+ Could not retrieve assistant with id {id},
390
+ so creating a new one.
391
+ """
392
+ )
393
+ self.assistant = None
394
+ if self.assistant is not None:
395
+ return
396
+ cached = self._cache_assistant_lookup()
397
+ if cached is not None:
398
+ if self.config.use_cached_assistant:
399
+ self.assistant = self.client.beta.assistants.retrieve(
400
+ assistant_id=cached
401
+ )
402
+ else:
403
+ logger.warning(
404
+ f"""
405
+ Found cached assistant id {cached},
406
+ but config.use_cached_assistant = False, so deleting it.
407
+ """
408
+ )
409
+ try:
410
+ self.client.beta.assistants.delete(assistant_id=cached)
411
+ except Exception:
412
+ logger.warning(
413
+ f"""
414
+ Could not delete assistant with id {cached}, ignoring.
415
+ """
416
+ )
417
+ if self.llm.cache is not None:
418
+ self.llm.cache.delete_keys([self._cache_assistant_key()])
419
+ if self.assistant is None:
420
+ self.assistant = self.client.beta.assistants.create(
421
+ name=self.config.name,
422
+ instructions=self.config.system_message,
423
+ tools=[],
424
+ model=self.config.llm.chat_model,
425
+ )
426
+
427
+ def _get_run(self) -> Run:
428
+ """Retrieve the run object associated with this thread and run,
429
+ to see its latest status.
430
+ """
431
+ if self.thread is None or self.run is None:
432
+ raise ValueError("Thread or Run is None")
433
+ return self.runs.retrieve(thread_id=self.thread.id, run_id=self.run.id)
434
+
435
+ def _get_run_steps(self) -> List[RunStep]:
436
+ if self.thread is None or self.run is None:
437
+ raise ValueError("Thread or Run is None")
438
+ result = self.runs.steps.list(thread_id=self.thread.id, run_id=self.run.id)
439
+ if result is None:
440
+ return []
441
+ return result.data
442
+
443
+ def _get_code_logs(self) -> List[Tuple[str, str]]:
444
+ """
445
+ Get list of input, output strings from code logs
446
+ """
447
+ run_steps = self._get_run_steps()
448
+ # each step may have multiple tool-calls,
449
+ # each tool-call may have multiple outputs
450
+ tool_calls = [ # list of list of tool-calls
451
+ s.step_details.tool_calls
452
+ for s in run_steps
453
+ if s.step_details is not None and hasattr(s.step_details, "tool_calls")
454
+ ]
455
+ code_logs = []
456
+ for tcl in tool_calls: # each tool-call-list
457
+ for tc in tcl:
458
+ if tc is None or tc.type != ToolType.CODE_INTERPRETER:
459
+ continue
460
+ io = tc.code_interpreter # type: ignore
461
+ input = io.input
462
+ # TODO for CodeInterpreterOutputImage, there is no "logs"
463
+ # revisit when we handle images.
464
+ outputs = "\n\n".join(
465
+ o.logs
466
+ for o in io.outputs
467
+ if o.type == "logs" and hasattr(o, "logs")
468
+ )
469
+ code_logs.append((input, outputs))
470
+ # return the reversed list, since they are stored in reverse chron order
471
+ return code_logs[::-1]
472
+
473
+ def _get_code_logs_str(self) -> str:
474
+ """
475
+ Get string representation of code logs
476
+ """
477
+ code_logs = self._get_code_logs()
478
+ return "\n\n".join(
479
+ f"INPUT:\n{input}\n\nOUTPUT:\n{output}" for input, output in code_logs
480
+ )
481
+
482
+ def _add_thread_message(self, msg: str, role: Role) -> None:
483
+ """
484
+ Add a message with the given role to the thread.
485
+ Args:
486
+ msg (str): message to add
487
+ role (Role): role of the message
488
+ """
489
+ if self.thread is None:
490
+ raise ValueError("Thread is None")
491
+ # CACHING TRICK! Since the API only allows inserting USER messages,
492
+ # we prepend the role to the message, so that we can store ASSISTANT msgs
493
+ # as well! When the LLM sees the thread messages, they will contain
494
+ # the right sequence of alternating roles, so that it has no trouble
495
+ # responding when it is its turn.
496
+ msg = f"{role.value.upper()}: {msg}"
497
+ thread_msg = self.thread_messages.create(
498
+ content=msg,
499
+ thread_id=self.thread.id,
500
+ # We ALWAYS store user role since only user role allowed currently
501
+ role=Role.USER.value,
502
+ )
503
+ self._update_messages_hash(thread_msg)
504
+
505
+ def _get_thread_messages(self, n: int = 20) -> List[LLMMessage]:
506
+ """
507
+ Get the last n messages in the thread, in cleaned-up form (LLMMessage).
508
+ Args:
509
+ n (int): number of messages to retrieve
510
+ Returns:
511
+ List[LLMMessage]: list of messages
512
+ """
513
+ if self.thread is None:
514
+ raise ValueError("Thread is None")
515
+ result = self.thread_messages.list(
516
+ thread_id=self.thread.id,
517
+ limit=n,
518
+ )
519
+ num = len(result.data)
520
+ if result.has_more and num < n: # type: ignore
521
+ logger.warning(f"Retrieving last {num} messages, but there are more")
522
+ thread_msgs = result.data
523
+ for msg in thread_msgs:
524
+ self.process_citations(msg)
525
+ return [
526
+ LLMMessage(
527
+ # TODO: could be image, deal with it later
528
+ content=m.content[0].text.value, # type: ignore
529
+ role=Role(m.role),
530
+ )
531
+ for m in thread_msgs
532
+ ]
533
+
534
+ def _wait_for_run(
535
+ self,
536
+ until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
537
+ until: List[RunStatus] = [],
538
+ timeout: int = 30,
539
+ ) -> RunStatus:
540
+ """
541
+ Poll the run until it either:
542
+ - EXITs the statuses specified in `until_not`, or
543
+ - ENTERs the statuses specified in `until`, or
544
+ """
545
+ if self.thread is None or self.run is None:
546
+ raise ValueError("Thread or Run is None")
547
+ while True:
548
+ run = self._get_run()
549
+ if run.status not in until_not or run.status in until:
550
+ return cast(RunStatus, run.status)
551
+ time.sleep(1)
552
+ timeout -= 1
553
+ if timeout <= 0:
554
+ return cast(RunStatus, RunStatus.TIMEOUT)
555
+
556
+ async def _wait_for_run_async(
557
+ self,
558
+ until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
559
+ until: List[RunStatus] = [],
560
+ timeout: int = 30,
561
+ ) -> RunStatus:
562
+ """Async version of _wait_for_run"""
563
+ if self.thread is None or self.run is None:
564
+ raise ValueError("Thread or Run is None")
565
+ while True:
566
+ run = self._get_run()
567
+ if run.status not in until_not or run.status in until:
568
+ return cast(RunStatus, run.status)
569
+ await asyncio.sleep(1)
570
+ timeout -= 1
571
+ if timeout <= 0:
572
+ return cast(RunStatus, RunStatus.TIMEOUT)
573
+
574
+ def set_system_message(self, msg: str) -> None:
575
+ """
576
+ Override ChatAgent's method.
577
+ The Task may use this method to set the system message
578
+ of the chat assistant.
579
+ """
580
+ super().set_system_message(msg)
581
+ if self.assistant is None:
582
+ raise ValueError("Assistant is None")
583
+ self.assistant = self.assistants.update(self.assistant.id, instructions=msg)
584
+
585
+ def _start_run(self) -> None:
586
+ """
587
+ Run the assistant on the thread.
588
+ """
589
+ if self.thread is None or self.assistant is None:
590
+ raise ValueError("Thread or Assistant is None")
591
+ self.run = self.runs.create(
592
+ thread_id=self.thread.id,
593
+ assistant_id=self.assistant.id,
594
+ )
595
+
596
+ def _run_result(self) -> LLMResponse:
597
+ """Result from run completed on the thread."""
598
+ status = self._wait_for_run(
599
+ timeout=self.config.timeout,
600
+ )
601
+ return self._process_run_result(status)
602
+
603
+ async def _run_result_async(self) -> LLMResponse:
604
+ """(Async) Result from run completed on the thread."""
605
+ status = await self._wait_for_run_async(
606
+ timeout=self.config.timeout,
607
+ )
608
+ return self._process_run_result(status)
609
+
610
+ def _process_run_result(self, status: RunStatus) -> LLMResponse:
611
+ """Process the result of the run."""
612
+ function_call: LLMFunctionCall | None = None
613
+ response = ""
614
+ tool_id = ""
615
+ # IMPORTANT: FIRST save hash key to store result,
616
+ # before it gets updated with the response
617
+ key = self._cache_messages_key()
618
+ if status == RunStatus.TIMEOUT:
619
+ logger.warning("Timeout waiting for run to complete, return empty string")
620
+ elif status == RunStatus.COMPLETED:
621
+ messages = self._get_thread_messages(n=1)
622
+ response = messages[0].content
623
+ # update hash to include the response.
624
+ self._update_messages_hash(messages[0])
625
+ elif status == RunStatus.REQUIRES_ACTION:
626
+ tool_calls = self._parse_run_required_action()
627
+ # pick the FIRST tool call with type "function"
628
+ tool_call_fn = [t for t in tool_calls if t.type == ToolType.FUNCTION][0]
629
+ # TODO Handling only first tool/fn call for now
630
+ # revisit later: multi-tools affects the task.run() loop.
631
+ function_call = tool_call_fn.function
632
+ tool_id = tool_call_fn.id
633
+ result = LLMResponse(
634
+ message=response,
635
+ tool_id=tool_id,
636
+ function_call=function_call,
637
+ usage=None, # TODO
638
+ cached=False, # TODO - revisit when able to insert Assistant responses
639
+ )
640
+ if self.llm.cache is not None:
641
+ self.llm.cache.store(key, result.dict())
642
+ return result
643
+
644
+ def _parse_run_required_action(self) -> List[AssistantToolCall]:
645
+ """
646
+ Parse the required_action field of the run, i.e. get the list of tool calls.
647
+ Currently only tool calls are supported.
648
+ """
649
+ # see https://platform.openai.com/docs/assistants/tools/function-calling
650
+ run = self._get_run()
651
+ if run.status != RunStatus.REQUIRES_ACTION: # type: ignore
652
+ return []
653
+
654
+ if (action := run.required_action.type) != "submit_tool_outputs":
655
+ raise ValueError(f"Unexpected required_action type {action}")
656
+ tool_calls = run.required_action.submit_tool_outputs.tool_calls
657
+ return [
658
+ AssistantToolCall(
659
+ id=tool_call.id,
660
+ type=ToolType(tool_call.type),
661
+ function=LLMFunctionCall.from_dict(tool_call.function.model_dump()),
662
+ )
663
+ for tool_call in tool_calls
664
+ ]
665
+
666
+ def _submit_tool_outputs(self, msg: LLMMessage) -> None:
667
+ """
668
+ Submit the tool (fn) outputs to the run/thread
669
+ """
670
+ if self.run is None or self.thread is None:
671
+ raise ValueError("Run or Thread is None")
672
+ tool_outputs = [
673
+ {
674
+ "tool_call_id": msg.tool_id,
675
+ "output": msg.content,
676
+ }
677
+ ]
678
+ # run enters queued, in_progress state after this
679
+ self.runs.submit_tool_outputs(
680
+ thread_id=self.thread.id,
681
+ run_id=self.run.id,
682
+ tool_outputs=tool_outputs, # type: ignore
683
+ )
684
+
685
+ def process_citations(self, thread_msg: Message) -> None:
686
+ """
687
+ Process citations in the thread message.
688
+ Modifies the thread message in-place.
689
+ """
690
+ # could there be multiple content items?
691
+ # TODO content could be MessageContentImageFile; handle that later
692
+ annotated_content = thread_msg.content[0].text # type: ignore
693
+ annotations = annotated_content.annotations
694
+ citations = []
695
+ # Iterate over the annotations and add footnotes
696
+ for index, annotation in enumerate(annotations):
697
+ # Replace the text with a footnote
698
+ annotated_content.value = annotated_content.value.replace(
699
+ annotation.text, f" [{index}]"
700
+ )
701
+ # Gather citations based on annotation attributes
702
+ if file_citation := getattr(annotation, "file_citation", None):
703
+ try:
704
+ cited_file = self.client.files.retrieve(file_citation.file_id)
705
+ except Exception:
706
+ logger.warning(
707
+ f"""
708
+ Could not retrieve cited file with id {file_citation.file_id},
709
+ ignoring.
710
+ """
711
+ )
712
+ continue
713
+ citations.append(
714
+ f"[{index}] '{file_citation.quote}',-- from {cited_file.filename}"
715
+ )
716
+ elif file_path := getattr(annotation, "file_path", None):
717
+ cited_file = self.client.files.retrieve(file_path.file_id)
718
+ citations.append(
719
+ f"[{index}] Click <here> to download {cited_file.filename}"
720
+ )
721
+ # Note: File download functionality not implemented above for brevity
722
+ sep = "\n" if len(citations) > 0 else ""
723
+ annotated_content.value += sep + "\n".join(citations)
724
+
725
+ def _llm_response_preprocess(
726
+ self,
727
+ message: Optional[str | ChatDocument] = None,
728
+ ) -> LLMResponse | None:
729
+ """
730
+ Preprocess message and return response if found in cache, else None.
731
+ """
732
+ is_tool_output = False
733
+ if message is not None:
734
+ # note: to_LLMMessage returns a list of LLMMessage,
735
+ # which is allowed to have len > 1, in case the msg
736
+ # represents results of multiple (non-assistant) tool-calls.
737
+ # But for OAI Assistant, we only assume exactly one tool-call at a time.
738
+ # TODO look into multi-tools
739
+ llm_msg = ChatDocument.to_LLMMessage(message)[0]
740
+ tool_id = llm_msg.tool_id
741
+ if tool_id in self.pending_tool_ids:
742
+ if isinstance(message, ChatDocument):
743
+ message.pop_tool_ids()
744
+ result_msg = f"Result for Tool_id {tool_id}: {llm_msg.content}"
745
+ if tool_id in self.cached_tool_ids:
746
+ self.cached_tool_ids.remove(tool_id)
747
+ # add actual result of cached fn-call
748
+ self._add_thread_message(result_msg, role=Role.USER)
749
+ else:
750
+ is_tool_output = True
751
+ # submit tool/fn result to the thread/run
752
+ self._submit_tool_outputs(llm_msg)
753
+ # We cannot ACTUALLY add this result to thread now
754
+ # since run is in `action_required` state,
755
+ # so we just update the message hash
756
+ self._update_messages_hash(
757
+ LLMMessage(content=result_msg, role=Role.USER)
758
+ )
759
+ self.pending_tool_ids.remove(tool_id)
760
+ else:
761
+ # add message to the thread
762
+ self._add_thread_message(llm_msg.content, role=Role.USER)
763
+
764
+ # When message is None, the thread may have no user msgs,
765
+ # Note: system message is NOT placed in the thread by the OpenAI system.
766
+
767
+ # check if we have cached the response.
768
+ # TODO: handle the case of structured result (fn-call, tool, etc)
769
+ response = self._cache_messages_lookup()
770
+ if response is not None:
771
+ response.cached = True
772
+ # store the result in the thread so
773
+ # it looks like assistant produced it
774
+ if self.config.cache_responses:
775
+ self._add_thread_message(
776
+ json.dumps(response.dict()), role=Role.ASSISTANT
777
+ )
778
+ return response # type: ignore
779
+ else:
780
+ # create a run for this assistant on this thread,
781
+ # i.e. actually "run"
782
+ if not is_tool_output:
783
+ # DO NOT start a run if we submitted tool outputs,
784
+ # since submission of tool outputs resumes a run from
785
+ # status = "requires_action"
786
+ self._start_run()
787
+ return None
788
+
789
+ def _llm_response_postprocess(
790
+ self,
791
+ response: LLMResponse,
792
+ cached: bool,
793
+ message: Optional[str | ChatDocument] = None,
794
+ ) -> Optional[ChatDocument]:
795
+ # code from ChatAgent.llm_response_messages
796
+ if response.function_call is not None:
797
+ self.pending_tool_ids += [response.tool_id]
798
+ if cached:
799
+ # add to cached tools list so we don't create an Assistant run
800
+ # in _llm_response_preprocess
801
+ self.cached_tool_ids += [response.tool_id]
802
+ response_str = str(response.function_call)
803
+ else:
804
+ response_str = response.message
805
+ cache_str = "[red](cached)[/red]" if cached else ""
806
+ if not settings.quiet:
807
+ if not cached and self._get_code_logs_str():
808
+ print(
809
+ f"[magenta]CODE-INTERPRETER LOGS:\n"
810
+ "-------------------------------\n"
811
+ f"{self._get_code_logs_str()}[/magenta]"
812
+ )
813
+ print(f"{cache_str}[green]" + response_str + "[/green]")
814
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=False)
815
+ # Note message.metadata.tool_ids may have been popped above
816
+ tool_ids = (
817
+ []
818
+ if (message is None or isinstance(message, str))
819
+ else message.metadata.tool_ids
820
+ )
821
+
822
+ if response.tool_id != "":
823
+ tool_ids.append(response.tool_id)
824
+ cdoc.metadata.tool_ids = tool_ids
825
+ return cdoc
826
+
827
+ def llm_response(
828
+ self, message: Optional[str | ChatDocument] = None
829
+ ) -> Optional[ChatDocument]:
830
+ """
831
+ Override ChatAgent's method: this is the main LLM response method.
832
+ In the ChatAgent, this updates `self.message_history` and then calls
833
+ `self.llm_response_messages`, but since we are relying on the Assistant API
834
+ to maintain conversation state, this method is simpler: Simply start a run
835
+ on the message-thread, and wait for it to complete.
836
+
837
+ Args:
838
+ message (Optional[str | ChatDocument], optional): message to respond to
839
+ (if absent, the LLM response will be based on the
840
+ instructions in the system_message). Defaults to None.
841
+ Returns:
842
+ Optional[ChatDocument]: LLM response
843
+ """
844
+ response = self._llm_response_preprocess(message)
845
+ cached = True
846
+ if response is None:
847
+ cached = False
848
+ response = self._run_result()
849
+ return self._llm_response_postprocess(response, cached=cached, message=message)
850
+
851
+ async def llm_response_async(
852
+ self, message: Optional[str | ChatDocument] = None
853
+ ) -> Optional[ChatDocument]:
854
+ """
855
+ Async version of llm_response.
856
+ """
857
+ response = self._llm_response_preprocess(message)
858
+ cached = True
859
+ if response is None:
860
+ cached = False
861
+ response = await self._run_result_async()
862
+ return self._llm_response_postprocess(response, cached=cached, message=message)
863
+
864
+ def agent_response(
865
+ self,
866
+ msg: Optional[str | ChatDocument] = None,
867
+ ) -> Optional[ChatDocument]:
868
+ response = super().agent_response(msg)
869
+ if msg is None:
870
+ return response
871
+ if response is None:
872
+ return None
873
+ try:
874
+ # When the agent response is to a tool message,
875
+ # we prefix it with "TOOL Result: " so that it is clear to the
876
+ # LLM that this is the result of the last TOOL;
877
+ # This ensures our caching trick works.
878
+ if self.config.use_tools and len(self.get_tool_messages(msg)) > 0:
879
+ response.content = "TOOL Result: " + response.content
880
+ return response
881
+ except Exception:
882
+ return response