aap-dspy 0.1.1.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.3
2
+ Name: aap-dspy
3
+ Version: 0.1.1.dev1
4
+ Summary: DsPy integration of agent deisgn pattern
5
+ Keywords: agent,ai,pattern,llm,dspy
6
+ Author: Ly Hon Quang
7
+ Author-email: Ly Hon Quang <lyhonquang@gmail.com>
8
+ License: MIT
9
+ Requires-Dist: aap-core>=0.1.0
10
+ Requires-Dist: dspy>=3.0.4
11
+ Requires-Dist: weaviate ; extra == 'weaviate'
12
+ Requires-Python: >=3.10.12
13
+ Provides-Extra: weaviate
14
+ Description-Content-Type: text/markdown
15
+
16
+
17
+ This is a DSPy integration for the aap library
@@ -0,0 +1,2 @@
1
+
2
+ This is a DSPy integration for the aap library
@@ -0,0 +1,34 @@
1
+ [project]
2
+ name = "aap_dspy"
3
+ version = "0.1.1.dev1"
4
+ description = "DsPy integration of agent deisgn pattern"
5
+ readme = "README.md"
6
+ license = { text = "MIT" }
7
+ keywords = ["agent", "ai", "pattern", "llm", "dspy"]
8
+ authors = [
9
+ { name = "Ly Hon Quang", email = "lyhonquang@gmail.com" }
10
+ ]
11
+ requires-python = ">=3.10.12"
12
+ dependencies = [
13
+ "aap_core>=0.1.0",
14
+ "dspy>=3.0.4",
15
+ ]
16
+
17
+ [dependency-groups]
18
+ dev = [
19
+ "ipykernel>=7.1.0",
20
+ "sentence-transformers>=5.2.0",
21
+ "torch",
22
+ ]
23
+ lint = [
24
+ "ruff==0.14.0",
25
+ ]
26
+
27
+ [build-system]
28
+ requires = ["uv_build>=0.8.14,<0.9.0"]
29
+ build-backend = "uv_build"
30
+
31
+ [project.optional-dependencies]
32
+ weaviate = [
33
+ "weaviate"
34
+ ]
File without changes
@@ -0,0 +1,193 @@
1
+ import abc
2
+ from typing import Any, Dict, Generic, List, Tuple, TypeVar
3
+ from aap_core.chain import BaseCausalMultiTurnsChain
4
+ from aap_core.types import AgentMessage, AgentResponse
5
+ import dspy
6
+ from pydantic import Field, PrivateAttr
7
+
8
+
9
+ Signature = TypeVar("Signature", bound=dspy.Signature)
10
+
11
+
12
+ class BaseSignatureAdapter(abc.ABC, Generic[Signature]):
13
+ """The adapter convert between AgentMessage and dspy.Signature
14
+ In this class we also have the prefill dictionary to fill in values to the Signature fields.
15
+ This is useful when we have static fields that don't exist in the AgentMessage object while it is moving in the workflow"""
16
+
17
+ _prefill_dict: Dict[str, Any] = PrivateAttr({})
18
+
19
+ @abc.abstractmethod
20
+ def msg2sig(self, message: AgentMessage) -> List[Signature]:
21
+ """The signature fields are only known when developing end application.
22
+ This function convert AgentMessage fields to dspy Signature before flow into the dspy predictor.
23
+ The filling logic for signature should be implemented in this method in the child class
24
+
25
+ Specifically, there are 2 attributes need to taken care of:
26
+ - prefill dictionary in this adapter class. This is also known as the static filling
27
+ - the context dictionary in the AgentMessage. This is also known as the dynamic filling
28
+
29
+ Args:
30
+ message (AgentMessage): message to convert
31
+
32
+ Returns:
33
+ List[Signature]: list of the conversation so fat in dspy.Signature format"""
34
+ raise NotImplementedError
35
+
36
+ @abc.abstractmethod
37
+ def sig2msg(self, signatures: List[Signature], name: str) -> List[AgentResponse]:
38
+ """dspy.Signature to AgentMessage mapping.
39
+ This function used after the flow is completed and the dspy output need to convert back to the agent message.
40
+
41
+ Note about extracting the source name who generate the message and the message content from signature.
42
+ The source can be assistant or tool, the user message is the dspy.InputField, and dspy already handled the system message.
43
+ To unify about the source name, we can make the following assumptions:
44
+ If a signature have both OutputField and ToolCalls, it is a tool message. Otherwise it is an assistant message
45
+
46
+ Args:
47
+ signatures (List[Signature]): dspy output
48
+ name (str): agent name
49
+
50
+ Returns:
51
+ List[AgentResponse]: list of responses extract from the signatures input
52
+ """
53
+ raise NotImplementedError
54
+
55
+ @classmethod
56
+ def with_prefill(cls, prefill_dict: Dict[str, str]) -> "BaseSignatureAdapter":
57
+ """Create a new instance of the adapter with the given prefill dictionary.
58
+ Note that the child class is responsible for manage the matching and evaluation between prefill dictionary and signature
59
+
60
+ Args:
61
+ prefill_dict (Dict[str, str]): The prefill dictionary.
62
+
63
+ Returns:
64
+ BaseSignatureAdapter: A new instance of the adapter with the given prefill dictionary.
65
+ """
66
+ obj = cls()
67
+ obj._prefill_dict = prefill_dict
68
+ return obj
69
+
70
+ def add_prefill(self, key: str, value: Any) -> None:
71
+ """Add a new key-value pair to the prefill dictionary.
72
+
73
+ Args:
74
+ key (str): The key to add.
75
+ value (str): The value to add.
76
+ """
77
+ self._prefill_dict[key] = value
78
+
79
+ def remove_prefill(self, key: str) -> None:
80
+ """Remove a key from the prefill dictionary.
81
+
82
+ Args:
83
+ key (str): The key to remove.
84
+ """
85
+ del self._prefill_dict[key]
86
+
87
+
88
+ class ChatCausalMultiTurnsChain(
89
+ BaseCausalMultiTurnsChain[dspy.Signature, dspy.Prediction],
90
+ arbitrary_types_allowed=True,
91
+ ):
92
+ """A class that handle LM call using dspy without history.
93
+
94
+ Regarding tool calling pattern used in dspy framework, there are 2 approaches proposed by authors of dspy:
95
+ 1. [dspy fully managed](https://dspy.ai/learn/programming/tools/#approach-1-using-dspyreact-fully-managed): using dspy.ReAct or its subclass or customized dspy.Module that handle tool calling internally.
96
+ In this case, first the signature of module doesn't have dspy.ToolCalls field. All tools completely stay inside the dspy module.
97
+ This class only get the final output produced by dspy predictor. The _process_tools will not be call at all.
98
+
99
+ 2. [Manual tool handling](https://dspy.ai/learn/programming/tools/#approach-1-using-dspyreact-fully-managed): tool calling logic is handled by this class.
100
+ When initializing this class with provided signature, this class will automaticallty detect the dspy.ToolCalls field.
101
+ When invoke the chain, it will detects and calls the tool depends on the value of the tool calls field.
102
+
103
+ Reference: https://dspy.ai
104
+ """
105
+
106
+ predictor: dspy.Module = Field(..., description="dspy predictor")
107
+ adapter: BaseSignatureAdapter = Field(
108
+ ...,
109
+ description="The adapter convert between AgentMessage and dspy.Signature",
110
+ )
111
+ _signature: type[dspy.Signature] = PrivateAttr()
112
+ _tool_calls_field: str | None = PrivateAttr(None)
113
+ _lm: dspy.LM | None = PrivateAttr(None)
114
+ _history_field_name: str | None = PrivateAttr(None)
115
+
116
+ def __init__(self, signature: str | type[dspy.Signature], **kwargs):
117
+ super().__init__(**kwargs)
118
+ self._signature = dspy.ensure_signature(signature)
119
+ for key, value in self._signature.input_fields.items():
120
+ if value.annotation is dspy.History:
121
+ self._history_field_name = key
122
+ break
123
+ for key, value in self._signature.output_fields.items():
124
+ if value.annotation is dspy.ToolCalls:
125
+ self._tool_calls_field = key
126
+ break
127
+
128
+ def _prepare_conversation(self, message: AgentMessage) -> List[dspy.Signature]:
129
+ return self.adapter.msg2sig(message)
130
+
131
+ def _generate_response(
132
+ self, conversation: List[dspy.Signature], **kwargs
133
+ ) -> Tuple[List[dspy.Signature], dspy.Prediction, bool]:
134
+ sig = conversation[-1].model_dump(exclude_none=True)
135
+ if self._history_field_name is not None:
136
+ # Convert dict to object as the library use object to access the history
137
+ history = dspy.History(messages=sig[self._history_field_name]["messages"])
138
+ sig[self._history_field_name] = history
139
+ if self._lm:
140
+ # change context if possible
141
+ with dspy.context(lm=self._lm):
142
+ data = self.predictor(**sig)
143
+ else:
144
+ data = self.predictor(**sig)
145
+
146
+ has_tool = (
147
+ False
148
+ if self._tool_calls_field is None
149
+ else bool(data[self._tool_calls_field])
150
+ )
151
+ sig.update(data.items())
152
+ conversation.append(self._signature(**sig))
153
+ return conversation, data, has_tool
154
+
155
+ def _process_tools(
156
+ self, conversation: List[dspy.Signature], response: dspy.Prediction
157
+ ) -> List[dspy.Signature]:
158
+ for call in response[self._tool_calls_field].tool_calls:
159
+ result = call.execute()
160
+ for key, value in self._signature.output_fields.items():
161
+ if value.annotation is str:
162
+ sig = self._signature(**response, **{key: result})
163
+ conversation.append(sig)
164
+ break
165
+
166
+ return conversation
167
+
168
+ def _append_responses(
169
+ self, message: AgentMessage, conversation: List[dspy.Signature]
170
+ ) -> AgentMessage:
171
+ start_index = (
172
+ min(len(message.responses), self.include_history) + 1
173
+ if self.store_immediate_steps
174
+ else len(conversation) - 1
175
+ )
176
+ end_index = len(conversation)
177
+ message.responses.extend(
178
+ self.adapter.sig2msg(conversation[start_index:end_index], self.name)
179
+ )
180
+ # TODO: handle other modals later
181
+ return message
182
+
183
+ def with_lm(self, lm: dspy.LM | None) -> "ChatCausalMultiTurnsChain":
184
+ """Set the language model context to use for the chain.
185
+ If set to None, the default context will be used.
186
+
187
+ Args:
188
+ lm (dspy.LM | None): The language model context to use for the chain.
189
+
190
+ Returns:
191
+ ChatCausalMultiTurnsChain: The updated ChatCausalMultiTurnsChain object."""
192
+ self._lm = lm
193
+ return self
File without changes
@@ -0,0 +1,40 @@
1
+ from aap_core.retriever import BaseRetriever
2
+ from aap_core.types import AgentMessage
3
+ from pydantic import ConfigDict, Field, field_validator
4
+
5
+ from dspy import Retrieve, Embeddings
6
+
7
+
8
+ class RetrieverAdapter(BaseRetriever):
9
+ model_config = ConfigDict(arbitrary_types_allowed=True)
10
+ retriever: Retrieve | Embeddings = Field(
11
+ ..., description="The dspy's retriever to use"
12
+ )
13
+ data_key: str = Field(
14
+ default="context.data", description="The key to the data in the message"
15
+ )
16
+
17
+ @field_validator("data_key")
18
+ @classmethod
19
+ def check_starts_with_prefix(cls, v: str) -> str:
20
+ if not v.startswith("context."):
21
+ raise ValueError("data_key must start with 'context.'")
22
+ return v
23
+
24
+ def retrieve(self, message: AgentMessage, **kwargs) -> AgentMessage:
25
+ if isinstance(self.retriever, Embeddings):
26
+ results = self.retriever(message.query).passages
27
+ else:
28
+ results = self.retriever(message.query)
29
+
30
+ data = []
31
+ for result in results:
32
+ data.append(result)
33
+
34
+ data_key = self.data_key.replace("context.", "")
35
+ content = " ".join(data) if len(data) > 1 else data
36
+ if message.context is None:
37
+ message.context = {data_key: content}
38
+ else:
39
+ message.context[data_key] = content
40
+ return message