aap-dspy 0.1.1.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aap_dspy-0.1.1.dev1/PKG-INFO +17 -0
- aap_dspy-0.1.1.dev1/README.md +2 -0
- aap_dspy-0.1.1.dev1/pyproject.toml +34 -0
- aap_dspy-0.1.1.dev1/src/aap_dspy/__init__.py +0 -0
- aap_dspy-0.1.1.dev1/src/aap_dspy/chain.py +193 -0
- aap_dspy-0.1.1.dev1/src/aap_dspy/py.typed +0 -0
- aap_dspy-0.1.1.dev1/src/aap_dspy/retriever.py +40 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: aap-dspy
|
|
3
|
+
Version: 0.1.1.dev1
|
|
4
|
+
Summary: DsPy integration of agent deisgn pattern
|
|
5
|
+
Keywords: agent,ai,pattern,llm,dspy
|
|
6
|
+
Author: Ly Hon Quang
|
|
7
|
+
Author-email: Ly Hon Quang <lyhonquang@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Requires-Dist: aap-core>=0.1.0
|
|
10
|
+
Requires-Dist: dspy>=3.0.4
|
|
11
|
+
Requires-Dist: weaviate ; extra == 'weaviate'
|
|
12
|
+
Requires-Python: >=3.10.12
|
|
13
|
+
Provides-Extra: weaviate
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
This is a DSPy integration for the aap library
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "aap_dspy"
|
|
3
|
+
version = "0.1.1.dev1"
|
|
4
|
+
description = "DsPy integration of agent deisgn pattern"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = { text = "MIT" }
|
|
7
|
+
keywords = ["agent", "ai", "pattern", "llm", "dspy"]
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "Ly Hon Quang", email = "lyhonquang@gmail.com" }
|
|
10
|
+
]
|
|
11
|
+
requires-python = ">=3.10.12"
|
|
12
|
+
dependencies = [
|
|
13
|
+
"aap_core>=0.1.0",
|
|
14
|
+
"dspy>=3.0.4",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[dependency-groups]
|
|
18
|
+
dev = [
|
|
19
|
+
"ipykernel>=7.1.0",
|
|
20
|
+
"sentence-transformers>=5.2.0",
|
|
21
|
+
"torch",
|
|
22
|
+
]
|
|
23
|
+
lint = [
|
|
24
|
+
"ruff==0.14.0",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[build-system]
|
|
28
|
+
requires = ["uv_build>=0.8.14,<0.9.0"]
|
|
29
|
+
build-backend = "uv_build"
|
|
30
|
+
|
|
31
|
+
[project.optional-dependencies]
|
|
32
|
+
weaviate = [
|
|
33
|
+
"weaviate"
|
|
34
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any, Dict, Generic, List, Tuple, TypeVar
|
|
3
|
+
from aap_core.chain import BaseCausalMultiTurnsChain
|
|
4
|
+
from aap_core.types import AgentMessage, AgentResponse
|
|
5
|
+
import dspy
|
|
6
|
+
from pydantic import Field, PrivateAttr
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
Signature = TypeVar("Signature", bound=dspy.Signature)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseSignatureAdapter(abc.ABC, Generic[Signature]):
|
|
13
|
+
"""The adapter convert between AgentMessage and dspy.Signature
|
|
14
|
+
In this class we also have the prefill dictionary to fill in values to the Signature fields.
|
|
15
|
+
This is useful when we have static fields that don't exist in the AgentMessage object while it is moving in the workflow"""
|
|
16
|
+
|
|
17
|
+
_prefill_dict: Dict[str, Any] = PrivateAttr({})
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def msg2sig(self, message: AgentMessage) -> List[Signature]:
|
|
21
|
+
"""The signature fields are only known when developing end application.
|
|
22
|
+
This function convert AgentMessage fields to dspy Signature before flow into the dspy predictor.
|
|
23
|
+
The filling logic for signature should be implemented in this method in the child class
|
|
24
|
+
|
|
25
|
+
Specifically, there are 2 attributes need to taken care of:
|
|
26
|
+
- prefill dictionary in this adapter class. This is also known as the static filling
|
|
27
|
+
- the context dictionary in the AgentMessage. This is also known as the dynamic filling
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
message (AgentMessage): message to convert
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
List[Signature]: list of the conversation so fat in dspy.Signature format"""
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
@abc.abstractmethod
|
|
37
|
+
def sig2msg(self, signatures: List[Signature], name: str) -> List[AgentResponse]:
|
|
38
|
+
"""dspy.Signature to AgentMessage mapping.
|
|
39
|
+
This function used after the flow is completed and the dspy output need to convert back to the agent message.
|
|
40
|
+
|
|
41
|
+
Note about extracting the source name who generate the message and the message content from signature.
|
|
42
|
+
The source can be assistant or tool, the user message is the dspy.InputField, and dspy already handled the system message.
|
|
43
|
+
To unify about the source name, we can make the following assumptions:
|
|
44
|
+
If a signature have both OutputField and ToolCalls, it is a tool message. Otherwise it is an assistant message
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
signatures (List[Signature]): dspy output
|
|
48
|
+
name (str): agent name
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List[AgentResponse]: list of responses extract from the signatures input
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def with_prefill(cls, prefill_dict: Dict[str, str]) -> "BaseSignatureAdapter":
|
|
57
|
+
"""Create a new instance of the adapter with the given prefill dictionary.
|
|
58
|
+
Note that the child class is responsible for manage the matching and evaluation between prefill dictionary and signature
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
prefill_dict (Dict[str, str]): The prefill dictionary.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
BaseSignatureAdapter: A new instance of the adapter with the given prefill dictionary.
|
|
65
|
+
"""
|
|
66
|
+
obj = cls()
|
|
67
|
+
obj._prefill_dict = prefill_dict
|
|
68
|
+
return obj
|
|
69
|
+
|
|
70
|
+
def add_prefill(self, key: str, value: Any) -> None:
|
|
71
|
+
"""Add a new key-value pair to the prefill dictionary.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
key (str): The key to add.
|
|
75
|
+
value (str): The value to add.
|
|
76
|
+
"""
|
|
77
|
+
self._prefill_dict[key] = value
|
|
78
|
+
|
|
79
|
+
def remove_prefill(self, key: str) -> None:
|
|
80
|
+
"""Remove a key from the prefill dictionary.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
key (str): The key to remove.
|
|
84
|
+
"""
|
|
85
|
+
del self._prefill_dict[key]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class ChatCausalMultiTurnsChain(
|
|
89
|
+
BaseCausalMultiTurnsChain[dspy.Signature, dspy.Prediction],
|
|
90
|
+
arbitrary_types_allowed=True,
|
|
91
|
+
):
|
|
92
|
+
"""A class that handle LM call using dspy without history.
|
|
93
|
+
|
|
94
|
+
Regarding tool calling pattern used in dspy framework, there are 2 approaches proposed by authors of dspy:
|
|
95
|
+
1. [dspy fully managed](https://dspy.ai/learn/programming/tools/#approach-1-using-dspyreact-fully-managed): using dspy.ReAct or its subclass or customized dspy.Module that handle tool calling internally.
|
|
96
|
+
In this case, first the signature of module doesn't have dspy.ToolCalls field. All tools completely stay inside the dspy module.
|
|
97
|
+
This class only get the final output produced by dspy predictor. The _process_tools will not be call at all.
|
|
98
|
+
|
|
99
|
+
2. [Manual tool handling](https://dspy.ai/learn/programming/tools/#approach-1-using-dspyreact-fully-managed): tool calling logic is handled by this class.
|
|
100
|
+
When initializing this class with provided signature, this class will automaticallty detect the dspy.ToolCalls field.
|
|
101
|
+
When invoke the chain, it will detects and calls the tool depends on the value of the tool calls field.
|
|
102
|
+
|
|
103
|
+
Reference: https://dspy.ai
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
predictor: dspy.Module = Field(..., description="dspy predictor")
|
|
107
|
+
adapter: BaseSignatureAdapter = Field(
|
|
108
|
+
...,
|
|
109
|
+
description="The adapter convert between AgentMessage and dspy.Signature",
|
|
110
|
+
)
|
|
111
|
+
_signature: type[dspy.Signature] = PrivateAttr()
|
|
112
|
+
_tool_calls_field: str | None = PrivateAttr(None)
|
|
113
|
+
_lm: dspy.LM | None = PrivateAttr(None)
|
|
114
|
+
_history_field_name: str | None = PrivateAttr(None)
|
|
115
|
+
|
|
116
|
+
def __init__(self, signature: str | type[dspy.Signature], **kwargs):
|
|
117
|
+
super().__init__(**kwargs)
|
|
118
|
+
self._signature = dspy.ensure_signature(signature)
|
|
119
|
+
for key, value in self._signature.input_fields.items():
|
|
120
|
+
if value.annotation is dspy.History:
|
|
121
|
+
self._history_field_name = key
|
|
122
|
+
break
|
|
123
|
+
for key, value in self._signature.output_fields.items():
|
|
124
|
+
if value.annotation is dspy.ToolCalls:
|
|
125
|
+
self._tool_calls_field = key
|
|
126
|
+
break
|
|
127
|
+
|
|
128
|
+
def _prepare_conversation(self, message: AgentMessage) -> List[dspy.Signature]:
|
|
129
|
+
return self.adapter.msg2sig(message)
|
|
130
|
+
|
|
131
|
+
def _generate_response(
|
|
132
|
+
self, conversation: List[dspy.Signature], **kwargs
|
|
133
|
+
) -> Tuple[List[dspy.Signature], dspy.Prediction, bool]:
|
|
134
|
+
sig = conversation[-1].model_dump(exclude_none=True)
|
|
135
|
+
if self._history_field_name is not None:
|
|
136
|
+
# Convert dict to object as the library use object to access the history
|
|
137
|
+
history = dspy.History(messages=sig[self._history_field_name]["messages"])
|
|
138
|
+
sig[self._history_field_name] = history
|
|
139
|
+
if self._lm:
|
|
140
|
+
# change context if possible
|
|
141
|
+
with dspy.context(lm=self._lm):
|
|
142
|
+
data = self.predictor(**sig)
|
|
143
|
+
else:
|
|
144
|
+
data = self.predictor(**sig)
|
|
145
|
+
|
|
146
|
+
has_tool = (
|
|
147
|
+
False
|
|
148
|
+
if self._tool_calls_field is None
|
|
149
|
+
else bool(data[self._tool_calls_field])
|
|
150
|
+
)
|
|
151
|
+
sig.update(data.items())
|
|
152
|
+
conversation.append(self._signature(**sig))
|
|
153
|
+
return conversation, data, has_tool
|
|
154
|
+
|
|
155
|
+
def _process_tools(
|
|
156
|
+
self, conversation: List[dspy.Signature], response: dspy.Prediction
|
|
157
|
+
) -> List[dspy.Signature]:
|
|
158
|
+
for call in response[self._tool_calls_field].tool_calls:
|
|
159
|
+
result = call.execute()
|
|
160
|
+
for key, value in self._signature.output_fields.items():
|
|
161
|
+
if value.annotation is str:
|
|
162
|
+
sig = self._signature(**response, **{key: result})
|
|
163
|
+
conversation.append(sig)
|
|
164
|
+
break
|
|
165
|
+
|
|
166
|
+
return conversation
|
|
167
|
+
|
|
168
|
+
def _append_responses(
|
|
169
|
+
self, message: AgentMessage, conversation: List[dspy.Signature]
|
|
170
|
+
) -> AgentMessage:
|
|
171
|
+
start_index = (
|
|
172
|
+
min(len(message.responses), self.include_history) + 1
|
|
173
|
+
if self.store_immediate_steps
|
|
174
|
+
else len(conversation) - 1
|
|
175
|
+
)
|
|
176
|
+
end_index = len(conversation)
|
|
177
|
+
message.responses.extend(
|
|
178
|
+
self.adapter.sig2msg(conversation[start_index:end_index], self.name)
|
|
179
|
+
)
|
|
180
|
+
# TODO: handle other modals later
|
|
181
|
+
return message
|
|
182
|
+
|
|
183
|
+
def with_lm(self, lm: dspy.LM | None) -> "ChatCausalMultiTurnsChain":
|
|
184
|
+
"""Set the language model context to use for the chain.
|
|
185
|
+
If set to None, the default context will be used.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
lm (dspy.LM | None): The language model context to use for the chain.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
ChatCausalMultiTurnsChain: The updated ChatCausalMultiTurnsChain object."""
|
|
192
|
+
self._lm = lm
|
|
193
|
+
return self
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from aap_core.retriever import BaseRetriever
|
|
2
|
+
from aap_core.types import AgentMessage
|
|
3
|
+
from pydantic import ConfigDict, Field, field_validator
|
|
4
|
+
|
|
5
|
+
from dspy import Retrieve, Embeddings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RetrieverAdapter(BaseRetriever):
|
|
9
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
10
|
+
retriever: Retrieve | Embeddings = Field(
|
|
11
|
+
..., description="The dspy's retriever to use"
|
|
12
|
+
)
|
|
13
|
+
data_key: str = Field(
|
|
14
|
+
default="context.data", description="The key to the data in the message"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
@field_validator("data_key")
|
|
18
|
+
@classmethod
|
|
19
|
+
def check_starts_with_prefix(cls, v: str) -> str:
|
|
20
|
+
if not v.startswith("context."):
|
|
21
|
+
raise ValueError("data_key must start with 'context.'")
|
|
22
|
+
return v
|
|
23
|
+
|
|
24
|
+
def retrieve(self, message: AgentMessage, **kwargs) -> AgentMessage:
|
|
25
|
+
if isinstance(self.retriever, Embeddings):
|
|
26
|
+
results = self.retriever(message.query).passages
|
|
27
|
+
else:
|
|
28
|
+
results = self.retriever(message.query)
|
|
29
|
+
|
|
30
|
+
data = []
|
|
31
|
+
for result in results:
|
|
32
|
+
data.append(result)
|
|
33
|
+
|
|
34
|
+
data_key = self.data_key.replace("context.", "")
|
|
35
|
+
content = " ".join(data) if len(data) > 1 else data
|
|
36
|
+
if message.context is None:
|
|
37
|
+
message.context = {data_key: content}
|
|
38
|
+
else:
|
|
39
|
+
message.context[data_key] = content
|
|
40
|
+
return message
|