grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -278
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +173 -176
  24. grasp_agents/run_context.py +21 -41
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -120
  47. grasp_agents/workflow/sequential_agent.py +0 -63
  48. grasp_agents/workflow/workflow_agent.py +0 -73
  49. grasp_agents-0.2.10.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,58 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any, Protocol
3
+
4
+ from pydantic import Field
5
+
6
+ from .memory import Memory
7
+ from .message_history import MessageHistory
8
+ from .run_context import RunContext
9
+ from .typing.io import LLMPrompt
10
+ from .typing.message import Message
11
+
12
+
13
+ class SetMemoryHandler(Protocol):
14
+ def __call__(
15
+ self,
16
+ prev_memory: "LLMAgentMemory",
17
+ in_args: Any | None,
18
+ sys_prompt: LLMPrompt | None,
19
+ ctx: RunContext[Any] | None,
20
+ ) -> "LLMAgentMemory": ...
21
+
22
+
23
+ class LLMAgentMemory(Memory):
24
+ message_history: MessageHistory = Field(default_factory=MessageHistory)
25
+
26
+ def reset(
27
+ self, sys_prompt: LLMPrompt | None = None, ctx: RunContext[Any] | None = None
28
+ ):
29
+ self.message_history.reset(sys_prompt=sys_prompt)
30
+
31
+ def update(
32
+ self,
33
+ *,
34
+ message_batch: Sequence[Message] | None = None,
35
+ message_list: Sequence[Message] | None = None,
36
+ ctx: RunContext[Any] | None = None,
37
+ ):
38
+ if message_batch is not None and message_list is not None:
39
+ raise ValueError(
40
+ "Only one of message_batch or messages should be provided."
41
+ )
42
+ if message_batch is not None:
43
+ self.message_history.add_message_batch(message_batch)
44
+ elif message_list is not None:
45
+ self.message_history.add_message_list(message_list)
46
+ else:
47
+ raise ValueError("Either message_batch or messages must be provided.")
48
+
49
+ @property
50
+ def is_empty(self) -> bool:
51
+ return len(self.message_history) == 0
52
+
53
+ @property
54
+ def batch_size(self) -> int:
55
+ return self.message_history.batch_size
56
+
57
+ def __repr__(self) -> str:
58
+ return f"Message History: {len(self.message_history)}"
@@ -0,0 +1,482 @@
1
+ import asyncio
2
+ import json
3
+ from collections.abc import AsyncIterator, Coroutine, Sequence
4
+ from itertools import starmap
5
+ from logging import getLogger
6
+ from typing import Any, ClassVar, Generic, Protocol, TypeVar
7
+
8
+ from pydantic import BaseModel
9
+
10
+ from .generics_utils import AutoInstanceAttributesMixin
11
+ from .llm import LLM, LLMSettings
12
+ from .llm_agent_memory import LLMAgentMemory
13
+ from .run_context import CtxT, RunContext
14
+ from .typing.completion import Completion
15
+ from .typing.converters import Converters
16
+ from .typing.events import (
17
+ CompletionChunkEvent,
18
+ CompletionEvent,
19
+ Event,
20
+ GenMessageEvent,
21
+ ToolCallEvent,
22
+ ToolMessageEvent,
23
+ UserMessageEvent,
24
+ )
25
+ from .typing.message import AssistantMessage, Messages, ToolMessage, UserMessage
26
+ from .typing.tool import BaseTool, NamedToolChoice, ToolCall, ToolChoice
27
+
28
+ logger = getLogger(__name__)
29
+
30
+
31
+ _FinalAnswerT = TypeVar("_FinalAnswerT")
32
+
33
+
34
+ class ExitToolCallLoopHandler(Protocol[CtxT]):
35
+ def __call__(
36
+ self,
37
+ conversation: Messages,
38
+ *,
39
+ ctx: RunContext[CtxT] | None,
40
+ **kwargs: Any,
41
+ ) -> bool: ...
42
+
43
+
44
+ class ManageMemoryHandler(Protocol[CtxT]):
45
+ def __call__(
46
+ self,
47
+ memory: LLMAgentMemory,
48
+ *,
49
+ ctx: RunContext[CtxT] | None,
50
+ **kwargs: Any,
51
+ ) -> None: ...
52
+
53
+
54
+ class LLMPolicyExecutor(AutoInstanceAttributesMixin, Generic[_FinalAnswerT, CtxT]):
55
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
56
+ 0: "_final_answer_type",
57
+ }
58
+
59
+ def __init__(
60
+ self,
61
+ agent_name: str,
62
+ llm: LLM[LLMSettings, Converters],
63
+ tools: list[BaseTool[BaseModel, Any, CtxT]] | None,
64
+ max_turns: int,
65
+ react_mode: bool = False,
66
+ final_answer_as_tool_call: bool = False,
67
+ ) -> None:
68
+ self._final_answer_type: type[_FinalAnswerT]
69
+ super().__init__()
70
+
71
+ self._agent_name = agent_name
72
+
73
+ _tools: list[BaseTool[BaseModel, Any, CtxT]] | None = tools
74
+ self._final_answer_tool_name: str | None = None
75
+ if tools and final_answer_as_tool_call:
76
+ final_answer_tool = self.get_final_answer_tool()
77
+ self._final_answer_tool_name = final_answer_tool.name
78
+ _tools = tools + [final_answer_tool]
79
+
80
+ self._llm = llm
81
+ self._llm.tools = _tools
82
+
83
+ self._max_turns = max_turns
84
+ self._react_mode = react_mode
85
+
86
+ self.exit_tool_call_loop_impl: ExitToolCallLoopHandler[CtxT] | None = None
87
+ self.manage_memory_impl: ManageMemoryHandler[CtxT] | None = None
88
+
89
+ @property
90
+ def agent_name(self) -> str:
91
+ return self._agent_name
92
+
93
+ @property
94
+ def llm(self) -> LLM[LLMSettings, Converters]:
95
+ return self._llm
96
+
97
+ @property
98
+ def tools(self) -> dict[str, BaseTool[BaseModel, Any, CtxT]]:
99
+ return self._llm.tools or {}
100
+
101
+ @property
102
+ def max_turns(self) -> int:
103
+ return self._max_turns
104
+
105
+ def _exit_tool_call_loop_fn(
106
+ self,
107
+ conversation: Messages,
108
+ *,
109
+ ctx: RunContext[CtxT] | None = None,
110
+ **kwargs: Any,
111
+ ) -> bool:
112
+ if self.exit_tool_call_loop_impl:
113
+ return self.exit_tool_call_loop_impl(conversation, ctx=ctx, **kwargs)
114
+
115
+ assert conversation, "Conversation must not be empty"
116
+ assert isinstance(conversation[-1], AssistantMessage), (
117
+ "Last message in conversation must be an AssistantMessage"
118
+ )
119
+
120
+ return not bool(conversation[-1].tool_calls)
121
+
122
+ def _manage_memory_fn(
123
+ self,
124
+ memory: LLMAgentMemory,
125
+ *,
126
+ ctx: RunContext[CtxT] | None = None,
127
+ **kwargs: Any,
128
+ ) -> None:
129
+ if self.manage_memory_impl:
130
+ self.manage_memory_impl(memory=memory, ctx=ctx, **kwargs)
131
+
132
+ async def generate_message_batch(
133
+ self,
134
+ memory: LLMAgentMemory,
135
+ tool_choice: ToolChoice | None = None,
136
+ ctx: RunContext[CtxT] | None = None,
137
+ ) -> Sequence[AssistantMessage]:
138
+ completion_batch = await self.llm.generate_completion_batch(
139
+ memory.message_history, tool_choice=tool_choice
140
+ )
141
+ if (
142
+ len(completion_batch[0].messages) > 1
143
+ and memory.message_history.batch_size > 1
144
+ ):
145
+ raise ValueError(
146
+ "Batch size must be 1 when generating completions with n>1."
147
+ )
148
+ message_batch = [c.messages[0] for c in completion_batch]
149
+ memory.update(message_batch=message_batch)
150
+
151
+ if ctx is not None:
152
+ ctx.completions[self.agent_name].extend(completion_batch)
153
+ self._track_usage(completion_batch, ctx=ctx)
154
+ self._print_completions(completion_batch, ctx=ctx)
155
+
156
+ return message_batch
157
+
158
+ async def generate_message_stream(
159
+ self,
160
+ memory: LLMAgentMemory,
161
+ tool_choice: ToolChoice | None = None,
162
+ ctx: RunContext[CtxT] | None = None,
163
+ ) -> AsyncIterator[CompletionChunkEvent | CompletionEvent | GenMessageEvent]:
164
+ message_hist = memory.message_history
165
+ if memory.message_history.batch_size > 1:
166
+ raise ValueError("Batch size must be 1 when streaming completions.")
167
+ conversation = message_hist.conversations[0]
168
+
169
+ completion: Completion | None = None
170
+ async for event in await self.llm.generate_completion_stream(
171
+ conversation, tool_choice=tool_choice
172
+ ):
173
+ yield event
174
+ if isinstance(event, CompletionEvent):
175
+ completion = event.data
176
+
177
+ if completion is None:
178
+ raise RuntimeError("No completion generated during stream.")
179
+ if len(completion.messages) > 1:
180
+ raise ValueError("Streaming completion must have n=1")
181
+
182
+ message = completion.messages[0]
183
+ memory.update(message_batch=[message])
184
+
185
+ yield GenMessageEvent(name=self.agent_name, data=message)
186
+
187
+ if ctx is not None:
188
+ self._track_usage([completion], ctx=ctx)
189
+ ctx.completions[self.agent_name].append(completion)
190
+
191
+ async def call_tools(
192
+ self,
193
+ calls: Sequence[ToolCall],
194
+ memory: LLMAgentMemory,
195
+ ctx: RunContext[CtxT] | None = None,
196
+ ) -> Sequence[ToolMessage]:
197
+ corouts: list[Coroutine[Any, Any, BaseModel]] = []
198
+ for call in calls:
199
+ tool = self.tools[call.tool_name]
200
+ args = json.loads(call.tool_arguments)
201
+ corouts.append(tool(ctx=ctx, **args))
202
+
203
+ outs = await asyncio.gather(*corouts)
204
+ tool_messages = list(
205
+ starmap(ToolMessage.from_tool_output, zip(outs, calls, strict=False))
206
+ )
207
+ memory.update(message_list=tool_messages)
208
+
209
+ if ctx is not None:
210
+ ctx.printer.print_llm_messages(tool_messages, agent_name=self.agent_name)
211
+
212
+ return tool_messages
213
+
214
+ async def call_tools_stream(
215
+ self,
216
+ calls: Sequence[ToolCall],
217
+ memory: LLMAgentMemory,
218
+ ctx: RunContext[CtxT] | None = None,
219
+ ) -> AsyncIterator[ToolMessageEvent]:
220
+ tool_messages = await self.call_tools(calls, memory=memory, ctx=ctx)
221
+ for tool_message, call in zip(tool_messages, calls, strict=False):
222
+ yield ToolMessageEvent(name=call.tool_name, data=tool_message)
223
+
224
+ def _extract_final_answer_from_tool_calls(
225
+ self, gen_message: AssistantMessage, memory: LLMAgentMemory
226
+ ) -> AssistantMessage | None:
227
+ final_answer_message: AssistantMessage | None = None
228
+ for tool_call in gen_message.tool_calls or []:
229
+ if tool_call.tool_name == self._final_answer_tool_name:
230
+ final_answer_message = AssistantMessage(
231
+ name=self.agent_name, content=tool_call.tool_arguments
232
+ )
233
+ gen_message.tool_calls = None
234
+ memory.update(message_list=[final_answer_message])
235
+ return final_answer_message
236
+
237
+ return final_answer_message
238
+
239
+ async def _generate_final_answer(
240
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
241
+ ) -> AssistantMessage:
242
+ assert self._final_answer_tool_name is not None
243
+
244
+ user_message = UserMessage.from_text(
245
+ "Exceeded the maximum number of turns: provide a final answer now!"
246
+ )
247
+ memory.update(message_list=[user_message])
248
+ if ctx is not None:
249
+ ctx.printer.print_llm_messages([user_message], agent_name=self.agent_name)
250
+
251
+ tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
252
+ gen_message = (
253
+ await self.generate_message_batch(memory, tool_choice=tool_choice, ctx=ctx)
254
+ )[0]
255
+
256
+ final_answer_message = self._extract_final_answer_from_tool_calls(
257
+ gen_message, memory=memory
258
+ )
259
+ if final_answer_message is None:
260
+ raise RuntimeError(
261
+ "Final answer tool call did not return a final answer message."
262
+ )
263
+ return final_answer_message
264
+
265
+ async def _generate_final_answer_stream(
266
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
267
+ ) -> AsyncIterator[Event[Any]]:
268
+ assert self._final_answer_tool_name is not None
269
+
270
+ user_message = UserMessage.from_text(
271
+ "Exceeded the maximum number of turns: provide a final answer now!",
272
+ )
273
+ memory.update(message_list=[user_message])
274
+ yield UserMessageEvent(name=self.agent_name, data=user_message)
275
+
276
+ tool_choice = NamedToolChoice(name=self._final_answer_tool_name)
277
+ event: Event[Any] | None = None
278
+ async for event in self.generate_message_stream(
279
+ memory, tool_choice=tool_choice, ctx=ctx
280
+ ):
281
+ yield event
282
+
283
+ assert isinstance(event, GenMessageEvent)
284
+ gen_message = event.data
285
+ final_answer_message = self._extract_final_answer_from_tool_calls(
286
+ gen_message, memory=memory
287
+ )
288
+ if final_answer_message is None:
289
+ raise RuntimeError(
290
+ "Final answer tool call did not return a final answer message."
291
+ )
292
+ yield GenMessageEvent(name=self.agent_name, data=final_answer_message)
293
+
294
+ async def execute(
295
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
296
+ ) -> AssistantMessage | Sequence[AssistantMessage]:
297
+ # 1. Generate the first message:
298
+ # In ReAct mode, we generate the first message without tool calls
299
+ # to force the agent to plan its actions in a separate message.
300
+ tool_choice: ToolChoice | None = None
301
+ if self.tools:
302
+ tool_choice = "none" if self._react_mode else "auto"
303
+ gen_message_batch = await self.generate_message_batch(
304
+ memory, tool_choice=tool_choice, ctx=ctx
305
+ )
306
+ if not self.tools:
307
+ return gen_message_batch
308
+
309
+ if memory.message_history.batch_size > 1:
310
+ raise ValueError("Batch size must be 1 for tool call loop.")
311
+ gen_message = gen_message_batch[0]
312
+ turns = 0
313
+
314
+ while True:
315
+ conversation = memory.message_history.conversations[0]
316
+
317
+ # 2. Check if we should exit the tool call loop
318
+
319
+ # When final_answer_tool_name is None, we use exit_tool_call_loop_impl
320
+ # to determine whether to exit the loop.
321
+ if self._final_answer_tool_name is None and self._exit_tool_call_loop_fn(
322
+ conversation, ctx=ctx, num_turns=turns
323
+ ):
324
+ return gen_message
325
+
326
+ # When final_answer_tool_name is set, we check if the last message contains
327
+ # a tool call to the final answer tool. If it does, we exit the loop.
328
+ if self._final_answer_tool_name is not None:
329
+ final_answer = self._extract_final_answer_from_tool_calls(
330
+ gen_message, memory=memory
331
+ )
332
+ if final_answer is not None:
333
+ return final_answer
334
+
335
+ # Exit if the maximum number of turns is reached
336
+ if turns >= self.max_turns:
337
+ # When final_answer_tool_name is set, we force the agent to provide
338
+ # a final answer by generating a message with a final answer
339
+ # tool call.
340
+ # Otherwise, we simply return the last generated message.
341
+ if self._final_answer_tool_name is not None:
342
+ final_answer = await self._generate_final_answer(memory, ctx=ctx)
343
+ else:
344
+ final_answer = gen_message
345
+ logger.info(
346
+ f"Max turns reached: {self.max_turns}. Exiting the tool call loop."
347
+ )
348
+ return final_answer
349
+
350
+ # 3. Call tools if there are any tool calls in the generated message.
351
+
352
+ if gen_message.tool_calls:
353
+ await self.call_tools(gen_message.tool_calls, memory=memory, ctx=ctx)
354
+
355
+ # Apply the memory management function if provided.
356
+ self._manage_memory_fn(memory, ctx=ctx, num_turns=turns)
357
+
358
+ # 4. Generate the next message based on the updated memory.
359
+ # In ReAct mode, we set tool_choice to "none" if we just called tools,
360
+ # so the next message will be an observation/planning message with
361
+ # no immediate tool calls.
362
+ # If we are not in ReAct mode, we set tool_choice to "auto" to allow
363
+ # the LLM to choose freely whether to call tools.
364
+
365
+ tool_choice = (
366
+ "none" if (self._react_mode and gen_message.tool_calls) else "required"
367
+ )
368
+ gen_message = (
369
+ await self.generate_message_batch(
370
+ memory, tool_choice=tool_choice, ctx=ctx
371
+ )
372
+ )[0]
373
+
374
+ turns += 1
375
+
376
+ async def execute_stream(
377
+ self, memory: LLMAgentMemory, ctx: RunContext[CtxT] | None = None
378
+ ) -> AsyncIterator[Event[Any]]:
379
+ if memory.message_history.batch_size > 1:
380
+ raise ValueError("Batch size must be 1 when streaming.")
381
+
382
+ tool_choice: ToolChoice = "none" if self._react_mode else "auto"
383
+ gen_message: AssistantMessage | None = None
384
+ async for event in self.generate_message_stream(
385
+ memory, tool_choice=tool_choice, ctx=ctx
386
+ ):
387
+ yield event
388
+ if isinstance(event, GenMessageEvent):
389
+ gen_message = event.data
390
+ assert isinstance(gen_message, AssistantMessage)
391
+
392
+ turns = 0
393
+
394
+ while True:
395
+ conversation = memory.message_history.conversations[0]
396
+
397
+ if self._final_answer_tool_name is None and self._exit_tool_call_loop_fn(
398
+ conversation, ctx=ctx, num_turns=turns
399
+ ):
400
+ return
401
+
402
+ if self._final_answer_tool_name is not None:
403
+ final_answer_message = self._extract_final_answer_from_tool_calls(
404
+ gen_message, memory=memory
405
+ )
406
+ if final_answer_message is not None:
407
+ yield GenMessageEvent(
408
+ name=self.agent_name, data=final_answer_message
409
+ )
410
+ return
411
+
412
+ if turns >= self.max_turns:
413
+ if self._final_answer_tool_name is not None:
414
+ async for event in self._generate_final_answer_stream(
415
+ memory, ctx=ctx
416
+ ):
417
+ yield event
418
+ logger.info(
419
+ f"Max turns reached: {self.max_turns}. Exiting the tool call loop."
420
+ )
421
+ return
422
+
423
+ if gen_message.tool_calls:
424
+ for tool_call in gen_message.tool_calls:
425
+ yield ToolCallEvent(name=self.agent_name, data=tool_call)
426
+
427
+ async for tool_message_event in self.call_tools_stream(
428
+ gen_message.tool_calls, memory=memory, ctx=ctx
429
+ ):
430
+ yield tool_message_event
431
+
432
+ self._manage_memory_fn(memory, ctx=ctx, num_turns=turns)
433
+
434
+ tool_choice = (
435
+ "none" if (self._react_mode and gen_message.tool_calls) else "required"
436
+ )
437
+ async for event in self.generate_message_stream(
438
+ memory, tool_choice=tool_choice, ctx=ctx
439
+ ):
440
+ yield event
441
+ if isinstance(event, GenMessageEvent):
442
+ gen_message = event.data
443
+
444
+ turns += 1
445
+
446
+ def _track_usage(
447
+ self, completion_batch: Sequence[Completion], ctx: RunContext[CtxT]
448
+ ) -> None:
449
+ ctx.usage_tracker.update(
450
+ completions=completion_batch, model_name=self.llm.model_name
451
+ )
452
+
453
+ def get_final_answer_tool(self) -> BaseTool[BaseModel, None, Any]:
454
+ if not issubclass(self._final_answer_type, BaseModel):
455
+ raise TypeError(
456
+ "final_answer_type must be a subclass of BaseModel to create "
457
+ "a final answer tool."
458
+ )
459
+
460
+ class FinalAnswerTool(BaseTool[self._final_answer_type, None, Any]):
461
+ name: str = "final_answer"
462
+ description: str = (
463
+ "You must use this tool to provide the final answer. "
464
+ "Do not provide the final answer anywhere else. "
465
+ "Input arguments correspond to the final answer."
466
+ )
467
+
468
+ async def run(
469
+ self, inp: _FinalAnswerT, ctx: RunContext[Any] | None = None
470
+ ) -> None:
471
+ return None
472
+
473
+ return FinalAnswerTool()
474
+
475
+ def _print_completions(
476
+ self, completion_batch: Sequence[Completion], ctx: RunContext[CtxT]
477
+ ) -> None:
478
+ messages = [c.messages[0] for c in completion_batch]
479
+ usages = [c.usage for c in completion_batch]
480
+ ctx.printer.print_llm_messages(
481
+ messages, usages=usages, agent_name=self.agent_name
482
+ )
grasp_agents/memory.py CHANGED
@@ -1,144 +1,30 @@
1
- import logging
2
- from collections.abc import Iterator, Sequence
3
- from copy import deepcopy
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
4
3
 
5
- from .typing.io import LLMPrompt
6
- from .typing.message import Conversation, Message, SystemMessage
4
+ from pydantic import BaseModel, ConfigDict
7
5
 
8
- logger = logging.getLogger(__name__)
6
+ from .run_context import RunContext
9
7
 
10
8
 
11
- class MessageHistory:
12
- def __init__(self, sys_prompt: LLMPrompt | None = None) -> None:
13
- self._sys_prompt = sys_prompt
14
- self._batched_conversations: list[Conversation]
15
- self.reset()
16
-
17
- @property
18
- def sys_prompt(self) -> LLMPrompt | None:
19
- return self._sys_prompt
20
-
21
- def add_message_batch(self, message_batch: Sequence[Message]) -> None:
22
- """
23
- Adds a batch of messages to the current batched conversations.
24
- This method verifies that the size of the input message batch matches
25
- the expected batch size (self.batch_size).
26
- If there is a mismatch, the method adjusts by duplicating either
27
- the message or the conversation as necessary:
28
-
29
- - If the message batch contains exactly one message and
30
- self.batch_size > 1, the single message is duplicated to match
31
- the batch size.
32
- - If the message batch contains multiple messages but
33
- self.batch_size == 1, the entire conversation is duplicated to
34
- accommodate each message in the batch.
35
- - If the message batch size does not match self.batch_size and none of
36
- the above adjustments apply, a ValueError is raised.
37
-
38
- Afterwards, each message in the batch is appended to its corresponding
39
- conversation in the batched conversations.
40
-
41
- Args:
42
- message_batch: A sequence of Message objects
43
- representing the batch of messages to be added. Must align with
44
- or be adjusted to match the current batch size.
45
-
46
- Raises:
47
- ValueError: If the message batch size does not match the current
48
- batch size and cannot be automatically adjusted.
49
-
50
- """
51
- message_batch_size = len(message_batch)
52
-
53
- if message_batch_size == 1 and self.batch_size > 1:
54
- logger.info(
55
- "Message batch size is 1, current batch size is "
56
- f"{self.batch_size}: duplicating the message to match the "
57
- "current batch size"
58
- )
59
- message_batch = self._duplicate_message_to_current_batch_size(message_batch)
60
- message_batch_size = self.batch_size
61
- elif message_batch_size > 1 and self.batch_size == 1:
62
- logger.info(
63
- f"Message batch size is {len(message_batch)}, current batch "
64
- "size is 1: duplicating the conversation to match the message "
65
- "batch size"
66
- )
67
- self._duplicate_conversation_to_message_batch_size(message_batch_size)
68
- elif message_batch_size != self.batch_size:
69
- raise ValueError(
70
- f"Message batch size {message_batch_size} does not match "
71
- f"current batch size {self.batch_size}"
72
- )
73
-
74
- for batch_id in range(message_batch_size):
75
- self._batched_conversations[batch_id].append(message_batch[batch_id])
76
-
77
- def add_message_batches(self, message_batches: Sequence[Sequence[Message]]) -> None:
78
- for message_batch in message_batches:
79
- self.add_message_batch(message_batch)
80
-
81
- def add_message(self, message: Message) -> None:
82
- for conversation in self._batched_conversations:
83
- conversation.append(message)
84
-
85
- def add_messages(self, messages: Sequence[Message]) -> None:
86
- for message in messages:
87
- self.add_message(message)
88
-
89
- def __len__(self) -> int:
90
- return len(self._batched_conversations[0])
91
-
92
- def __repr__(self) -> str:
93
- return f"{self.__class__.__name__}(len={len(self)}; bs={self.batch_size})"
94
-
95
- def __getitem__(self, idx: int) -> tuple[Message, ...]:
96
- return tuple(conversation[idx] for conversation in self._batched_conversations)
97
-
98
- def __iter__(self) -> Iterator[tuple[Message, ...]]:
99
- for idx in range(len(self)):
100
- yield tuple(
101
- conversation[idx] for conversation in self._batched_conversations
102
- )
103
-
104
- def _duplicate_message_to_current_batch_size(
105
- self, message_batch: Sequence[Message]
106
- ) -> Sequence[Message]:
107
- assert len(message_batch) == 1, (
108
- "Message batch size must be 1 to duplicate to current batch size"
109
- )
110
-
111
- return [deepcopy(message_batch[0]) for _ in range(self.batch_size)]
112
-
113
- def _duplicate_conversation_to_message_batch_size(
114
- self, target_batch_size: int
9
+ class Memory(BaseModel, ABC):
10
+ @abstractmethod
11
+ def reset(
12
+ self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
115
13
  ) -> None:
116
- assert self.batch_size == 1, "Batch size must be 1 to duplicate conversation"
117
- self._batched_conversations = [
118
- deepcopy(self._batched_conversations[0]) for _ in range(target_batch_size)
119
- ]
14
+ pass
120
15
 
121
- @property
122
- def batched_conversations(self) -> list[Conversation]:
123
- return self._batched_conversations
124
-
125
- @property
126
- def batch_size(self) -> int:
127
- return len(self._batched_conversations)
128
-
129
- def reset(
130
- self, sys_prompt: LLMPrompt | None = None, *, batch_size: int = 1
16
+ @abstractmethod
17
+ def update(
18
+ self, *args: Any, ctx: RunContext[Any] | None = None, **kwargs: Any
131
19
  ) -> None:
132
- if sys_prompt is not None:
133
- self._sys_prompt = sys_prompt
20
+ pass
134
21
 
135
- conv: Conversation
136
- if self._sys_prompt is not None:
137
- conv = [SystemMessage(content=self._sys_prompt)]
138
- else:
139
- conv = []
22
+ @property
23
+ @abstractmethod
24
+ def is_empty(self) -> bool:
25
+ pass
140
26
 
141
- self._batched_conversations = [deepcopy(conv) for _ in range(batch_size)]
27
+ def __repr__(self) -> str:
28
+ return f"{self.__class__.__name__}()"
142
29
 
143
- def erase(self) -> None:
144
- self._batched_conversations = [[]]
30
+ model_config = ConfigDict(arbitrary_types_allowed=True)