langroid 0.1.265__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langroid/agent/task.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import copy
5
5
  import logging
6
+ import threading
6
7
  from collections import Counter, deque
7
8
  from types import SimpleNamespace
8
9
  from typing import (
@@ -13,7 +14,6 @@ from typing import (
13
14
  Dict,
14
15
  List,
15
16
  Optional,
16
- Set,
17
17
  Tuple,
18
18
  Type,
19
19
  cast,
@@ -47,6 +47,7 @@ from langroid.utils.constants import (
47
47
  USER_QUIT_STRINGS,
48
48
  )
49
49
  from langroid.utils.logging import RichFileLogger, setup_file_logger
50
+ from langroid.utils.object_registry import scheduled_cleanup
50
51
  from langroid.utils.system import hash
51
52
 
52
53
  logger = logging.getLogger(__name__)
@@ -65,14 +66,17 @@ class TaskConfig(BaseModel):
65
66
  we have config classes for `Agent`, `ChatAgent`, `LanguageModel`, etc.
66
67
 
67
68
  Attributes:
68
- inf_loop_cycle_len: max exact-loop cycle length: 0 => no inf loop test
69
- inf_loop_dominance_factor: dominance factor for exact-loop detection
70
- inf_loop_wait_factor: wait this * cycle_len msgs before loop-check
69
+ inf_loop_cycle_len (int): max exact-loop cycle length: 0 => no inf loop test
70
+ inf_loop_dominance_factor (float): dominance factor for exact-loop detection
71
+ inf_loop_wait_factor (int): wait this * cycle_len msgs before loop-check
72
+ restart_subtask_run (bool): whether to restart *every* run of this task
73
+ when run as a subtask.
71
74
  """
72
75
 
73
76
  inf_loop_cycle_len: int = 10
74
77
  inf_loop_dominance_factor: float = 1.5
75
78
  inf_loop_wait_factor: int = 5
79
+ restart_as_subtask: bool = False
76
80
 
77
81
 
78
82
  class Task:
@@ -107,6 +111,7 @@ class Task:
107
111
 
108
112
  # class variable called `cache` that is a RedisCache object
109
113
  _cache: RedisCache | None = None
114
+ _background_tasks_started: bool = False
110
115
 
111
116
  def __init__(
112
117
  self,
@@ -149,7 +154,7 @@ class Task:
149
154
  One run of step() is considered a "turn".
150
155
  system_message (str): if not empty, overrides agent's system_message
151
156
  user_message (str): if not empty, overrides agent's user_message
152
- restart (bool): if true, resets the agent's message history
157
+ restart (bool): if true, resets the agent's message history *at every run*.
153
158
  default_human_response (str): default response from user; useful for
154
159
  testing, to avoid interactive input from user.
155
160
  [Instead of this, setting `interactive` usually suffices]
@@ -187,6 +192,8 @@ class Task:
187
192
  set_parent_agent=noop_fn,
188
193
  )
189
194
  self.config = config
195
+ # how to behave as a sub-task; can be overriden by `add_sub_task()`
196
+ self.config_sub_task = copy.deepcopy(config)
190
197
  # counts of distinct pending messages in history,
191
198
  # to help detect (exact) infinite loops
192
199
  self.message_counter: Counter[str] = Counter()
@@ -208,24 +215,25 @@ class Task:
208
215
  the config may affect other agents using the same config.
209
216
  """
210
217
  )
211
-
218
+ self.restart = restart
219
+ agent = cast(ChatAgent, agent)
220
+ self.agent: ChatAgent = agent
212
221
  if isinstance(agent, ChatAgent) and len(agent.message_history) == 0 or restart:
213
- agent = cast(ChatAgent, agent)
214
- agent.clear_history(0)
215
- agent.clear_dialog()
222
+ self.agent.clear_history(0)
223
+ self.agent.clear_dialog()
216
224
  # possibly change the system and user messages
217
225
  if system_message:
218
226
  # we always have at least 1 task_message
219
- agent.set_system_message(system_message)
227
+ self.agent.set_system_message(system_message)
220
228
  if user_message:
221
- agent.set_user_message(user_message)
229
+ self.agent.set_user_message(user_message)
222
230
  self.max_cost: float = 0
223
231
  self.max_tokens: int = 0
224
232
  self.session_id: str = ""
225
233
  self.logger: None | RichFileLogger = None
226
234
  self.tsv_logger: None | logging.Logger = None
227
235
  self.color_log: bool = False if settings.notebook else True
228
- self.agent = agent
236
+
229
237
  self.step_progress = False # progress in current step?
230
238
  self.n_stalled_steps = 0 # how many consecutive steps with no progress?
231
239
  self.max_stalled_steps = max_stalled_steps
@@ -305,7 +313,6 @@ class Task:
305
313
 
306
314
  # other sub_tasks this task can delegate to
307
315
  self.sub_tasks: List[Task] = []
308
- self.parent_task: Set[Task] = set()
309
316
  self.caller: Task | None = None # which task called this task's `run` method
310
317
 
311
318
  def clone(self, i: int) -> "Task":
@@ -321,7 +328,7 @@ class Task:
321
328
  single_round=self.single_round,
322
329
  system_message=self.agent.system_message,
323
330
  user_message=self.agent.user_message,
324
- restart=False,
331
+ restart=self.restart,
325
332
  default_human_response=self.default_human_response,
326
333
  interactive=self.interactive,
327
334
  erase_substeps=self.erase_substeps,
@@ -338,6 +345,19 @@ class Task:
338
345
  cls._cache = RedisCache(RedisCacheConfig(fake=False))
339
346
  return cls._cache
340
347
 
348
+ @classmethod
349
+ def _start_background_tasks(cls) -> None:
350
+ """Start background object registry cleanup thread. NOT USED."""
351
+ if cls._background_tasks_started:
352
+ return
353
+ cls._background_tasks_started = True
354
+ cleanup_thread = threading.Thread(
355
+ target=scheduled_cleanup,
356
+ args=(600,),
357
+ daemon=True,
358
+ )
359
+ cleanup_thread.start()
360
+
341
361
  def __repr__(self) -> str:
342
362
  return f"{self.name}"
343
363
 
@@ -416,24 +436,37 @@ class Task:
416
436
  def _leave(self) -> str:
417
437
  return self._indent + "<<<"
418
438
 
419
- def add_sub_task(self, task: Task | List[Task]) -> None:
439
+ def add_sub_task(
440
+ self,
441
+ task: (
442
+ Task | List[Task] | Tuple[Task, TaskConfig] | List[Tuple[Task, TaskConfig]]
443
+ ),
444
+ ) -> None:
420
445
  """
421
446
  Add a sub-task (or list of subtasks) that this task can delegate
422
447
  (or fail-over) to. Note that the sequence of sub-tasks is important,
423
448
  since these are tried in order, as the parent task searches for a valid
424
- response.
449
+ response (unless a sub-task is explicitly addressed).
425
450
 
426
451
  Args:
427
- task (Task|List[Task]): sub-task(s) to add
452
+ task: A task, or list of tasks, or a tuple of task and task config,
453
+ or a list of tuples of task and task config.
454
+ These tasks are added as sub-tasks of the current task.
455
+ The task configs (if any) dictate how the tasks are run when
456
+ invoked as sub-tasks of other tasks. This allows users to specify
457
+ behavior applicable only in the context of a particular task-subtask
458
+ combination.
428
459
  """
429
-
430
460
  if isinstance(task, list):
431
461
  for t in task:
432
462
  self.add_sub_task(t)
433
463
  return
434
- assert isinstance(task, Task), f"added task must be a Task, not {type(task)}"
435
464
 
436
- task.parent_task.add(self) # add myself to set of parent tasks of `task`
465
+ if isinstance(task, tuple):
466
+ task, config = task
467
+ else:
468
+ config = TaskConfig()
469
+ task.config_sub_task = config
437
470
  self.sub_tasks.append(task)
438
471
  self.name_sub_task_map[task.name] = task
439
472
  self.responders.append(cast(Responder, task))
@@ -460,12 +493,28 @@ class Task:
460
493
  sender=Entity.USER,
461
494
  ),
462
495
  )
496
+ elif msg is None and len(self.agent.message_history) > 1:
497
+ # if agent has a history beyond system msg, set the
498
+ # pending message to the ChatDocument linked from
499
+ # last message in the history
500
+ last_agent_msg = self.agent.message_history[-1]
501
+ self.pending_message = ChatDocument.from_id(last_agent_msg.chat_document_id)
502
+ if self.pending_message is not None:
503
+ self.pending_sender = self.pending_message.metadata.sender
463
504
  else:
464
- self.pending_message = copy.deepcopy(msg)
505
+ if isinstance(msg, ChatDocument):
506
+ # carefully deep-copy: fresh metadata.id, register
507
+ # as new obj in registry
508
+ self.pending_message = ChatDocument.deepcopy(msg)
465
509
  if self.pending_message is not None and self.caller is not None:
466
510
  # msg may have come from `caller`, so we pretend this is from
467
511
  # the CURRENT task's USER entity
468
512
  self.pending_message.metadata.sender = Entity.USER
513
+ # update parent, child, agent pointers
514
+ if msg is not None:
515
+ msg.metadata.child_id = self.pending_message.metadata.id
516
+ self.pending_message.metadata.parent_id = msg.metadata.id
517
+ self.pending_message.metadata.agent_id = self.agent.id
469
518
 
470
519
  self._show_pending_message_if_debug()
471
520
 
@@ -484,6 +533,13 @@ class Task:
484
533
  self.log_message(Entity.USER, self.pending_message)
485
534
  return self.pending_message
486
535
 
536
+ def reset_all_sub_tasks(self) -> None:
537
+ """Recursively reset message history of own agent and all sub-tasks"""
538
+ self.agent.clear_history(0)
539
+ self.agent.clear_dialog()
540
+ for t in self.sub_tasks:
541
+ t.reset_all_sub_tasks()
542
+
487
543
  def run(
488
544
  self,
489
545
  msg: Optional[str | ChatDocument] = None,
@@ -495,6 +551,14 @@ class Task:
495
551
  ) -> Optional[ChatDocument]:
496
552
  """Synchronous version of `run_async()`.
497
553
  See `run_async()` for details."""
554
+ if (self.restart and caller is None) or (
555
+ self.config_sub_task.restart_as_subtask and caller is not None
556
+ ):
557
+ # We are either at top level, with restart = True, OR
558
+ # we are a sub-task with restart_as_subtask = True,
559
+ # so reset own agent and recursively for all sub-tasks
560
+ self.reset_all_sub_tasks()
561
+
498
562
  self.task_progress = False
499
563
  self.n_stalled_steps = 0
500
564
  self.max_cost = max_cost
@@ -597,6 +661,18 @@ class Task:
597
661
  # have come from another LLM), as far as this agent is concerned, the initial
598
662
  # message can be considered to be from the USER
599
663
  # (from the POV of this agent's LLM).
664
+
665
+ if (
666
+ self.restart
667
+ and caller is None
668
+ or self.config_sub_task.restart_as_subtask
669
+ and caller is not None
670
+ ):
671
+ # We are either at top level, with restart = True, OR
672
+ # we are a sub-task with restart_as_subtask = True,
673
+ # so reset own agent and recursively for all sub-tasks
674
+ self.reset_all_sub_tasks()
675
+
600
676
  self.task_progress = False
601
677
  self.n_stalled_steps = 0
602
678
  self.max_cost = max_cost
@@ -701,6 +777,23 @@ class Task:
701
777
  if self.erase_substeps:
702
778
  # TODO I don't like directly accessing agent message_history. Revisit.
703
779
  # (Pchalasani)
780
+ # Note: msg history will consist of:
781
+ # - H: the original msg history, ending at idx= self.message_history_idx
782
+ # - R: this agent's response, which presumably leads to:
783
+ # - X: a series of back-and-forth msgs (including with agent's own
784
+ # responders and with sub-tasks)
785
+ # - F: the final result message, from this agent.
786
+ # Here we are deleting all of [X] from the agent's message history,
787
+ # so that it simply looks as if the sub-tasks never happened.
788
+
789
+ dropped = self.agent.message_history[
790
+ self.message_history_idx + 2 : n_messages - 1
791
+ ]
792
+ # first delete the linked ChatDocuments (and descendants) from
793
+ # ObjectRegistry
794
+ for msg in dropped:
795
+ ChatDocument.delete_id(msg.chat_document_id)
796
+ # then delete the messages from the agent's message_history
704
797
  del self.agent.message_history[
705
798
  self.message_history_idx + 2 : n_messages - 1
706
799
  ]
@@ -750,9 +843,11 @@ class Task:
750
843
 
751
844
  if (
752
845
  Entity.USER in self.responders
846
+ and self.interactive
753
847
  and not self.human_tried
754
848
  and not self.agent.has_tool_message_attempt(self.pending_message)
755
849
  ):
850
+ # When in interactive mode,
756
851
  # Give human first chance if they haven't been tried in last step,
757
852
  # and the msg is not a tool-call attempt;
758
853
  # This ensures human gets a chance to respond,
@@ -778,6 +873,8 @@ class Task:
778
873
  recipient=recipient,
779
874
  ),
780
875
  )
876
+ # no need to register this dummy msg in ObjectRegistry
877
+ ChatDocument.delete_id(log_doc.id())
781
878
  self.log_message(r, log_doc)
782
879
  continue
783
880
  self.human_tried = r == Entity.USER
@@ -844,6 +941,7 @@ class Task:
844
941
 
845
942
  if (
846
943
  Entity.USER in self.responders
944
+ and self.interactive
847
945
  and not self.human_tried
848
946
  and not self.agent.has_tool_message_attempt(self.pending_message)
849
947
  ):
@@ -870,6 +968,8 @@ class Task:
870
968
  recipient=recipient,
871
969
  ),
872
970
  )
971
+ # no need to register this dummy msg in ObjectRegistry
972
+ ChatDocument.delete_id(log_doc.id())
873
973
  self.log_message(r, log_doc)
874
974
  continue
875
975
  self.human_tried = r == Entity.USER
@@ -905,10 +1005,26 @@ class Task:
905
1005
  # Contrast this with self.pending_message.metadata.sender, which is an ENTITY
906
1006
  # of this agent, or a sub-task's agent.
907
1007
  if not self.is_pass_thru:
908
- self.pending_sender = r
909
- result.metadata.parent = parent
910
- if not self.is_pass_thru:
1008
+ if (
1009
+ self.pending_message is not None
1010
+ and self.pending_message.metadata.agent_id == self.agent.id
1011
+ ):
1012
+ # when pending msg is from our own agent, respect the sender set there,
1013
+ # since sometimes a response may "mock" as if the response is from
1014
+ # another entity (e.g when using RewindTool, the agent handler
1015
+ # returns a result as if it were from the LLM).
1016
+ self.pending_sender = result.metadata.sender
1017
+ else:
1018
+ # when pending msg is from a sub-task, the sender is the sub-task
1019
+ self.pending_sender = r
911
1020
  self.pending_message = result
1021
+ # set the parent/child links ONLY if not already set by agent internally,
1022
+ # which may happen when using the RewindTool
1023
+ if parent is not None and not result.metadata.parent_id:
1024
+ result.metadata.parent_id = parent.id()
1025
+ if parent is not None and not parent.metadata.child_id:
1026
+ parent.metadata.child_id = result.id()
1027
+
912
1028
  self.log_message(self.pending_sender, result, mark=True)
913
1029
  self.step_progress = True
914
1030
  self.task_progress = True
@@ -941,9 +1057,10 @@ class Task:
941
1057
  responder = (
942
1058
  Entity.LLM if self.pending_sender == Entity.USER else Entity.USER
943
1059
  )
1060
+ parent_id = "" if parent is None else parent.id()
944
1061
  self.pending_message = ChatDocument(
945
1062
  content=NO_ANSWER,
946
- metadata=ChatDocMetaData(sender=responder, parent=parent),
1063
+ metadata=ChatDocMetaData(sender=responder, parent_id=parent_id),
947
1064
  )
948
1065
  self.pending_sender = responder
949
1066
  self.log_message(self.pending_sender, self.pending_message, mark=True)
@@ -1089,7 +1206,7 @@ class Task:
1089
1206
  # regardless of which entity actually produced the result,
1090
1207
  # when we return the result, we set entity to USER
1091
1208
  # since to the "parent" task, this result is equivalent to a response from USER
1092
- return ChatDocument(
1209
+ result_doc = ChatDocument(
1093
1210
  content=content,
1094
1211
  function_call=fun_call,
1095
1212
  tool_messages=tool_messages,
@@ -1101,8 +1218,14 @@ class Task:
1101
1218
  sender_name=self.name,
1102
1219
  recipient=recipient,
1103
1220
  tool_ids=tool_ids,
1221
+ parent_id=result_msg.id() if result_msg else "",
1222
+ agent_id=str(self.agent.id),
1104
1223
  ),
1105
1224
  )
1225
+ if self.pending_message is not None:
1226
+ self.pending_message.metadata.child_id = result_doc.id()
1227
+
1228
+ return result_doc
1106
1229
 
1107
1230
  def _is_empty_message(self, msg: str | ChatDocument | None) -> bool:
1108
1231
  """
@@ -1,7 +1,9 @@
1
1
  from . import google_search_tool
2
2
  from . import recipient_tool
3
+ from . import rewind_tool
3
4
  from .google_search_tool import GoogleSearchTool
4
5
  from .recipient_tool import AddRecipientTool, RecipientTool
6
+ from .rewind_tool import RewindTool
5
7
 
6
8
  __all__ = [
7
9
  "GoogleSearchTool",
@@ -9,4 +11,6 @@ __all__ = [
9
11
  "RecipientTool",
10
12
  "google_search_tool",
11
13
  "recipient_tool",
14
+ "rewind_tool",
15
+ "RewindTool",
12
16
  ]
@@ -0,0 +1,136 @@
1
+ """
2
+ The `rewind_tool` is used to rewind to the `n`th previous Assistant message
3
+ and replace it with a new `content`. This is useful in several scenarios and
4
+ - saves token-cost + inference time,
5
+ - reduces distracting clutter in chat history, which helps improve response quality.
6
+
7
+ This is intended to mimic how a human user might use a chat interface, where they
8
+ go down a conversation path, and want to go back in history to "edit and re-submit"
9
+ a previous message, to get a better response.
10
+
11
+ See usage examples in `tests/main/test_rewind_tool.py`.
12
+ """
13
+
14
+ from typing import List, Tuple
15
+
16
+ import langroid.language_models as lm
17
+ from langroid.agent.chat_agent import ChatAgent
18
+ from langroid.agent.chat_document import ChatDocument
19
+ from langroid.agent.tool_message import ToolMessage
20
+
21
+
22
+ def prune_messages(agent: ChatAgent, idx: int) -> ChatDocument | None:
23
+ """
24
+ Clear the message history of agent, starting at index `idx`,
25
+ taking care to first clear all dependent messages (possibly from other agents'
26
+ message histories) that are linked to the message at `idx`, via the `child_id` field
27
+ of the `metadata` field of the ChatDocument linked from the message at `idx`.
28
+
29
+ Args:
30
+ agent (ChatAgent): The agent whose message history is to be pruned.
31
+ idx (int): The index from which to start clearing the message history.
32
+
33
+ Returns:
34
+ The parent ChatDocument of the ChatDocument linked from the message at `idx`,
35
+ if it exists, else None.
36
+
37
+ """
38
+ assert idx >= 0, "Invalid index for message history!"
39
+ chat_doc_id = agent.message_history[idx].chat_document_id
40
+ chat_doc = ChatDocument.from_id(chat_doc_id)
41
+ assert chat_doc is not None, "ChatDocument not found in registry!"
42
+
43
+ parent = ChatDocument.from_id(chat_doc.metadata.parent_id) # may be None
44
+ # We're invaliding the msg at idx,
45
+ # so starting with chat_doc, go down the child links
46
+ # and clear history of each agent, to the msg_idx
47
+ curr_doc = chat_doc
48
+ while child_doc := curr_doc.metadata.child:
49
+ if child_doc.metadata.msg_idx >= 0:
50
+ child_agent = ChatAgent.from_id(child_doc.metadata.agent_id)
51
+ if child_agent is not None:
52
+ child_agent.clear_history(child_doc.metadata.msg_idx)
53
+ curr_doc = child_doc
54
+
55
+ # Clear out ObjectRegistry entries for this ChatDocuments
56
+ # and all descendants (in case they weren't already cleared above)
57
+ ChatDocument.delete_id(chat_doc.id())
58
+
59
+ # Finally, clear this agent's history back to idx,
60
+ # and replace the msg at idx with the new content
61
+ agent.clear_history(idx)
62
+ return parent
63
+
64
+
65
+ class RewindTool(ToolMessage):
66
+ """
67
+ Used by LLM to rewind (i.e. backtrack) to the `n`th Assistant message
68
+ and replace with a new msg.
69
+ """
70
+
71
+ request: str = "rewind_tool"
72
+ purpose: str = """
73
+ To rewind the conversation and replace the
74
+ <n>'th Assistant message with <content>
75
+ """
76
+ n: int
77
+ content: str
78
+
79
+ @classmethod
80
+ def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
81
+ return [
82
+ cls(n=1, content="What are the 3 major causes of heart disease?"),
83
+ (
84
+ """
85
+ I want to change my 2nd message to Bob, to say
86
+ 'who wrote the book Grime and Banishment?'
87
+ """,
88
+ cls(n=2, content="who wrote the book 'Grime and Banishment'?"),
89
+ ),
90
+ ]
91
+
92
+ def response(self, agent: ChatAgent) -> str | ChatDocument:
93
+ """
94
+ Define the tool-handler method for this tool here itself,
95
+ since it is a generic tool whose functionality should be the
96
+ same for any agent.
97
+
98
+ When LLM has correctly used this tool, rewind this agent's
99
+ `message_history` to the `n`th assistant msg, and replace it with `content`.
100
+ We need to mock it as if the LLM is sending this message.
101
+
102
+ Within a multi-agent scenario, this also means that any other messages dependent
103
+ on this message will need to be invalidated --
104
+ so go down the chain of child messages and clear each agent's history
105
+ back to the `msg_idx` corresponding to the child message.
106
+
107
+ Returns:
108
+ (ChatDocument): with content set to self.content.
109
+ """
110
+ idx = agent.nth_message_idx_with_role(lm.Role.ASSISTANT, self.n)
111
+ if idx < 0:
112
+ # set up a corrective message from AGENT
113
+ msg = f"""
114
+ Could not rewind to {self.n}th Assistant message!
115
+ Please check the value of `n` and try again.
116
+ Or it may be too early to use the `rewind_tool`.
117
+ """
118
+ return agent.create_agent_response(msg)
119
+
120
+ parent = prune_messages(agent, idx)
121
+
122
+ # create ChatDocument with new content, to be returned as result of this tool
123
+ result_doc = agent.create_llm_response(self.content)
124
+ result_doc.metadata.parent_id = "" if parent is None else parent.id()
125
+ result_doc.metadata.agent_id = agent.id
126
+ result_doc.metadata.msg_idx = idx
127
+
128
+ # replace the message at idx with this new message
129
+ agent.message_history.append(ChatDocument.to_LLMMessage(result_doc))
130
+
131
+ # set the replaced doc's parent's child to this result_doc
132
+ if parent is not None:
133
+ # first remove the this parent's child from registry
134
+ ChatDocument.delete_id(parent.metadata.child_id)
135
+ parent.metadata.child_id = result_doc.id()
136
+ return result_doc
@@ -20,6 +20,7 @@ from .openai_gpt import (
20
20
  OpenAIGPTConfig,
21
21
  OpenAIGPT,
22
22
  )
23
+ from .mock_lm import MockLM, MockLMConfig
23
24
  from .azure_openai import AzureConfig, AzureGPT
24
25
 
25
26
 
@@ -43,4 +44,6 @@ __all__ = [
43
44
  "OpenAIGPT",
44
45
  "AzureConfig",
45
46
  "AzureGPT",
47
+ "MockLM",
48
+ "MockLMConfig",
46
49
  ]
@@ -4,7 +4,17 @@ import logging
4
4
  from abc import ABC, abstractmethod
5
5
  from datetime import datetime
6
6
  from enum import Enum
7
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ cast,
17
+ )
8
18
 
9
19
  from langroid.cachedb.base import CacheDBConfig
10
20
  from langroid.parsing.agent_chats import parse_message
@@ -134,12 +144,15 @@ class LLMMessage(BaseModel):
134
144
  content: str
135
145
  function_call: Optional[LLMFunctionCall] = None
136
146
  timestamp: datetime = Field(default_factory=datetime.utcnow)
147
+ # link to corresponding chat document, for provenance/rewind purposes
148
+ chat_document_id: str = ""
137
149
 
138
150
  def api_dict(self) -> Dict[str, Any]:
139
151
  """
140
- Convert to dictionary for API request.
141
- DROP the tool_id, since it is only for use in the Assistant API,
142
- not the completion API.
152
+ Convert to dictionary for API request, keeping ONLY
153
+ the fields that are expected in an API call!
154
+ E.g., DROP the tool_id, since it is only for use in the Assistant API,
155
+ not the completion API.
143
156
  Returns:
144
157
  dict: dictionary representation of LLM message
145
158
  """
@@ -155,8 +168,10 @@ class LLMMessage(BaseModel):
155
168
  dict_no_none["function_call"]["arguments"] = json.dumps(
156
169
  dict_no_none["function_call"]["arguments"]
157
170
  )
171
+ # IMPORTANT! drop fields that are not expected in API call
158
172
  dict_no_none.pop("tool_id", None)
159
173
  dict_no_none.pop("timestamp", None)
174
+ dict_no_none.pop("chat_document_id", None)
160
175
  return dict_no_none
161
176
 
162
177
  def __str__(self) -> str:
@@ -268,11 +283,15 @@ class LanguageModel(ABC):
268
283
  """
269
284
  )
270
285
  from langroid.language_models.azure_openai import AzureGPT
286
+ from langroid.language_models.mock_lm import MockLM, MockLMConfig
271
287
  from langroid.language_models.openai_gpt import OpenAIGPT
272
288
 
273
289
  if config is None or config.type is None:
274
290
  return None
275
291
 
292
+ if config.type == "mock":
293
+ return MockLM(cast(MockLMConfig, config))
294
+
276
295
  openai: Union[Type[AzureGPT], Type[OpenAIGPT]]
277
296
 
278
297
  if config.type == "azure":