langroid 0.1.265__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +15 -1
- langroid/agent/chat_agent.py +68 -16
- langroid/agent/chat_document.py +57 -3
- langroid/agent/special/doc_chat_agent.py +8 -26
- langroid/agent/task.py +149 -26
- langroid/agent/tools/__init__.py +4 -0
- langroid/agent/tools/rewind_tool.py +136 -0
- langroid/language_models/__init__.py +3 -0
- langroid/language_models/base.py +23 -4
- langroid/language_models/mock_lm.py +96 -0
- langroid/language_models/utils.py +2 -1
- langroid/mytypes.py +4 -35
- langroid/parsing/document_parser.py +5 -0
- langroid/parsing/parser.py +17 -2
- langroid/utils/__init__.py +2 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/system.py +1 -2
- langroid/vector_store/base.py +3 -2
- {langroid-0.1.265.dist-info → langroid-0.2.0.dist-info}/METADATA +5 -5
- {langroid-0.1.265.dist-info → langroid-0.2.0.dist-info}/RECORD +23 -21
- pyproject.toml +1 -1
- langroid/language_models/openai_assistants.py +0 -3
- {langroid-0.1.265.dist-info → langroid-0.2.0.dist-info}/LICENSE +0 -0
- {langroid-0.1.265.dist-info → langroid-0.2.0.dist-info}/WHEEL +0 -0
langroid/agent/task.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import asyncio
|
4
4
|
import copy
|
5
5
|
import logging
|
6
|
+
import threading
|
6
7
|
from collections import Counter, deque
|
7
8
|
from types import SimpleNamespace
|
8
9
|
from typing import (
|
@@ -13,7 +14,6 @@ from typing import (
|
|
13
14
|
Dict,
|
14
15
|
List,
|
15
16
|
Optional,
|
16
|
-
Set,
|
17
17
|
Tuple,
|
18
18
|
Type,
|
19
19
|
cast,
|
@@ -47,6 +47,7 @@ from langroid.utils.constants import (
|
|
47
47
|
USER_QUIT_STRINGS,
|
48
48
|
)
|
49
49
|
from langroid.utils.logging import RichFileLogger, setup_file_logger
|
50
|
+
from langroid.utils.object_registry import scheduled_cleanup
|
50
51
|
from langroid.utils.system import hash
|
51
52
|
|
52
53
|
logger = logging.getLogger(__name__)
|
@@ -65,14 +66,17 @@ class TaskConfig(BaseModel):
|
|
65
66
|
we have config classes for `Agent`, `ChatAgent`, `LanguageModel`, etc.
|
66
67
|
|
67
68
|
Attributes:
|
68
|
-
inf_loop_cycle_len: max exact-loop cycle length: 0 => no inf loop test
|
69
|
-
inf_loop_dominance_factor: dominance factor for exact-loop detection
|
70
|
-
inf_loop_wait_factor: wait this * cycle_len msgs before loop-check
|
69
|
+
inf_loop_cycle_len (int): max exact-loop cycle length: 0 => no inf loop test
|
70
|
+
inf_loop_dominance_factor (float): dominance factor for exact-loop detection
|
71
|
+
inf_loop_wait_factor (int): wait this * cycle_len msgs before loop-check
|
72
|
+
restart_subtask_run (bool): whether to restart *every* run of this task
|
73
|
+
when run as a subtask.
|
71
74
|
"""
|
72
75
|
|
73
76
|
inf_loop_cycle_len: int = 10
|
74
77
|
inf_loop_dominance_factor: float = 1.5
|
75
78
|
inf_loop_wait_factor: int = 5
|
79
|
+
restart_as_subtask: bool = False
|
76
80
|
|
77
81
|
|
78
82
|
class Task:
|
@@ -107,6 +111,7 @@ class Task:
|
|
107
111
|
|
108
112
|
# class variable called `cache` that is a RedisCache object
|
109
113
|
_cache: RedisCache | None = None
|
114
|
+
_background_tasks_started: bool = False
|
110
115
|
|
111
116
|
def __init__(
|
112
117
|
self,
|
@@ -149,7 +154,7 @@ class Task:
|
|
149
154
|
One run of step() is considered a "turn".
|
150
155
|
system_message (str): if not empty, overrides agent's system_message
|
151
156
|
user_message (str): if not empty, overrides agent's user_message
|
152
|
-
restart (bool): if true, resets the agent's message history
|
157
|
+
restart (bool): if true, resets the agent's message history *at every run*.
|
153
158
|
default_human_response (str): default response from user; useful for
|
154
159
|
testing, to avoid interactive input from user.
|
155
160
|
[Instead of this, setting `interactive` usually suffices]
|
@@ -187,6 +192,8 @@ class Task:
|
|
187
192
|
set_parent_agent=noop_fn,
|
188
193
|
)
|
189
194
|
self.config = config
|
195
|
+
# how to behave as a sub-task; can be overriden by `add_sub_task()`
|
196
|
+
self.config_sub_task = copy.deepcopy(config)
|
190
197
|
# counts of distinct pending messages in history,
|
191
198
|
# to help detect (exact) infinite loops
|
192
199
|
self.message_counter: Counter[str] = Counter()
|
@@ -208,24 +215,25 @@ class Task:
|
|
208
215
|
the config may affect other agents using the same config.
|
209
216
|
"""
|
210
217
|
)
|
211
|
-
|
218
|
+
self.restart = restart
|
219
|
+
agent = cast(ChatAgent, agent)
|
220
|
+
self.agent: ChatAgent = agent
|
212
221
|
if isinstance(agent, ChatAgent) and len(agent.message_history) == 0 or restart:
|
213
|
-
agent
|
214
|
-
agent.
|
215
|
-
agent.clear_dialog()
|
222
|
+
self.agent.clear_history(0)
|
223
|
+
self.agent.clear_dialog()
|
216
224
|
# possibly change the system and user messages
|
217
225
|
if system_message:
|
218
226
|
# we always have at least 1 task_message
|
219
|
-
agent.set_system_message(system_message)
|
227
|
+
self.agent.set_system_message(system_message)
|
220
228
|
if user_message:
|
221
|
-
agent.set_user_message(user_message)
|
229
|
+
self.agent.set_user_message(user_message)
|
222
230
|
self.max_cost: float = 0
|
223
231
|
self.max_tokens: int = 0
|
224
232
|
self.session_id: str = ""
|
225
233
|
self.logger: None | RichFileLogger = None
|
226
234
|
self.tsv_logger: None | logging.Logger = None
|
227
235
|
self.color_log: bool = False if settings.notebook else True
|
228
|
-
|
236
|
+
|
229
237
|
self.step_progress = False # progress in current step?
|
230
238
|
self.n_stalled_steps = 0 # how many consecutive steps with no progress?
|
231
239
|
self.max_stalled_steps = max_stalled_steps
|
@@ -305,7 +313,6 @@ class Task:
|
|
305
313
|
|
306
314
|
# other sub_tasks this task can delegate to
|
307
315
|
self.sub_tasks: List[Task] = []
|
308
|
-
self.parent_task: Set[Task] = set()
|
309
316
|
self.caller: Task | None = None # which task called this task's `run` method
|
310
317
|
|
311
318
|
def clone(self, i: int) -> "Task":
|
@@ -321,7 +328,7 @@ class Task:
|
|
321
328
|
single_round=self.single_round,
|
322
329
|
system_message=self.agent.system_message,
|
323
330
|
user_message=self.agent.user_message,
|
324
|
-
restart=
|
331
|
+
restart=self.restart,
|
325
332
|
default_human_response=self.default_human_response,
|
326
333
|
interactive=self.interactive,
|
327
334
|
erase_substeps=self.erase_substeps,
|
@@ -338,6 +345,19 @@ class Task:
|
|
338
345
|
cls._cache = RedisCache(RedisCacheConfig(fake=False))
|
339
346
|
return cls._cache
|
340
347
|
|
348
|
+
@classmethod
|
349
|
+
def _start_background_tasks(cls) -> None:
|
350
|
+
"""Start background object registry cleanup thread. NOT USED."""
|
351
|
+
if cls._background_tasks_started:
|
352
|
+
return
|
353
|
+
cls._background_tasks_started = True
|
354
|
+
cleanup_thread = threading.Thread(
|
355
|
+
target=scheduled_cleanup,
|
356
|
+
args=(600,),
|
357
|
+
daemon=True,
|
358
|
+
)
|
359
|
+
cleanup_thread.start()
|
360
|
+
|
341
361
|
def __repr__(self) -> str:
|
342
362
|
return f"{self.name}"
|
343
363
|
|
@@ -416,24 +436,37 @@ class Task:
|
|
416
436
|
def _leave(self) -> str:
|
417
437
|
return self._indent + "<<<"
|
418
438
|
|
419
|
-
def add_sub_task(
|
439
|
+
def add_sub_task(
|
440
|
+
self,
|
441
|
+
task: (
|
442
|
+
Task | List[Task] | Tuple[Task, TaskConfig] | List[Tuple[Task, TaskConfig]]
|
443
|
+
),
|
444
|
+
) -> None:
|
420
445
|
"""
|
421
446
|
Add a sub-task (or list of subtasks) that this task can delegate
|
422
447
|
(or fail-over) to. Note that the sequence of sub-tasks is important,
|
423
448
|
since these are tried in order, as the parent task searches for a valid
|
424
|
-
response.
|
449
|
+
response (unless a sub-task is explicitly addressed).
|
425
450
|
|
426
451
|
Args:
|
427
|
-
task
|
452
|
+
task: A task, or list of tasks, or a tuple of task and task config,
|
453
|
+
or a list of tuples of task and task config.
|
454
|
+
These tasks are added as sub-tasks of the current task.
|
455
|
+
The task configs (if any) dictate how the tasks are run when
|
456
|
+
invoked as sub-tasks of other tasks. This allows users to specify
|
457
|
+
behavior applicable only in the context of a particular task-subtask
|
458
|
+
combination.
|
428
459
|
"""
|
429
|
-
|
430
460
|
if isinstance(task, list):
|
431
461
|
for t in task:
|
432
462
|
self.add_sub_task(t)
|
433
463
|
return
|
434
|
-
assert isinstance(task, Task), f"added task must be a Task, not {type(task)}"
|
435
464
|
|
436
|
-
task
|
465
|
+
if isinstance(task, tuple):
|
466
|
+
task, config = task
|
467
|
+
else:
|
468
|
+
config = TaskConfig()
|
469
|
+
task.config_sub_task = config
|
437
470
|
self.sub_tasks.append(task)
|
438
471
|
self.name_sub_task_map[task.name] = task
|
439
472
|
self.responders.append(cast(Responder, task))
|
@@ -460,12 +493,28 @@ class Task:
|
|
460
493
|
sender=Entity.USER,
|
461
494
|
),
|
462
495
|
)
|
496
|
+
elif msg is None and len(self.agent.message_history) > 1:
|
497
|
+
# if agent has a history beyond system msg, set the
|
498
|
+
# pending message to the ChatDocument linked from
|
499
|
+
# last message in the history
|
500
|
+
last_agent_msg = self.agent.message_history[-1]
|
501
|
+
self.pending_message = ChatDocument.from_id(last_agent_msg.chat_document_id)
|
502
|
+
if self.pending_message is not None:
|
503
|
+
self.pending_sender = self.pending_message.metadata.sender
|
463
504
|
else:
|
464
|
-
|
505
|
+
if isinstance(msg, ChatDocument):
|
506
|
+
# carefully deep-copy: fresh metadata.id, register
|
507
|
+
# as new obj in registry
|
508
|
+
self.pending_message = ChatDocument.deepcopy(msg)
|
465
509
|
if self.pending_message is not None and self.caller is not None:
|
466
510
|
# msg may have come from `caller`, so we pretend this is from
|
467
511
|
# the CURRENT task's USER entity
|
468
512
|
self.pending_message.metadata.sender = Entity.USER
|
513
|
+
# update parent, child, agent pointers
|
514
|
+
if msg is not None:
|
515
|
+
msg.metadata.child_id = self.pending_message.metadata.id
|
516
|
+
self.pending_message.metadata.parent_id = msg.metadata.id
|
517
|
+
self.pending_message.metadata.agent_id = self.agent.id
|
469
518
|
|
470
519
|
self._show_pending_message_if_debug()
|
471
520
|
|
@@ -484,6 +533,13 @@ class Task:
|
|
484
533
|
self.log_message(Entity.USER, self.pending_message)
|
485
534
|
return self.pending_message
|
486
535
|
|
536
|
+
def reset_all_sub_tasks(self) -> None:
|
537
|
+
"""Recursively reset message history of own agent and all sub-tasks"""
|
538
|
+
self.agent.clear_history(0)
|
539
|
+
self.agent.clear_dialog()
|
540
|
+
for t in self.sub_tasks:
|
541
|
+
t.reset_all_sub_tasks()
|
542
|
+
|
487
543
|
def run(
|
488
544
|
self,
|
489
545
|
msg: Optional[str | ChatDocument] = None,
|
@@ -495,6 +551,14 @@ class Task:
|
|
495
551
|
) -> Optional[ChatDocument]:
|
496
552
|
"""Synchronous version of `run_async()`.
|
497
553
|
See `run_async()` for details."""
|
554
|
+
if (self.restart and caller is None) or (
|
555
|
+
self.config_sub_task.restart_as_subtask and caller is not None
|
556
|
+
):
|
557
|
+
# We are either at top level, with restart = True, OR
|
558
|
+
# we are a sub-task with restart_as_subtask = True,
|
559
|
+
# so reset own agent and recursively for all sub-tasks
|
560
|
+
self.reset_all_sub_tasks()
|
561
|
+
|
498
562
|
self.task_progress = False
|
499
563
|
self.n_stalled_steps = 0
|
500
564
|
self.max_cost = max_cost
|
@@ -597,6 +661,18 @@ class Task:
|
|
597
661
|
# have come from another LLM), as far as this agent is concerned, the initial
|
598
662
|
# message can be considered to be from the USER
|
599
663
|
# (from the POV of this agent's LLM).
|
664
|
+
|
665
|
+
if (
|
666
|
+
self.restart
|
667
|
+
and caller is None
|
668
|
+
or self.config_sub_task.restart_as_subtask
|
669
|
+
and caller is not None
|
670
|
+
):
|
671
|
+
# We are either at top level, with restart = True, OR
|
672
|
+
# we are a sub-task with restart_as_subtask = True,
|
673
|
+
# so reset own agent and recursively for all sub-tasks
|
674
|
+
self.reset_all_sub_tasks()
|
675
|
+
|
600
676
|
self.task_progress = False
|
601
677
|
self.n_stalled_steps = 0
|
602
678
|
self.max_cost = max_cost
|
@@ -701,6 +777,23 @@ class Task:
|
|
701
777
|
if self.erase_substeps:
|
702
778
|
# TODO I don't like directly accessing agent message_history. Revisit.
|
703
779
|
# (Pchalasani)
|
780
|
+
# Note: msg history will consist of:
|
781
|
+
# - H: the original msg history, ending at idx= self.message_history_idx
|
782
|
+
# - R: this agent's response, which presumably leads to:
|
783
|
+
# - X: a series of back-and-forth msgs (including with agent's own
|
784
|
+
# responders and with sub-tasks)
|
785
|
+
# - F: the final result message, from this agent.
|
786
|
+
# Here we are deleting all of [X] from the agent's message history,
|
787
|
+
# so that it simply looks as if the sub-tasks never happened.
|
788
|
+
|
789
|
+
dropped = self.agent.message_history[
|
790
|
+
self.message_history_idx + 2 : n_messages - 1
|
791
|
+
]
|
792
|
+
# first delete the linked ChatDocuments (and descendants) from
|
793
|
+
# ObjectRegistry
|
794
|
+
for msg in dropped:
|
795
|
+
ChatDocument.delete_id(msg.chat_document_id)
|
796
|
+
# then delete the messages from the agent's message_history
|
704
797
|
del self.agent.message_history[
|
705
798
|
self.message_history_idx + 2 : n_messages - 1
|
706
799
|
]
|
@@ -750,9 +843,11 @@ class Task:
|
|
750
843
|
|
751
844
|
if (
|
752
845
|
Entity.USER in self.responders
|
846
|
+
and self.interactive
|
753
847
|
and not self.human_tried
|
754
848
|
and not self.agent.has_tool_message_attempt(self.pending_message)
|
755
849
|
):
|
850
|
+
# When in interactive mode,
|
756
851
|
# Give human first chance if they haven't been tried in last step,
|
757
852
|
# and the msg is not a tool-call attempt;
|
758
853
|
# This ensures human gets a chance to respond,
|
@@ -778,6 +873,8 @@ class Task:
|
|
778
873
|
recipient=recipient,
|
779
874
|
),
|
780
875
|
)
|
876
|
+
# no need to register this dummy msg in ObjectRegistry
|
877
|
+
ChatDocument.delete_id(log_doc.id())
|
781
878
|
self.log_message(r, log_doc)
|
782
879
|
continue
|
783
880
|
self.human_tried = r == Entity.USER
|
@@ -844,6 +941,7 @@ class Task:
|
|
844
941
|
|
845
942
|
if (
|
846
943
|
Entity.USER in self.responders
|
944
|
+
and self.interactive
|
847
945
|
and not self.human_tried
|
848
946
|
and not self.agent.has_tool_message_attempt(self.pending_message)
|
849
947
|
):
|
@@ -870,6 +968,8 @@ class Task:
|
|
870
968
|
recipient=recipient,
|
871
969
|
),
|
872
970
|
)
|
971
|
+
# no need to register this dummy msg in ObjectRegistry
|
972
|
+
ChatDocument.delete_id(log_doc.id())
|
873
973
|
self.log_message(r, log_doc)
|
874
974
|
continue
|
875
975
|
self.human_tried = r == Entity.USER
|
@@ -905,10 +1005,26 @@ class Task:
|
|
905
1005
|
# Contrast this with self.pending_message.metadata.sender, which is an ENTITY
|
906
1006
|
# of this agent, or a sub-task's agent.
|
907
1007
|
if not self.is_pass_thru:
|
908
|
-
|
909
|
-
|
910
|
-
|
1008
|
+
if (
|
1009
|
+
self.pending_message is not None
|
1010
|
+
and self.pending_message.metadata.agent_id == self.agent.id
|
1011
|
+
):
|
1012
|
+
# when pending msg is from our own agent, respect the sender set there,
|
1013
|
+
# since sometimes a response may "mock" as if the response is from
|
1014
|
+
# another entity (e.g when using RewindTool, the agent handler
|
1015
|
+
# returns a result as if it were from the LLM).
|
1016
|
+
self.pending_sender = result.metadata.sender
|
1017
|
+
else:
|
1018
|
+
# when pending msg is from a sub-task, the sender is the sub-task
|
1019
|
+
self.pending_sender = r
|
911
1020
|
self.pending_message = result
|
1021
|
+
# set the parent/child links ONLY if not already set by agent internally,
|
1022
|
+
# which may happen when using the RewindTool
|
1023
|
+
if parent is not None and not result.metadata.parent_id:
|
1024
|
+
result.metadata.parent_id = parent.id()
|
1025
|
+
if parent is not None and not parent.metadata.child_id:
|
1026
|
+
parent.metadata.child_id = result.id()
|
1027
|
+
|
912
1028
|
self.log_message(self.pending_sender, result, mark=True)
|
913
1029
|
self.step_progress = True
|
914
1030
|
self.task_progress = True
|
@@ -941,9 +1057,10 @@ class Task:
|
|
941
1057
|
responder = (
|
942
1058
|
Entity.LLM if self.pending_sender == Entity.USER else Entity.USER
|
943
1059
|
)
|
1060
|
+
parent_id = "" if parent is None else parent.id()
|
944
1061
|
self.pending_message = ChatDocument(
|
945
1062
|
content=NO_ANSWER,
|
946
|
-
metadata=ChatDocMetaData(sender=responder,
|
1063
|
+
metadata=ChatDocMetaData(sender=responder, parent_id=parent_id),
|
947
1064
|
)
|
948
1065
|
self.pending_sender = responder
|
949
1066
|
self.log_message(self.pending_sender, self.pending_message, mark=True)
|
@@ -1089,7 +1206,7 @@ class Task:
|
|
1089
1206
|
# regardless of which entity actually produced the result,
|
1090
1207
|
# when we return the result, we set entity to USER
|
1091
1208
|
# since to the "parent" task, this result is equivalent to a response from USER
|
1092
|
-
|
1209
|
+
result_doc = ChatDocument(
|
1093
1210
|
content=content,
|
1094
1211
|
function_call=fun_call,
|
1095
1212
|
tool_messages=tool_messages,
|
@@ -1101,8 +1218,14 @@ class Task:
|
|
1101
1218
|
sender_name=self.name,
|
1102
1219
|
recipient=recipient,
|
1103
1220
|
tool_ids=tool_ids,
|
1221
|
+
parent_id=result_msg.id() if result_msg else "",
|
1222
|
+
agent_id=str(self.agent.id),
|
1104
1223
|
),
|
1105
1224
|
)
|
1225
|
+
if self.pending_message is not None:
|
1226
|
+
self.pending_message.metadata.child_id = result_doc.id()
|
1227
|
+
|
1228
|
+
return result_doc
|
1106
1229
|
|
1107
1230
|
def _is_empty_message(self, msg: str | ChatDocument | None) -> bool:
|
1108
1231
|
"""
|
langroid/agent/tools/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
from . import google_search_tool
|
2
2
|
from . import recipient_tool
|
3
|
+
from . import rewind_tool
|
3
4
|
from .google_search_tool import GoogleSearchTool
|
4
5
|
from .recipient_tool import AddRecipientTool, RecipientTool
|
6
|
+
from .rewind_tool import RewindTool
|
5
7
|
|
6
8
|
__all__ = [
|
7
9
|
"GoogleSearchTool",
|
@@ -9,4 +11,6 @@ __all__ = [
|
|
9
11
|
"RecipientTool",
|
10
12
|
"google_search_tool",
|
11
13
|
"recipient_tool",
|
14
|
+
"rewind_tool",
|
15
|
+
"RewindTool",
|
12
16
|
]
|
@@ -0,0 +1,136 @@
|
|
1
|
+
"""
|
2
|
+
The `rewind_tool` is used to rewind to the `n`th previous Assistant message
|
3
|
+
and replace it with a new `content`. This is useful in several scenarios and
|
4
|
+
- saves token-cost + inference time,
|
5
|
+
- reduces distracting clutter in chat history, which helps improve response quality.
|
6
|
+
|
7
|
+
This is intended to mimic how a human user might use a chat interface, where they
|
8
|
+
go down a conversation path, and want to go back in history to "edit and re-submit"
|
9
|
+
a previous message, to get a better response.
|
10
|
+
|
11
|
+
See usage examples in `tests/main/test_rewind_tool.py`.
|
12
|
+
"""
|
13
|
+
|
14
|
+
from typing import List, Tuple
|
15
|
+
|
16
|
+
import langroid.language_models as lm
|
17
|
+
from langroid.agent.chat_agent import ChatAgent
|
18
|
+
from langroid.agent.chat_document import ChatDocument
|
19
|
+
from langroid.agent.tool_message import ToolMessage
|
20
|
+
|
21
|
+
|
22
|
+
def prune_messages(agent: ChatAgent, idx: int) -> ChatDocument | None:
|
23
|
+
"""
|
24
|
+
Clear the message history of agent, starting at index `idx`,
|
25
|
+
taking care to first clear all dependent messages (possibly from other agents'
|
26
|
+
message histories) that are linked to the message at `idx`, via the `child_id` field
|
27
|
+
of the `metadata` field of the ChatDocument linked from the message at `idx`.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
agent (ChatAgent): The agent whose message history is to be pruned.
|
31
|
+
idx (int): The index from which to start clearing the message history.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The parent ChatDocument of the ChatDocument linked from the message at `idx`,
|
35
|
+
if it exists, else None.
|
36
|
+
|
37
|
+
"""
|
38
|
+
assert idx >= 0, "Invalid index for message history!"
|
39
|
+
chat_doc_id = agent.message_history[idx].chat_document_id
|
40
|
+
chat_doc = ChatDocument.from_id(chat_doc_id)
|
41
|
+
assert chat_doc is not None, "ChatDocument not found in registry!"
|
42
|
+
|
43
|
+
parent = ChatDocument.from_id(chat_doc.metadata.parent_id) # may be None
|
44
|
+
# We're invaliding the msg at idx,
|
45
|
+
# so starting with chat_doc, go down the child links
|
46
|
+
# and clear history of each agent, to the msg_idx
|
47
|
+
curr_doc = chat_doc
|
48
|
+
while child_doc := curr_doc.metadata.child:
|
49
|
+
if child_doc.metadata.msg_idx >= 0:
|
50
|
+
child_agent = ChatAgent.from_id(child_doc.metadata.agent_id)
|
51
|
+
if child_agent is not None:
|
52
|
+
child_agent.clear_history(child_doc.metadata.msg_idx)
|
53
|
+
curr_doc = child_doc
|
54
|
+
|
55
|
+
# Clear out ObjectRegistry entries for this ChatDocuments
|
56
|
+
# and all descendants (in case they weren't already cleared above)
|
57
|
+
ChatDocument.delete_id(chat_doc.id())
|
58
|
+
|
59
|
+
# Finally, clear this agent's history back to idx,
|
60
|
+
# and replace the msg at idx with the new content
|
61
|
+
agent.clear_history(idx)
|
62
|
+
return parent
|
63
|
+
|
64
|
+
|
65
|
+
class RewindTool(ToolMessage):
|
66
|
+
"""
|
67
|
+
Used by LLM to rewind (i.e. backtrack) to the `n`th Assistant message
|
68
|
+
and replace with a new msg.
|
69
|
+
"""
|
70
|
+
|
71
|
+
request: str = "rewind_tool"
|
72
|
+
purpose: str = """
|
73
|
+
To rewind the conversation and replace the
|
74
|
+
<n>'th Assistant message with <content>
|
75
|
+
"""
|
76
|
+
n: int
|
77
|
+
content: str
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def examples(cls) -> List["ToolMessage" | Tuple[str, "ToolMessage"]]:
|
81
|
+
return [
|
82
|
+
cls(n=1, content="What are the 3 major causes of heart disease?"),
|
83
|
+
(
|
84
|
+
"""
|
85
|
+
I want to change my 2nd message to Bob, to say
|
86
|
+
'who wrote the book Grime and Banishment?'
|
87
|
+
""",
|
88
|
+
cls(n=2, content="who wrote the book 'Grime and Banishment'?"),
|
89
|
+
),
|
90
|
+
]
|
91
|
+
|
92
|
+
def response(self, agent: ChatAgent) -> str | ChatDocument:
|
93
|
+
"""
|
94
|
+
Define the tool-handler method for this tool here itself,
|
95
|
+
since it is a generic tool whose functionality should be the
|
96
|
+
same for any agent.
|
97
|
+
|
98
|
+
When LLM has correctly used this tool, rewind this agent's
|
99
|
+
`message_history` to the `n`th assistant msg, and replace it with `content`.
|
100
|
+
We need to mock it as if the LLM is sending this message.
|
101
|
+
|
102
|
+
Within a multi-agent scenario, this also means that any other messages dependent
|
103
|
+
on this message will need to be invalidated --
|
104
|
+
so go down the chain of child messages and clear each agent's history
|
105
|
+
back to the `msg_idx` corresponding to the child message.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
(ChatDocument): with content set to self.content.
|
109
|
+
"""
|
110
|
+
idx = agent.nth_message_idx_with_role(lm.Role.ASSISTANT, self.n)
|
111
|
+
if idx < 0:
|
112
|
+
# set up a corrective message from AGENT
|
113
|
+
msg = f"""
|
114
|
+
Could not rewind to {self.n}th Assistant message!
|
115
|
+
Please check the value of `n` and try again.
|
116
|
+
Or it may be too early to use the `rewind_tool`.
|
117
|
+
"""
|
118
|
+
return agent.create_agent_response(msg)
|
119
|
+
|
120
|
+
parent = prune_messages(agent, idx)
|
121
|
+
|
122
|
+
# create ChatDocument with new content, to be returned as result of this tool
|
123
|
+
result_doc = agent.create_llm_response(self.content)
|
124
|
+
result_doc.metadata.parent_id = "" if parent is None else parent.id()
|
125
|
+
result_doc.metadata.agent_id = agent.id
|
126
|
+
result_doc.metadata.msg_idx = idx
|
127
|
+
|
128
|
+
# replace the message at idx with this new message
|
129
|
+
agent.message_history.append(ChatDocument.to_LLMMessage(result_doc))
|
130
|
+
|
131
|
+
# set the replaced doc's parent's child to this result_doc
|
132
|
+
if parent is not None:
|
133
|
+
# first remove the this parent's child from registry
|
134
|
+
ChatDocument.delete_id(parent.metadata.child_id)
|
135
|
+
parent.metadata.child_id = result_doc.id()
|
136
|
+
return result_doc
|
@@ -20,6 +20,7 @@ from .openai_gpt import (
|
|
20
20
|
OpenAIGPTConfig,
|
21
21
|
OpenAIGPT,
|
22
22
|
)
|
23
|
+
from .mock_lm import MockLM, MockLMConfig
|
23
24
|
from .azure_openai import AzureConfig, AzureGPT
|
24
25
|
|
25
26
|
|
@@ -43,4 +44,6 @@ __all__ = [
|
|
43
44
|
"OpenAIGPT",
|
44
45
|
"AzureConfig",
|
45
46
|
"AzureGPT",
|
47
|
+
"MockLM",
|
48
|
+
"MockLMConfig",
|
46
49
|
]
|
langroid/language_models/base.py
CHANGED
@@ -4,7 +4,17 @@ import logging
|
|
4
4
|
from abc import ABC, abstractmethod
|
5
5
|
from datetime import datetime
|
6
6
|
from enum import Enum
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
Callable,
|
10
|
+
Dict,
|
11
|
+
List,
|
12
|
+
Optional,
|
13
|
+
Tuple,
|
14
|
+
Type,
|
15
|
+
Union,
|
16
|
+
cast,
|
17
|
+
)
|
8
18
|
|
9
19
|
from langroid.cachedb.base import CacheDBConfig
|
10
20
|
from langroid.parsing.agent_chats import parse_message
|
@@ -134,12 +144,15 @@ class LLMMessage(BaseModel):
|
|
134
144
|
content: str
|
135
145
|
function_call: Optional[LLMFunctionCall] = None
|
136
146
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
147
|
+
# link to corresponding chat document, for provenance/rewind purposes
|
148
|
+
chat_document_id: str = ""
|
137
149
|
|
138
150
|
def api_dict(self) -> Dict[str, Any]:
|
139
151
|
"""
|
140
|
-
Convert to dictionary for API request
|
141
|
-
|
142
|
-
|
152
|
+
Convert to dictionary for API request, keeping ONLY
|
153
|
+
the fields that are expected in an API call!
|
154
|
+
E.g., DROP the tool_id, since it is only for use in the Assistant API,
|
155
|
+
not the completion API.
|
143
156
|
Returns:
|
144
157
|
dict: dictionary representation of LLM message
|
145
158
|
"""
|
@@ -155,8 +168,10 @@ class LLMMessage(BaseModel):
|
|
155
168
|
dict_no_none["function_call"]["arguments"] = json.dumps(
|
156
169
|
dict_no_none["function_call"]["arguments"]
|
157
170
|
)
|
171
|
+
# IMPORTANT! drop fields that are not expected in API call
|
158
172
|
dict_no_none.pop("tool_id", None)
|
159
173
|
dict_no_none.pop("timestamp", None)
|
174
|
+
dict_no_none.pop("chat_document_id", None)
|
160
175
|
return dict_no_none
|
161
176
|
|
162
177
|
def __str__(self) -> str:
|
@@ -268,11 +283,15 @@ class LanguageModel(ABC):
|
|
268
283
|
"""
|
269
284
|
)
|
270
285
|
from langroid.language_models.azure_openai import AzureGPT
|
286
|
+
from langroid.language_models.mock_lm import MockLM, MockLMConfig
|
271
287
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
272
288
|
|
273
289
|
if config is None or config.type is None:
|
274
290
|
return None
|
275
291
|
|
292
|
+
if config.type == "mock":
|
293
|
+
return MockLM(cast(MockLMConfig, config))
|
294
|
+
|
276
295
|
openai: Union[Type[AzureGPT], Type[OpenAIGPT]]
|
277
296
|
|
278
297
|
if config.type == "azure":
|