grasp_agents 0.2.10__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. grasp_agents/__init__.py +15 -14
  2. grasp_agents/cloud_llm.py +118 -131
  3. grasp_agents/comm_processor.py +201 -0
  4. grasp_agents/generics_utils.py +15 -7
  5. grasp_agents/llm.py +60 -31
  6. grasp_agents/llm_agent.py +229 -278
  7. grasp_agents/llm_agent_memory.py +58 -0
  8. grasp_agents/llm_policy_executor.py +482 -0
  9. grasp_agents/memory.py +20 -134
  10. grasp_agents/message_history.py +140 -0
  11. grasp_agents/openai/__init__.py +54 -36
  12. grasp_agents/openai/completion_chunk_converters.py +78 -0
  13. grasp_agents/openai/completion_converters.py +53 -30
  14. grasp_agents/openai/content_converters.py +13 -14
  15. grasp_agents/openai/converters.py +44 -68
  16. grasp_agents/openai/message_converters.py +58 -72
  17. grasp_agents/openai/openai_llm.py +101 -42
  18. grasp_agents/openai/tool_converters.py +24 -19
  19. grasp_agents/packet.py +24 -0
  20. grasp_agents/packet_pool.py +91 -0
  21. grasp_agents/printer.py +29 -15
  22. grasp_agents/processor.py +194 -0
  23. grasp_agents/prompt_builder.py +173 -176
  24. grasp_agents/run_context.py +21 -41
  25. grasp_agents/typing/completion.py +58 -12
  26. grasp_agents/typing/completion_chunk.py +173 -0
  27. grasp_agents/typing/converters.py +8 -12
  28. grasp_agents/typing/events.py +86 -0
  29. grasp_agents/typing/io.py +4 -13
  30. grasp_agents/typing/message.py +12 -50
  31. grasp_agents/typing/tool.py +52 -26
  32. grasp_agents/usage_tracker.py +6 -6
  33. grasp_agents/utils.py +3 -3
  34. grasp_agents/workflow/looped_workflow.py +132 -0
  35. grasp_agents/workflow/parallel_processor.py +95 -0
  36. grasp_agents/workflow/sequential_workflow.py +66 -0
  37. grasp_agents/workflow/workflow_processor.py +78 -0
  38. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/METADATA +41 -50
  39. grasp_agents-0.3.1.dist-info/RECORD +51 -0
  40. grasp_agents/agent_message.py +0 -27
  41. grasp_agents/agent_message_pool.py +0 -92
  42. grasp_agents/base_agent.py +0 -51
  43. grasp_agents/comm_agent.py +0 -217
  44. grasp_agents/llm_agent_state.py +0 -79
  45. grasp_agents/tool_orchestrator.py +0 -203
  46. grasp_agents/workflow/looped_agent.py +0 -120
  47. grasp_agents/workflow/sequential_agent.py +0 -63
  48. grasp_agents/workflow/workflow_agent.py +0 -73
  49. grasp_agents-0.2.10.dist-info/RECORD +0 -46
  50. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/WHEEL +0 -0
  51. {grasp_agents-0.2.10.dist-info → grasp_agents-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,38 +1,43 @@
1
1
  from typing import Any
2
2
 
3
+ from openai import pydantic_function_tool
3
4
  from pydantic import BaseModel
4
5
 
5
- from ..typing.tool import BaseTool, ToolChoice
6
+ from ..typing.tool import BaseTool, NamedToolChoice, ToolChoice
6
7
  from . import (
7
- ChatCompletionFunctionDefinition,
8
- ChatCompletionNamedToolChoiceFunction,
9
- ChatCompletionNamedToolChoiceParam,
10
- ChatCompletionToolChoiceOptionParam,
11
- ChatCompletionToolParam,
8
+ OpenAIFunctionDefinition,
9
+ OpenAINamedToolChoiceFunction,
10
+ OpenAINamedToolChoiceParam,
11
+ OpenAIToolChoiceOptionParam,
12
+ OpenAIToolParam,
12
13
  )
13
14
 
14
15
 
15
16
  def to_api_tool(
16
- tool: BaseTool[BaseModel, Any, Any],
17
- ) -> ChatCompletionToolParam:
18
- function = ChatCompletionFunctionDefinition(
17
+ tool: BaseTool[BaseModel, Any, Any], strict: bool | None = None
18
+ ) -> OpenAIToolParam:
19
+ if strict:
20
+ # Enforce strict mode for pydantic models
21
+ return pydantic_function_tool(
22
+ model=tool.in_type, name=tool.name, description=tool.description
23
+ )
24
+
25
+ function = OpenAIFunctionDefinition(
19
26
  name=tool.name,
20
27
  description=tool.description,
21
- parameters=tool.in_schema.model_json_schema(),
22
- strict=tool.strict,
28
+ parameters=tool.in_type.model_json_schema(),
29
+ strict=strict,
23
30
  )
24
- if tool.strict is None:
31
+ if strict is None:
25
32
  function.pop("strict")
26
33
 
27
- return ChatCompletionToolParam(type="function", function=function)
34
+ return OpenAIToolParam(type="function", function=function)
28
35
 
29
36
 
30
- def to_api_tool_choice(
31
- tool_choice: ToolChoice,
32
- ) -> ChatCompletionToolChoiceOptionParam:
33
- if isinstance(tool_choice, BaseTool):
34
- return ChatCompletionNamedToolChoiceParam(
37
+ def to_api_tool_choice(tool_choice: ToolChoice) -> OpenAIToolChoiceOptionParam:
38
+ if isinstance(tool_choice, NamedToolChoice):
39
+ return OpenAINamedToolChoiceParam(
35
40
  type="function",
36
- function=ChatCompletionNamedToolChoiceFunction(name=tool_choice.name),
41
+ function=OpenAINamedToolChoiceFunction(name=tool_choice.name),
37
42
  )
38
43
  return tool_choice
grasp_agents/packet.py ADDED
@@ -0,0 +1,24 @@
1
+ from collections.abc import Sequence
2
+ from typing import Generic, TypeVar
3
+ from uuid import uuid4
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+ from .typing.io import ProcName
8
+
9
+ _PayloadT_co = TypeVar("_PayloadT_co", covariant=True)
10
+
11
+
12
+ class Packet(BaseModel, Generic[_PayloadT_co]):
13
+ id: str = Field(default_factory=lambda: str(uuid4())[:8])
14
+ payloads: Sequence[_PayloadT_co]
15
+ sender: ProcName
16
+ recipients: Sequence[ProcName] = Field(default_factory=list)
17
+
18
+ model_config = ConfigDict(extra="forbid", frozen=True)
19
+
20
+ def __repr__(self) -> str:
21
+ return (
22
+ f"From: {self.sender}, To: {', '.join(self.recipients)}, "
23
+ f"Payloads: {len(self.payloads)}"
24
+ )
@@ -0,0 +1,91 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Any, Generic, Protocol, TypeVar
4
+
5
+ from .packet import Packet
6
+ from .run_context import CtxT, RunContext
7
+ from .typing.io import ProcName
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ _PayloadT_contra = TypeVar("_PayloadT_contra", contravariant=True)
13
+
14
+
15
+ class PacketHandler(Protocol[_PayloadT_contra, CtxT]):
16
+ async def __call__(
17
+ self,
18
+ packet: Packet[_PayloadT_contra],
19
+ ctx: RunContext[CtxT] | None,
20
+ **kwargs: Any,
21
+ ) -> None: ...
22
+
23
+
24
+ class PacketPool(Generic[CtxT]):
25
+ def __init__(self) -> None:
26
+ self._queues: dict[ProcName, asyncio.Queue[Packet[Any]]] = {}
27
+ self._packet_handlers: dict[ProcName, PacketHandler[Any, CtxT]] = {}
28
+ self._tasks: dict[ProcName, asyncio.Task[None]] = {}
29
+
30
+ async def post(self, packet: Packet[Any]) -> None:
31
+ for recipient_id in packet.recipients:
32
+ queue = self._queues.setdefault(recipient_id, asyncio.Queue())
33
+ await queue.put(packet)
34
+
35
+ def register_packet_handler(
36
+ self,
37
+ processor_name: ProcName,
38
+ handler: PacketHandler[Any, CtxT],
39
+ ctx: RunContext[CtxT] | None = None,
40
+ **run_kwargs: Any,
41
+ ) -> None:
42
+ self._packet_handlers[processor_name] = handler
43
+ self._queues.setdefault(processor_name, asyncio.Queue())
44
+ if processor_name not in self._tasks:
45
+ self._tasks[processor_name] = asyncio.create_task(
46
+ self._handle_packets(processor_name, ctx=ctx, **run_kwargs)
47
+ )
48
+
49
+ async def _handle_packets(
50
+ self,
51
+ processor_name: ProcName,
52
+ ctx: RunContext[CtxT] | None = None,
53
+ **run_kwargs: Any,
54
+ ) -> None:
55
+ queue = self._queues[processor_name]
56
+ while True:
57
+ try:
58
+ packet = await queue.get()
59
+ handler = self._packet_handlers.get(processor_name)
60
+ if handler is None:
61
+ break
62
+
63
+ try:
64
+ await self._packet_handlers[processor_name](
65
+ packet, ctx=ctx, **run_kwargs
66
+ )
67
+ except Exception:
68
+ logger.exception(f"Error handling packet for {processor_name}")
69
+
70
+ queue.task_done()
71
+
72
+ except Exception:
73
+ logger.exception(
74
+ f"Unexpected error in processing loop for {processor_name}"
75
+ )
76
+
77
+ async def unregister_packet_handler(self, processor_name: ProcName) -> None:
78
+ if task := self._tasks.get(processor_name):
79
+ task.cancel()
80
+ try:
81
+ await task
82
+ except asyncio.CancelledError:
83
+ logger.debug(f"{processor_name} exited")
84
+
85
+ self._tasks.pop(processor_name, None)
86
+ self._queues.pop(processor_name, None)
87
+ self._packet_handlers.pop(processor_name, None)
88
+
89
+ async def stop_all(self) -> None:
90
+ for processor_name in list(self._tasks):
91
+ await self.unregister_packet_handler(processor_name)
grasp_agents/printer.py CHANGED
@@ -6,13 +6,14 @@ from typing import Literal, TypeAlias
6
6
 
7
7
  from termcolor._types import Color # type: ignore[import]
8
8
 
9
+ from .typing.completion import Usage
9
10
  from .typing.content import Content, ContentPartText
10
11
  from .typing.message import AssistantMessage, Message, Role, ToolMessage
11
12
 
12
13
  logger = logging.getLogger(__name__)
13
14
 
14
15
 
15
- ColoringMode: TypeAlias = Literal["agent_id", "role"]
16
+ ColoringMode: TypeAlias = Literal["agent", "role"]
16
17
 
17
18
  ROLE_TO_COLOR: Mapping[Role, Color] = {
18
19
  Role.SYSTEM: "magenta",
@@ -50,9 +51,9 @@ class Printer:
50
51
  return ROLE_TO_COLOR[role]
51
52
 
52
53
  @staticmethod
53
- def get_agent_color(agent_id: str) -> Color:
54
+ def get_agent_color(agent_name: str) -> Color:
54
55
  idx = int(
55
- hashlib.md5(agent_id.encode()).hexdigest(), # noqa :S324
56
+ hashlib.md5(agent_name.encode()).hexdigest(), # noqa :S324
56
57
  16,
57
58
  ) % len(AVAILABLE_COLORS)
58
59
 
@@ -82,16 +83,22 @@ class Printer:
82
83
 
83
84
  return content_str
84
85
 
85
- def print_llm_message(self, message: Message, agent_id: str) -> None:
86
+ def print_llm_message(
87
+ self, message: Message, agent_name: str, usage: Usage | None = None
88
+ ) -> None:
86
89
  if not self.print_messages:
87
90
  return
88
91
 
92
+ if usage is not None and not isinstance(message, AssistantMessage):
93
+ raise ValueError(
94
+ "Usage information can only be printed for AssistantMessage"
95
+ )
96
+
89
97
  role = message.role
90
- usage = message.usage if isinstance(message, AssistantMessage) else None
91
98
  content_str = self.content_to_str(message.content or "", message.role)
92
99
 
93
- if self.color_by == "agent_id":
94
- color = self.get_agent_color(agent_id)
100
+ if self.color_by == "agent":
101
+ color = self.get_agent_color(agent_name)
95
102
  elif self.color_by == "role":
96
103
  color = self.get_role_color(role)
97
104
 
@@ -99,11 +106,11 @@ class Printer:
99
106
 
100
107
  # Print message title
101
108
 
102
- out = f"\n<{agent_id}>"
109
+ out = f"\n<{agent_name}>"
103
110
  out += "[" + role.value.upper() + "]"
104
111
 
105
112
  if isinstance(message, ToolMessage):
106
- out += f"\nTool call ID: {message.tool_call_id}"
113
+ out += f"\n{message.name} | {message.tool_call_id}"
107
114
 
108
115
  # Print message content
109
116
 
@@ -123,12 +130,12 @@ class Printer:
123
130
 
124
131
  if isinstance(message, AssistantMessage) and message.tool_calls is not None:
125
132
  for tool_call in message.tool_calls:
126
- if self.color_by == "agent_id":
127
- tool_color = self.get_agent_color(agent_id=agent_id)
133
+ if self.color_by == "agent":
134
+ tool_color = self.get_agent_color(agent_name=agent_name)
128
135
  elif self.color_by == "role":
129
136
  tool_color = self.get_role_color(role=Role.TOOL)
130
137
  logger.debug(
131
- f"\n[TOOL_CALL]<{agent_id}>\n{tool_call.tool_name} "
138
+ f"\n[TOOL_CALL]<{agent_name}>\n{tool_call.tool_name} "
132
139
  f"| {tool_call.id}\n{tool_call.tool_arguments}",
133
140
  extra={"color": tool_color}, # type: ignore
134
141
  )
@@ -148,9 +155,16 @@ class Printer:
148
155
  **log_kwargs, # type: ignore
149
156
  )
150
157
 
151
- def print_llm_messages(self, messages: Sequence[Message], agent_id: str) -> None:
158
+ def print_llm_messages(
159
+ self,
160
+ messages: Sequence[Message],
161
+ agent_name: str,
162
+ usages: Sequence[Usage | None] | None = None,
163
+ ) -> None:
152
164
  if not self.print_messages:
153
165
  return
154
166
 
155
- for message in messages:
156
- self.print_llm_message(message, agent_id)
167
+ _usages: Sequence[Usage | None] = usages or [None] * len(messages)
168
+
169
+ for _message, _usage in zip(messages, _usages, strict=False):
170
+ self.print_llm_message(_message, usage=_usage, agent_name=agent_name)
@@ -0,0 +1,194 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import AsyncIterator, Sequence
3
+ from typing import Any, ClassVar, Generic, cast, final
4
+
5
+ from pydantic import BaseModel, TypeAdapter
6
+
7
+ from .generics_utils import AutoInstanceAttributesMixin
8
+ from .packet import Packet
9
+ from .run_context import CtxT, RunContext
10
+ from .typing.events import Event, PacketEvent, ProcOutputEvent
11
+ from .typing.io import InT_contra, MemT_co, OutT_co, ProcName
12
+ from .typing.tool import BaseTool
13
+
14
+
15
+ class Processor(
16
+ AutoInstanceAttributesMixin, ABC, Generic[InT_contra, OutT_co, MemT_co, CtxT]
17
+ ):
18
+ _generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
19
+ 0: "_in_type",
20
+ 1: "_out_type",
21
+ }
22
+
23
+ @abstractmethod
24
+ def __init__(self, name: ProcName, **kwargs: Any) -> None:
25
+ self._in_type: type[InT_contra]
26
+ self._out_type: type[OutT_co]
27
+
28
+ super().__init__()
29
+
30
+ self._in_type_adapter: TypeAdapter[InT_contra] = TypeAdapter(self._in_type)
31
+ self._out_type_adapter: TypeAdapter[OutT_co] = TypeAdapter(self._out_type)
32
+
33
+ self._name: ProcName = name
34
+ self._memory: MemT_co
35
+
36
+ @property
37
+ def in_type(self) -> type[InT_contra]: # type: ignore[reportInvalidTypeVarUse]
38
+ # Exposing the type of a contravariant variable only, should be type safe
39
+ return self._in_type
40
+
41
+ @property
42
+ def out_type(self) -> type[OutT_co]:
43
+ return self._out_type
44
+
45
+ @property
46
+ def name(self) -> ProcName:
47
+ return self._name
48
+
49
+ @property
50
+ def memory(self) -> MemT_co:
51
+ return self._memory
52
+
53
+ def _validate_and_resolve_inputs(
54
+ self,
55
+ chat_inputs: Any | None = None,
56
+ in_packet: Packet[InT_contra] | None = None,
57
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
58
+ ) -> Sequence[InT_contra] | None:
59
+ multiple_inputs_err_message = (
60
+ "Only one of chat_inputs, in_args, or in_message must be provided."
61
+ )
62
+ if chat_inputs is not None and in_args is not None:
63
+ raise ValueError(multiple_inputs_err_message)
64
+ if chat_inputs is not None and in_packet is not None:
65
+ raise ValueError(multiple_inputs_err_message)
66
+ if in_args is not None and in_packet is not None:
67
+ raise ValueError(multiple_inputs_err_message)
68
+
69
+ resolved_in_args: Sequence[InT_contra] | None = None
70
+ if in_packet is not None:
71
+ resolved_in_args = in_packet.payloads
72
+ elif isinstance(in_args, self._in_type):
73
+ resolved_in_args = cast("Sequence[InT_contra]", [in_args])
74
+ elif in_args is None:
75
+ resolved_in_args = in_args
76
+ else:
77
+ resolved_in_args = cast("Sequence[InT_contra]", in_args)
78
+
79
+ return resolved_in_args
80
+
81
+ async def _process(
82
+ self,
83
+ chat_inputs: Any | None = None,
84
+ *,
85
+ in_args: Sequence[InT_contra] | None = None,
86
+ forgetful: bool = False,
87
+ ctx: RunContext[CtxT] | None = None,
88
+ ) -> Sequence[OutT_co]:
89
+ assert in_args is not None, (
90
+ "Default implementation of _process requires in_args"
91
+ )
92
+
93
+ return cast("Sequence[OutT_co]", in_args)
94
+
95
+ async def _process_stream(
96
+ self,
97
+ chat_inputs: Any | None = None,
98
+ *,
99
+ in_args: Sequence[InT_contra] | None = None,
100
+ forgetful: bool = False,
101
+ ctx: RunContext[CtxT] | None = None,
102
+ ) -> AsyncIterator[Event[Any]]:
103
+ assert in_args is not None, (
104
+ "Default implementation of _process requires in_args"
105
+ )
106
+ outputs = cast("Sequence[OutT_co]", in_args)
107
+ for out in outputs:
108
+ yield ProcOutputEvent(data=out, name=self.name)
109
+
110
+ def _validate_outputs(self, out_payloads: Sequence[OutT_co]) -> Sequence[OutT_co]:
111
+ return [
112
+ self._out_type_adapter.validate_python(payload) for payload in out_payloads
113
+ ]
114
+
115
+ async def run(
116
+ self,
117
+ chat_inputs: Any | None = None,
118
+ *,
119
+ in_packet: Packet[InT_contra] | None = None,
120
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
121
+ forgetful: bool = False,
122
+ ctx: RunContext[CtxT] | None = None,
123
+ ) -> Packet[OutT_co]:
124
+ resolved_in_args = self._validate_and_resolve_inputs(
125
+ chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
126
+ )
127
+ outputs = await self._process(
128
+ chat_inputs=chat_inputs,
129
+ in_args=resolved_in_args,
130
+ forgetful=forgetful,
131
+ ctx=ctx,
132
+ )
133
+ val_outputs = self._validate_outputs(outputs)
134
+
135
+ return Packet(payloads=val_outputs, sender=self.name)
136
+
137
+ async def run_stream(
138
+ self,
139
+ chat_inputs: Any | None = None,
140
+ *,
141
+ in_packet: Packet[InT_contra] | None = None,
142
+ in_args: InT_contra | Sequence[InT_contra] | None = None,
143
+ forgetful: bool = False,
144
+ ctx: RunContext[CtxT] | None = None,
145
+ ) -> AsyncIterator[Event[Any]]:
146
+ resolved_in_args = self._validate_and_resolve_inputs(
147
+ chat_inputs=chat_inputs, in_packet=in_packet, in_args=in_args
148
+ )
149
+
150
+ outputs: Sequence[OutT_co] = []
151
+ async for output_event in self._process_stream(
152
+ chat_inputs=chat_inputs,
153
+ in_args=resolved_in_args,
154
+ forgetful=forgetful,
155
+ ctx=ctx,
156
+ ):
157
+ if isinstance(output_event, ProcOutputEvent):
158
+ outputs.append(output_event.data)
159
+ else:
160
+ yield output_event
161
+
162
+ val_outputs = self._validate_outputs(outputs)
163
+ out_packet = Packet[OutT_co](payloads=val_outputs, sender=self.name)
164
+
165
+ yield PacketEvent(data=out_packet, name=self.name)
166
+
167
+ @final
168
+ def as_tool(
169
+ self, tool_name: str, tool_description: str
170
+ ) -> BaseTool[InT_contra, OutT_co, Any]: # type: ignore[override]
171
+ # TODO: stream tools
172
+ processor_instance = self
173
+ in_type = processor_instance.in_type
174
+ out_type = processor_instance.out_type
175
+ if not issubclass(in_type, BaseModel):
176
+ raise TypeError(
177
+ "Cannot create a tool from an agent with "
178
+ f"non-BaseModel input type: {in_type}"
179
+ )
180
+
181
+ class ProcessorTool(BaseTool[in_type, out_type, Any]):
182
+ name: str = tool_name
183
+ description: str = tool_description
184
+
185
+ async def run(
186
+ self, inp: InT_contra, ctx: RunContext[CtxT] | None = None
187
+ ) -> OutT_co:
188
+ result = await processor_instance.run(
189
+ in_args=in_type.model_validate(inp), forgetful=True, ctx=ctx
190
+ )
191
+
192
+ return result.payloads[0]
193
+
194
+ return ProcessorTool() # type: ignore[return-value]