langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,842 @@
1
+ import asyncio
2
+ import json
3
+
4
+ # setup logger
5
+ import logging
6
+ import time
7
+ from enum import Enum
8
+ from typing import Any, Dict, List, Optional, Tuple, Type, cast, no_type_check
9
+
10
+ from openai.types.beta import Assistant, Thread
11
+ from openai.types.beta.threads import Message, Run
12
+ from openai.types.beta.threads.runs import RunStep
13
+ from pydantic import BaseModel
14
+ from rich import print
15
+
16
+ from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
17
+ from langroid.agent.chat_document import ChatDocument
18
+ from langroid.agent.tool_message import ToolMessage
19
+ from langroid.language_models.base import LLMFunctionCall, LLMMessage, LLMResponse, Role
20
+ from langroid.language_models.openai_gpt import (
21
+ OpenAIChatModel,
22
+ OpenAIGPT,
23
+ OpenAIGPTConfig,
24
+ )
25
+ from langroid.utils.configuration import settings
26
+ from langroid.utils.system import generate_user_id, update_hash
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ToolType(str, Enum):
32
+ RETRIEVAL = "retrieval"
33
+ CODE_INTERPRETER = "code_interpreter"
34
+ FUNCTION = "function"
35
+
36
+
37
+ class AssistantTool(BaseModel):
38
+ type: ToolType
39
+ function: Dict[str, Any] | None = None
40
+
41
+ def dct(self) -> Dict[str, Any]:
42
+ d = super().dict()
43
+ d["type"] = d["type"].value
44
+ if self.type != ToolType.FUNCTION:
45
+ d.pop("function")
46
+ return d
47
+
48
+
49
+ class AssistantToolCall(BaseModel):
50
+ id: str
51
+ type: ToolType
52
+ function: LLMFunctionCall
53
+
54
+
55
+ class RunStatus(str, Enum):
56
+ QUEUED = "queued"
57
+ IN_PROGRESS = "in_progress"
58
+ COMPLETED = "completed"
59
+ REQUIRES_ACTION = "requires_action"
60
+ EXPIRED = "expired"
61
+ CANCELLING = "cancelling"
62
+ CANCELLED = "cancelled"
63
+ FAILED = "failed"
64
+ TIMEOUT = "timeout"
65
+
66
+
67
+ class OpenAIAssistantConfig(ChatAgentConfig):
68
+ use_cached_assistant: bool = False # set in script via user dialog
69
+ assistant_id: str | None = None
70
+ use_tools = False
71
+ use_functions_api = True
72
+ use_cached_thread: bool = False # set in script via user dialog
73
+ thread_id: str | None = None
74
+ # set to True once we can add Assistant msgs in threads
75
+ cache_responses: bool = True
76
+ timeout: int = 30 # can be different from llm.timeout
77
+ llm = OpenAIGPTConfig(chat_model=OpenAIChatModel.GPT4_TURBO)
78
+ tools: List[AssistantTool] = []
79
+ files: List[str] = []
80
+
81
+
82
+ class OpenAIAssistant(ChatAgent):
83
+ """
84
+ A ChatAgent powered by OpenAI Assistant API:
85
+ mainly, in `llm_response` method, we avoid maintaining conversation state,
86
+ and instead let the Assistant API do it for us.
87
+ Also handles persistent storage of Assistant and Threads:
88
+ stores their ids (for given user, org) in a cache, and
89
+ reuses them based on config.use_cached_assistant and config.use_cached_thread.
90
+
91
+ This class can be used as a drop-in replacement for ChatAgent.
92
+ """
93
+
94
+ def __init__(self, config: OpenAIAssistantConfig):
95
+ super().__init__(config)
96
+ self.config: OpenAIAssistantConfig = config
97
+ self.llm: OpenAIGPT = OpenAIGPT(self.config.llm)
98
+ # handles for various entities and methods
99
+ self.client = self.llm.client
100
+ self.runs = self.llm.client.beta.threads.runs
101
+ self.threads = self.llm.client.beta.threads
102
+ self.thread_messages = self.llm.client.beta.threads.messages
103
+ self.assistants = self.llm.client.beta.assistants
104
+ # which tool_ids are awaiting output submissions
105
+ self.pending_tool_ids: List[str] = []
106
+ self.cached_tool_ids: List[str] = []
107
+
108
+ self.thread: Thread | None = None
109
+ self.assistant: Assistant | None = None
110
+ self.run: Run | None = None
111
+
112
+ self._maybe_create_assistant(self.config.assistant_id)
113
+ self._maybe_create_thread(self.config.thread_id)
114
+ self._cache_store()
115
+
116
+ self.add_assistant_files(self.config.files)
117
+ self.add_assistant_tools(self.config.tools)
118
+
119
+ def add_assistant_files(self, files: List[str]) -> None:
120
+ """Add file_ids to assistant"""
121
+ if self.assistant is None:
122
+ raise ValueError("Assistant is None")
123
+ self.files = [
124
+ self.client.files.create(file=open(f, "rb"), purpose="assistants")
125
+ for f in files
126
+ ]
127
+ self.config.files = list(set(self.config.files + files))
128
+ self.assistant = self.assistants.update(
129
+ self.assistant.id,
130
+ file_ids=[f.id for f in self.files],
131
+ )
132
+
133
+ def add_assistant_tools(self, tools: List[AssistantTool]) -> None:
134
+ """Add tools to assistant"""
135
+ if self.assistant is None:
136
+ raise ValueError("Assistant is None")
137
+ all_tool_dicts = [t.dct() for t in self.config.tools]
138
+ for t in tools:
139
+ if t.dct() not in all_tool_dicts:
140
+ self.config.tools.append(t)
141
+ self.assistant = self.assistants.update(
142
+ self.assistant.id,
143
+ tools=[tool.dct() for tool in self.config.tools], # type: ignore
144
+ )
145
+
146
+ def enable_message(
147
+ self,
148
+ message_class: Optional[Type[ToolMessage]],
149
+ use: bool = True,
150
+ handle: bool = True,
151
+ force: bool = False,
152
+ require_recipient: bool = False,
153
+ include_defaults: bool = True,
154
+ ) -> None:
155
+ """Override ChatAgent's method: extract the function-related args.
156
+ See that method for details. But specifically about the `include_defaults` arg:
157
+ Normally the OpenAI completion API ignores these fields, but the Assistant
158
+ fn-calling seems to pay attn to these, and if we don't want this,
159
+ we should set this to False.
160
+ """
161
+ super().enable_message(
162
+ message_class,
163
+ use=use,
164
+ handle=handle,
165
+ force=force,
166
+ require_recipient=require_recipient,
167
+ include_defaults=include_defaults,
168
+ )
169
+ if message_class is None or not use:
170
+ # no specific msg class, or
171
+ # we are not enabling USAGE/GENERATION of this tool/fn,
172
+ # then there's no need to attach the fn to the assistant
173
+ # (HANDLING the fn will still work via self.agent_response)
174
+ return
175
+ if self.config.use_tools:
176
+ sys_msg = self._create_system_and_tools_message()
177
+ self.set_system_message(sys_msg.content)
178
+ if not self.config.use_functions_api:
179
+ return
180
+ functions, _ = self._function_args()
181
+ if functions is None:
182
+ return
183
+ # add the functions to the assistant:
184
+ if self.assistant is None:
185
+ raise ValueError("Assistant is None")
186
+ tools = self.assistant.tools
187
+ tools.extend(
188
+ [
189
+ {
190
+ "type": "function", # type: ignore
191
+ "function": f.dict(),
192
+ }
193
+ for f in functions
194
+ ]
195
+ )
196
+ self.assistant = self.assistants.update(
197
+ self.assistant.id,
198
+ tools=tools, # type: ignore
199
+ )
200
+
201
+ def _cache_thread_key(self) -> str:
202
+ """Key to use for caching or retrieving thread id"""
203
+ org = self.llm.client.organization or ""
204
+ uid = generate_user_id(org)
205
+ name = self.config.name
206
+ return "Thread:" + name + ":" + uid
207
+
208
+ def _cache_assistant_key(self) -> str:
209
+ """Key to use for caching or retrieving assistant id"""
210
+ org = self.llm.client.organization or ""
211
+ uid = generate_user_id(org)
212
+ name = self.config.name
213
+ return "Assistant:" + name + ":" + uid
214
+
215
+ @no_type_check
216
+ def _cache_messages_key(self) -> str:
217
+ """Key to use when caching or retrieving thread messages"""
218
+ if self.thread is None:
219
+ raise ValueError("Thread is None")
220
+ return "Messages:" + self.thread.metadata["hash"]
221
+
222
+ @no_type_check
223
+ def _cache_thread_lookup(self) -> str | None:
224
+ """Try to retrieve cached thread_id associated with
225
+ this user + machine + organization"""
226
+ key = self._cache_thread_key()
227
+ return self.llm.cache.retrieve(key)
228
+
229
+ @no_type_check
230
+ def _cache_assistant_lookup(self) -> str | None:
231
+ """Try to retrieve cached assistant_id associated with
232
+ this user + machine + organization"""
233
+ key = self._cache_assistant_key()
234
+ return self.llm.cache.retrieve(key)
235
+
236
+ @no_type_check
237
+ def _cache_messages_lookup(self) -> LLMResponse | None:
238
+ """Try to retrieve cached response for the message-list-hash"""
239
+ if not settings.cache:
240
+ return None
241
+ key = self._cache_messages_key()
242
+ cached_dict = self.llm.cache.retrieve(key)
243
+ if cached_dict is None:
244
+ return None
245
+ return LLMResponse.parse_obj(cached_dict)
246
+
247
+ def _cache_store(self) -> None:
248
+ """
249
+ Cache the assistant_id, thread_id associated with
250
+ this user + machine + organization
251
+ """
252
+ if self.thread is None or self.assistant is None:
253
+ raise ValueError("Thread or Assistant is None")
254
+ thread_key = self._cache_thread_key()
255
+ self.llm.cache.store(thread_key, self.thread.id)
256
+
257
+ assistant_key = self._cache_assistant_key()
258
+ self.llm.cache.store(assistant_key, self.assistant.id)
259
+
260
+ @staticmethod
261
+ def thread_msg_to_llm_msg(msg: Message) -> LLMMessage:
262
+ """
263
+ Convert a Message to an LLMMessage
264
+ """
265
+ return LLMMessage(
266
+ content=msg.content[0].text.value, # type: ignore
267
+ role=Role(msg.role),
268
+ )
269
+
270
+ def _update_messages_hash(self, msg: Message | LLMMessage) -> None:
271
+ """
272
+ Update the hash-state in the thread with the given message.
273
+ """
274
+ if self.thread is None:
275
+ raise ValueError("Thread is None")
276
+ if isinstance(msg, Message):
277
+ llm_msg = self.thread_msg_to_llm_msg(msg)
278
+ else:
279
+ llm_msg = msg
280
+ hash = self.thread.metadata["hash"] # type: ignore
281
+ most_recent_msg = llm_msg.content
282
+ most_recent_role = llm_msg.role
283
+ hash = update_hash(hash, f"{most_recent_role}:{most_recent_msg}")
284
+ # TODO is this inplace?
285
+ self.thread = self.threads.update(
286
+ self.thread.id,
287
+ metadata={
288
+ "hash": hash,
289
+ },
290
+ )
291
+ assert self.thread.metadata["hash"] == hash # type: ignore
292
+
293
+ def _maybe_create_thread(self, id: str | None = None) -> None:
294
+ """Retrieve or create a thread if one does not exist,
295
+ or retrieve it from cache"""
296
+ if id is not None:
297
+ try:
298
+ self.thread = self.threads.retrieve(thread_id=id)
299
+ except Exception:
300
+ logger.warning(
301
+ f"""
302
+ Could not retrieve thread with id {id},
303
+ so creating a new one.
304
+ """
305
+ )
306
+ self.thread = None
307
+ if self.thread is not None:
308
+ return
309
+ cached = self._cache_thread_lookup()
310
+ if cached is not None:
311
+ if self.config.use_cached_thread:
312
+ self.thread = self.llm.client.beta.threads.retrieve(thread_id=cached)
313
+ else:
314
+ logger.warning(
315
+ f"""
316
+ Found cached thread id {cached},
317
+ but config.use_cached_thread = False, so deleting it.
318
+ """
319
+ )
320
+ try:
321
+ self.llm.client.beta.threads.delete(thread_id=cached)
322
+ except Exception:
323
+ logger.warning(
324
+ f"""
325
+ Could not delete thread with id {cached}, ignoring.
326
+ """
327
+ )
328
+ self.llm.cache.delete_keys([self._cache_thread_key()])
329
+ if self.thread is None:
330
+ if self.assistant is None:
331
+ raise ValueError("Assistant is None")
332
+ self.thread = self.llm.client.beta.threads.create()
333
+ hash_key_str = (
334
+ (self.assistant.instructions or "")
335
+ + str(self.config.use_tools)
336
+ + str(self.config.use_functions_api)
337
+ )
338
+ hash_hex = update_hash(None, s=hash_key_str)
339
+ self.thread = self.threads.update(
340
+ self.thread.id,
341
+ metadata={
342
+ "hash": hash_hex,
343
+ },
344
+ )
345
+ assert self.thread.metadata["hash"] == hash_hex # type: ignore
346
+
347
+ def _maybe_create_assistant(self, id: str | None = None) -> None:
348
+ """Retrieve or create an assistant if one does not exist,
349
+ or retrieve it from cache"""
350
+ if id is not None:
351
+ try:
352
+ self.assistant = self.assistants.retrieve(assistant_id=id)
353
+ except Exception:
354
+ logger.warning(
355
+ f"""
356
+ Could not retrieve assistant with id {id},
357
+ so creating a new one.
358
+ """
359
+ )
360
+ self.assistant = None
361
+ if self.assistant is not None:
362
+ return
363
+ cached = self._cache_assistant_lookup()
364
+ if cached is not None:
365
+ if self.config.use_cached_assistant:
366
+ self.assistant = self.llm.client.beta.assistants.retrieve(
367
+ assistant_id=cached
368
+ )
369
+ else:
370
+ logger.warning(
371
+ f"""
372
+ Found cached assistant id {cached},
373
+ but config.use_cached_assistant = False, so deleting it.
374
+ """
375
+ )
376
+ try:
377
+ self.llm.client.beta.assistants.delete(assistant_id=cached)
378
+ except Exception:
379
+ logger.warning(
380
+ f"""
381
+ Could not delete assistant with id {cached}, ignoring.
382
+ """
383
+ )
384
+ self.llm.cache.delete_keys([self._cache_assistant_key()])
385
+ if self.assistant is None:
386
+ self.assistant = self.llm.client.beta.assistants.create(
387
+ name=self.config.name,
388
+ instructions=self.config.system_message,
389
+ tools=[],
390
+ model=self.config.llm.chat_model,
391
+ )
392
+
393
+ def _get_run(self) -> Run:
394
+ """Retrieve the run object associated with this thread and run,
395
+ to see its latest status.
396
+ """
397
+ if self.thread is None or self.run is None:
398
+ raise ValueError("Thread or Run is None")
399
+ return self.runs.retrieve(thread_id=self.thread.id, run_id=self.run.id)
400
+
401
+ def _get_run_steps(self) -> List[RunStep]:
402
+ if self.thread is None or self.run is None:
403
+ raise ValueError("Thread or Run is None")
404
+ result = self.runs.steps.list(thread_id=self.thread.id, run_id=self.run.id)
405
+ if result is None:
406
+ return []
407
+ return result.data
408
+
409
+ def _get_code_logs(self) -> List[Tuple[str, str]]:
410
+ """
411
+ Get list of input, output strings from code logs
412
+ """
413
+ run_steps = self._get_run_steps()
414
+ # each step may have multiple tool-calls,
415
+ # each tool-call may have multiple outputs
416
+ tool_calls = [ # list of list of tool-calls
417
+ s.step_details.tool_calls
418
+ for s in run_steps
419
+ if s.step_details is not None and hasattr(s.step_details, "tool_calls")
420
+ ]
421
+ code_logs = []
422
+ for tcl in tool_calls: # each tool-call-list
423
+ for tc in tcl:
424
+ if tc is None or tc.type != ToolType.CODE_INTERPRETER:
425
+ continue
426
+ io = tc.code_interpreter # type: ignore
427
+ input = io.input
428
+ # TODO for CodeInterpreterOutputImage, there is no "logs"
429
+ # revisit when we handle images.
430
+ outputs = "\n\n".join(
431
+ o.logs
432
+ for o in io.outputs
433
+ if o.type == "logs" and hasattr(o, "logs")
434
+ )
435
+ code_logs.append((input, outputs))
436
+ # return the reversed list, since they are stored in reverse chron order
437
+ return code_logs[::-1]
438
+
439
+ def _get_code_logs_str(self) -> str:
440
+ """
441
+ Get string representation of code logs
442
+ """
443
+ code_logs = self._get_code_logs()
444
+ return "\n\n".join(
445
+ f"INPUT:\n{input}\n\nOUTPUT:\n{output}" for input, output in code_logs
446
+ )
447
+
448
+ def _add_thread_message(self, msg: str, role: Role) -> None:
449
+ """
450
+ Add a message with the given role to the thread.
451
+ Args:
452
+ msg (str): message to add
453
+ role (Role): role of the message
454
+ """
455
+ if self.thread is None:
456
+ raise ValueError("Thread is None")
457
+ # CACHING TRICK! Since the API only allows inserting USER messages,
458
+ # we prepend the role to the message, so that we can store ASSISTANT msgs
459
+ # as well! When the LLM sees the thread messages, they will contain
460
+ # the right sequence of alternating roles, so that it has no trouble
461
+ # responding when it is its turn.
462
+ msg = f"{role.value.upper()}: {msg}"
463
+ thread_msg = self.thread_messages.create(
464
+ content=msg,
465
+ thread_id=self.thread.id,
466
+ # We ALWAYS store user role since only user role allowed currently
467
+ role=Role.USER.value,
468
+ )
469
+ self._update_messages_hash(thread_msg)
470
+
471
+ def _get_thread_messages(self, n: int = 20) -> List[LLMMessage]:
472
+ """
473
+ Get the last n messages in the thread, in cleaned-up form (LLMMessage).
474
+ Args:
475
+ n (int): number of messages to retrieve
476
+ Returns:
477
+ List[LLMMessage]: list of messages
478
+ """
479
+ if self.thread is None:
480
+ raise ValueError("Thread is None")
481
+ result = self.thread_messages.list(
482
+ thread_id=self.thread.id,
483
+ limit=n,
484
+ )
485
+ num = len(result.data)
486
+ if result.has_more and num < n: # type: ignore
487
+ logger.warning(f"Retrieving last {num} messages, but there are more")
488
+ thread_msgs = result.data
489
+ for msg in thread_msgs:
490
+ self.process_citations(msg)
491
+ return [
492
+ LLMMessage(
493
+ # TODO: could be image, deal with it later
494
+ content=m.content[0].text.value, # type: ignore
495
+ role=Role(m.role),
496
+ )
497
+ for m in thread_msgs
498
+ ]
499
+
500
+ def _wait_for_run(
501
+ self,
502
+ until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
503
+ until: List[RunStatus] = [],
504
+ timeout: int = 30,
505
+ ) -> RunStatus:
506
+ """
507
+ Poll the run until it either:
508
+ - EXITs the statuses specified in `until_not`, or
509
+ - ENTERs the statuses specified in `until`, or
510
+ """
511
+ if self.thread is None or self.run is None:
512
+ raise ValueError("Thread or Run is None")
513
+ while True:
514
+ run = self._get_run()
515
+ if run.status not in until_not or run.status in until:
516
+ return cast(RunStatus, run.status)
517
+ time.sleep(1)
518
+ timeout -= 1
519
+ if timeout <= 0:
520
+ return cast(RunStatus, RunStatus.TIMEOUT)
521
+
522
+ async def _wait_for_run_async(
523
+ self,
524
+ until_not: List[RunStatus] = [RunStatus.QUEUED, RunStatus.IN_PROGRESS],
525
+ until: List[RunStatus] = [],
526
+ timeout: int = 30,
527
+ ) -> RunStatus:
528
+ """Async version of _wait_for_run"""
529
+ if self.thread is None or self.run is None:
530
+ raise ValueError("Thread or Run is None")
531
+ while True:
532
+ run = self._get_run()
533
+ if run.status not in until_not or run.status in until:
534
+ return cast(RunStatus, run.status)
535
+ await asyncio.sleep(1)
536
+ timeout -= 1
537
+ if timeout <= 0:
538
+ return cast(RunStatus, RunStatus.TIMEOUT)
539
+
540
+ def set_system_message(self, msg: str) -> None:
541
+ """
542
+ Override ChatAgent's method.
543
+ The Task may use this method to set the system message
544
+ of the chat assistant.
545
+ """
546
+ super().set_system_message(msg)
547
+ if self.assistant is None:
548
+ raise ValueError("Assistant is None")
549
+ self.assistant = self.assistants.update(self.assistant.id, instructions=msg)
550
+
551
+ def _start_run(self) -> None:
552
+ """
553
+ Run the assistant on the thread.
554
+ """
555
+ if self.thread is None or self.assistant is None:
556
+ raise ValueError("Thread or Assistant is None")
557
+ self.run = self.runs.create(
558
+ thread_id=self.thread.id,
559
+ assistant_id=self.assistant.id,
560
+ )
561
+
562
+ def _run_result(self) -> LLMResponse:
563
+ """Result from run completed on the thread."""
564
+ status = self._wait_for_run(
565
+ timeout=self.config.timeout,
566
+ )
567
+ return self._process_run_result(status)
568
+
569
+ async def _run_result_async(self) -> LLMResponse:
570
+ """(Async) Result from run completed on the thread."""
571
+ status = await self._wait_for_run_async(
572
+ timeout=self.config.timeout,
573
+ )
574
+ return self._process_run_result(status)
575
+
576
+ def _process_run_result(self, status: RunStatus) -> LLMResponse:
577
+ """Process the result of the run."""
578
+ function_call: LLMFunctionCall | None = None
579
+ response = ""
580
+ tool_id = ""
581
+ # IMPORTANT: FIRST save hash key to store result,
582
+ # before it gets updated with the response
583
+ key = self._cache_messages_key()
584
+ if status == RunStatus.TIMEOUT:
585
+ logger.warning("Timeout waiting for run to complete, return empty string")
586
+ elif status == RunStatus.COMPLETED:
587
+ messages = self._get_thread_messages(n=1)
588
+ response = messages[0].content
589
+ # update hash to include the response.
590
+ self._update_messages_hash(messages[0])
591
+ elif status == RunStatus.REQUIRES_ACTION:
592
+ tool_calls = self._parse_run_required_action()
593
+ # pick the FIRST tool call with type "function"
594
+ tool_call_fn = [t for t in tool_calls if t.type == ToolType.FUNCTION][0]
595
+ # TODO Handling only first tool/fn call for now
596
+ # revisit later: multi-tools affects the task.run() loop.
597
+ function_call = tool_call_fn.function
598
+ tool_id = tool_call_fn.id
599
+ result = LLMResponse(
600
+ message=response,
601
+ tool_id=tool_id,
602
+ function_call=function_call,
603
+ usage=None, # TODO
604
+ cached=False, # TODO - revisit when able to insert Assistant responses
605
+ )
606
+ self.llm.cache.store(key, result.dict())
607
+ return result
608
+
609
+ def _parse_run_required_action(self) -> List[AssistantToolCall]:
610
+ """
611
+ Parse the required_action field of the run, i.e. get the list of tool calls.
612
+ Currently only tool calls are supported.
613
+ """
614
+ # see https://platform.openai.com/docs/assistants/tools/function-calling
615
+ run = self._get_run()
616
+ if run.status != RunStatus.REQUIRES_ACTION: # type: ignore
617
+ return []
618
+
619
+ if (action := run.required_action.type) != "submit_tool_outputs":
620
+ raise ValueError(f"Unexpected required_action type {action}")
621
+ tool_calls = run.required_action.submit_tool_outputs.tool_calls
622
+ return [
623
+ AssistantToolCall(
624
+ id=tool_call.id,
625
+ type=ToolType(tool_call.type),
626
+ function=LLMFunctionCall.from_dict(tool_call.function.model_dump()),
627
+ )
628
+ for tool_call in tool_calls
629
+ ]
630
+
631
+ def _submit_tool_outputs(self, msg: LLMMessage) -> None:
632
+ """
633
+ Submit the tool (fn) outputs to the run/thread
634
+ """
635
+ if self.run is None or self.thread is None:
636
+ raise ValueError("Run or Thread is None")
637
+ tool_outputs = [
638
+ {
639
+ "tool_call_id": msg.tool_id,
640
+ "output": msg.content,
641
+ }
642
+ ]
643
+ # run enters queued, in_progress state after this
644
+ self.runs.submit_tool_outputs(
645
+ thread_id=self.thread.id,
646
+ run_id=self.run.id,
647
+ tool_outputs=tool_outputs, # type: ignore
648
+ )
649
+
650
+ def process_citations(self, thread_msg: Message) -> None:
651
+ """
652
+ Process citations in the thread message.
653
+ Modifies the thread message in-place.
654
+ """
655
+ # could there be multiple content items?
656
+ # TODO content could be MessageContentImageFile; handle that later
657
+ annotated_content = thread_msg.content[0].text # type: ignore
658
+ annotations = annotated_content.annotations
659
+ citations = []
660
+ # Iterate over the annotations and add footnotes
661
+ for index, annotation in enumerate(annotations):
662
+ # Replace the text with a footnote
663
+ annotated_content.value = annotated_content.value.replace(
664
+ annotation.text, f" [{index}]"
665
+ )
666
+ # Gather citations based on annotation attributes
667
+ if file_citation := getattr(annotation, "file_citation", None):
668
+ try:
669
+ cited_file = self.client.files.retrieve(file_citation.file_id)
670
+ except Exception:
671
+ logger.warning(
672
+ f"""
673
+ Could not retrieve cited file with id {file_citation.file_id},
674
+ ignoring.
675
+ """
676
+ )
677
+ continue
678
+ citations.append(
679
+ f"[{index}] '{file_citation.quote}',-- from {cited_file.filename}"
680
+ )
681
+ elif file_path := getattr(annotation, "file_path", None):
682
+ cited_file = self.client.files.retrieve(file_path.file_id)
683
+ citations.append(
684
+ f"[{index}] Click <here> to download {cited_file.filename}"
685
+ )
686
+ # Note: File download functionality not implemented above for brevity
687
+ sep = "\n" if len(citations) > 0 else ""
688
+ annotated_content.value += sep + "\n".join(citations)
689
+
690
+ def _llm_response_preprocess(
691
+ self,
692
+ message: Optional[str | ChatDocument] = None,
693
+ ) -> LLMResponse | None:
694
+ """
695
+ Preprocess message and return response if found in cache, else None.
696
+ """
697
+ is_tool_output = False
698
+ if message is not None:
699
+ llm_msg = ChatDocument.to_LLMMessage(message)
700
+ tool_id = llm_msg.tool_id
701
+ if tool_id in self.pending_tool_ids:
702
+ if isinstance(message, ChatDocument):
703
+ message.pop_tool_ids()
704
+ result_msg = f"Result for Tool_id {tool_id}: {llm_msg.content}"
705
+ if tool_id in self.cached_tool_ids:
706
+ self.cached_tool_ids.remove(tool_id)
707
+ # add actual result of cached fn-call
708
+ self._add_thread_message(result_msg, role=Role.USER)
709
+ else:
710
+ is_tool_output = True
711
+ # submit tool/fn result to the thread/run
712
+ self._submit_tool_outputs(llm_msg)
713
+ # We cannot ACTUALLY add this result to thread now
714
+ # since run is in `action_required` state,
715
+ # so we just update the message hash
716
+ self._update_messages_hash(
717
+ LLMMessage(content=result_msg, role=Role.USER)
718
+ )
719
+ self.pending_tool_ids.remove(tool_id)
720
+ else:
721
+ # add message to the thread
722
+ self._add_thread_message(llm_msg.content, role=Role.USER)
723
+
724
+ # When message is None, the thread may have no user msgs,
725
+ # Note: system message is NOT placed in the thread by the OpenAI system.
726
+
727
+ # check if we have cached the response.
728
+ # TODO: handle the case of structured result (fn-call, tool, etc)
729
+ response = self._cache_messages_lookup()
730
+ if response is not None:
731
+ response.cached = True
732
+ # store the result in the thread so
733
+ # it looks like assistant produced it
734
+ if self.config.cache_responses:
735
+ self._add_thread_message(
736
+ json.dumps(response.dict()), role=Role.ASSISTANT
737
+ )
738
+ return response # type: ignore
739
+ else:
740
+ # create a run for this assistant on this thread,
741
+ # i.e. actually "run"
742
+ if not is_tool_output:
743
+ # DO NOT start a run if we submitted tool outputs,
744
+ # since submission of tool outputs resumes a run from
745
+ # status = "requires_action"
746
+ self._start_run()
747
+ return None
748
+
749
+ def _llm_response_postprocess(
750
+ self,
751
+ response: LLMResponse,
752
+ cached: bool,
753
+ message: Optional[str | ChatDocument] = None,
754
+ ) -> Optional[ChatDocument]:
755
+ # code from ChatAgent.llm_response_messages
756
+ if response.function_call is not None:
757
+ self.pending_tool_ids += [response.tool_id]
758
+ if cached:
759
+ # add to cached tools list so we don't create an Assistant run
760
+ # in _llm_response_preprocess
761
+ self.cached_tool_ids += [response.tool_id]
762
+ response_str = str(response.function_call)
763
+ else:
764
+ response_str = response.message
765
+ cache_str = "[red](cached)[/red]" if cached else ""
766
+ if not settings.quiet:
767
+ if not cached and self._get_code_logs_str():
768
+ print(
769
+ f"[magenta]CODE-INTERPRETER LOGS:\n"
770
+ "-------------------------------\n"
771
+ f"{self._get_code_logs_str()}[/magenta]"
772
+ )
773
+ print(f"{cache_str}[green]" + response_str + "[/green]")
774
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=False)
775
+ # Note message.metadata.tool_ids may have been popped above
776
+ tool_ids = (
777
+ []
778
+ if (message is None or isinstance(message, str))
779
+ else message.metadata.tool_ids
780
+ )
781
+
782
+ if response.tool_id != "":
783
+ tool_ids.append(response.tool_id)
784
+ cdoc.metadata.tool_ids = tool_ids
785
+ return cdoc
786
+
787
+ def llm_response(
788
+ self, message: Optional[str | ChatDocument] = None
789
+ ) -> Optional[ChatDocument]:
790
+ """
791
+ Override ChatAgent's method: this is the main LLM response method.
792
+ In the ChatAgent, this updates `self.message_history` and then calls
793
+ `self.llm_response_messages`, but since we are relying on the Assistant API
794
+ to maintain conversation state, this method is simpler: Simply start a run
795
+ on the message-thread, and wait for it to complete.
796
+
797
+ Args:
798
+ message (Optional[str | ChatDocument], optional): message to respond to
799
+ (if absent, the LLM response will be based on the
800
+ instructions in the system_message). Defaults to None.
801
+ Returns:
802
+ Optional[ChatDocument]: LLM response
803
+ """
804
+ response = self._llm_response_preprocess(message)
805
+ cached = True
806
+ if response is None:
807
+ cached = False
808
+ response = self._run_result()
809
+ return self._llm_response_postprocess(response, cached=cached, message=message)
810
+
811
+ async def llm_response_async(
812
+ self, message: Optional[str | ChatDocument] = None
813
+ ) -> Optional[ChatDocument]:
814
+ """
815
+ Async version of llm_response.
816
+ """
817
+ response = self._llm_response_preprocess(message)
818
+ cached = True
819
+ if response is None:
820
+ cached = False
821
+ response = await self._run_result_async()
822
+ return self._llm_response_postprocess(response, cached=cached, message=message)
823
+
824
+ def agent_response(
825
+ self,
826
+ msg: Optional[str | ChatDocument] = None,
827
+ ) -> Optional[ChatDocument]:
828
+ response = super().agent_response(msg)
829
+ if msg is None:
830
+ return response
831
+ if response is None:
832
+ return None
833
+ try:
834
+ # When the agent response is to a tool message,
835
+ # we prefix it with "TOOL Result: " so that it is clear to the
836
+ # LLM that this is the result of the last TOOL;
837
+ # This ensures our caching trick works.
838
+ if self.config.use_tools and len(self.get_tool_messages(msg)) > 0:
839
+ response.content = "TOOL Result: " + response.content
840
+ return response
841
+ except Exception:
842
+ return response