copilotkit 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- copilotkit-0.1.0/PKG-INFO +31 -0
- copilotkit-0.1.0/README.md +9 -0
- copilotkit-0.1.0/copilotkit/__init__.py +9 -0
- copilotkit-0.1.0/copilotkit/action.py +40 -0
- copilotkit-0.1.0/copilotkit/agent.py +288 -0
- copilotkit-0.1.0/copilotkit/demo.py +24 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/__init__.py +0 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/agent.py +74 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/chatbot.py +108 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/state.py +33 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/story/__init__.py +0 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/story/characters.py +30 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/story/outline.py +22 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/story/story.py +107 -0
- copilotkit-0.1.0/copilotkit/demos/autotale_ai/story/style.py +22 -0
- copilotkit-0.1.0/copilotkit/exc.py +31 -0
- copilotkit-0.1.0/copilotkit/integrations/__init__.py +0 -0
- copilotkit-0.1.0/copilotkit/integrations/fastapi.py +139 -0
- copilotkit-0.1.0/copilotkit/langchain.py +71 -0
- copilotkit-0.1.0/copilotkit/parameter.py +64 -0
- copilotkit-0.1.0/copilotkit/sdk.py +102 -0
- copilotkit-0.1.0/copilotkit/state.py +7 -0
- copilotkit-0.1.0/copilotkit/types.py +32 -0
- copilotkit-0.1.0/pyproject.toml +25 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: copilotkit
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: CopilotKit python SDK
|
|
5
|
+
Home-page: https://copilotkit.ai
|
|
6
|
+
License: MIT
|
|
7
|
+
Keywords: copilot,copilotkit,langgraph,langchain,ai,langsmith,langserve
|
|
8
|
+
Author: Markus Ecker
|
|
9
|
+
Author-email: markus.ecker@gmail.com
|
|
10
|
+
Requires-Python: >=3.12,<4.0
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Requires-Dist: fastapi (>=0.111.1,<0.112.0)
|
|
15
|
+
Requires-Dist: httpx (>=0.27.0,<0.28.0)
|
|
16
|
+
Requires-Dist: langchain (>=0.2.12,<0.3.0)
|
|
17
|
+
Requires-Dist: langchain-openai (>=0.1.20,<0.2.0)
|
|
18
|
+
Requires-Dist: langgraph (>=0.2.3,<0.3.0)
|
|
19
|
+
Requires-Dist: partialjson (>=0.0.8,<0.0.9)
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
|
|
22
|
+
# CopilotKit python SDK (alpha)
|
|
23
|
+
|
|
24
|
+
Needs pre-release CopilotKit packages:
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
@copilotkit/react-core@1.1.1-feat-runtime-remote-actions.5
|
|
28
|
+
@copilotkit/react-ui@1.1.1-feat-runtime-remote-actions.5
|
|
29
|
+
@copilotkit/runtime@1.1.1-feat-runtime-remote-actions.5
|
|
30
|
+
```
|
|
31
|
+
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# CopilotKit python SDK (alpha)
|
|
2
|
+
|
|
3
|
+
Needs pre-release CopilotKit packages:
|
|
4
|
+
|
|
5
|
+
```
|
|
6
|
+
@copilotkit/react-core@1.1.1-feat-runtime-remote-actions.5
|
|
7
|
+
@copilotkit/react-ui@1.1.1-feat-runtime-remote-actions.5
|
|
8
|
+
@copilotkit/runtime@1.1.1-feat-runtime-remote-actions.5
|
|
9
|
+
```
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Actions"""
|
|
2
|
+
|
|
3
|
+
from inspect import iscoroutinefunction
|
|
4
|
+
from typing import Optional, List, Callable
|
|
5
|
+
from .parameter import BaseParameter, normalize_parameters
|
|
6
|
+
|
|
7
|
+
class Action: # pylint: disable=too-few-public-methods
|
|
8
|
+
"""Action class for CopilotKit"""
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
*,
|
|
12
|
+
name: str,
|
|
13
|
+
handler: Callable,
|
|
14
|
+
description: Optional[str] = None,
|
|
15
|
+
parameters: Optional[List[BaseParameter]] = None,
|
|
16
|
+
):
|
|
17
|
+
self.name = name
|
|
18
|
+
self.description = description
|
|
19
|
+
self.parameters = parameters
|
|
20
|
+
self.handler = handler
|
|
21
|
+
|
|
22
|
+
async def execute(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
arguments: dict
|
|
26
|
+
) -> dict:
|
|
27
|
+
"""Execute the action"""
|
|
28
|
+
result = self.handler(**arguments)
|
|
29
|
+
|
|
30
|
+
return {
|
|
31
|
+
"result": await result if iscoroutinefunction(self.handler) else result
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
def dict_repr(self):
|
|
35
|
+
"""Dict representation of the action"""
|
|
36
|
+
return {
|
|
37
|
+
'name': self.name,
|
|
38
|
+
'description': self.description or '',
|
|
39
|
+
'parameters': normalize_parameters(self.parameters),
|
|
40
|
+
}
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""Agents"""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
import uuid
|
|
6
|
+
from langgraph.graph.graph import CompiledGraph
|
|
7
|
+
from langchain.load.dump import dumps as langchain_dumps
|
|
8
|
+
from langchain.load.load import load as langchain_load
|
|
9
|
+
|
|
10
|
+
from langchain.schema import SystemMessage
|
|
11
|
+
|
|
12
|
+
from partialjson.json_parser import JSONParser
|
|
13
|
+
|
|
14
|
+
from .types import Message
|
|
15
|
+
from .langchain import copilotkit_messages_to_langchain
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Agent(ABC):
|
|
19
|
+
"""Agent class for CopilotKit"""
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
name: str,
|
|
24
|
+
description: Optional[str] = None,
|
|
25
|
+
merge_state: Optional[callable] = None
|
|
26
|
+
):
|
|
27
|
+
self.name = name
|
|
28
|
+
self.description = description
|
|
29
|
+
self.merge_state = merge_state
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def execute(
|
|
33
|
+
self,
|
|
34
|
+
*,
|
|
35
|
+
state: dict,
|
|
36
|
+
messages: List[Message],
|
|
37
|
+
thread_id: Optional[str] = None,
|
|
38
|
+
node_name: Optional[str] = None,
|
|
39
|
+
):
|
|
40
|
+
"""Execute the agent"""
|
|
41
|
+
|
|
42
|
+
def dict_repr(self):
|
|
43
|
+
"""Dict representation of the action"""
|
|
44
|
+
return {
|
|
45
|
+
'name': self.name,
|
|
46
|
+
'description': self.description or ''
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def langgraph_default_merge_state( # pylint: disable=unused-argument
|
|
50
|
+
*,
|
|
51
|
+
state: dict,
|
|
52
|
+
messages: List[Message],
|
|
53
|
+
actions: List[any]
|
|
54
|
+
):
|
|
55
|
+
"""Default merge state for LangGraph"""
|
|
56
|
+
if len(messages) > 0 and isinstance(messages[0], SystemMessage):
|
|
57
|
+
# remove system message
|
|
58
|
+
messages = messages[1:]
|
|
59
|
+
|
|
60
|
+
# merge with existing messages
|
|
61
|
+
merged_messages = list(map(langchain_load, state.get("messages", [])))
|
|
62
|
+
existing_message_ids = {message.id for message in merged_messages}
|
|
63
|
+
|
|
64
|
+
for message in messages:
|
|
65
|
+
if message.id not in existing_message_ids:
|
|
66
|
+
merged_messages.append(message)
|
|
67
|
+
|
|
68
|
+
return {
|
|
69
|
+
**state,
|
|
70
|
+
"messages": merged_messages,
|
|
71
|
+
"copilotkit": {
|
|
72
|
+
"actions": actions
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
class LangGraphAgent(Agent):
|
|
77
|
+
"""LangGraph agent class for CopilotKit"""
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
*,
|
|
81
|
+
name: str,
|
|
82
|
+
agent: CompiledGraph,
|
|
83
|
+
description: Optional[str] = None,
|
|
84
|
+
merge_state: Optional[callable] = langgraph_default_merge_state
|
|
85
|
+
):
|
|
86
|
+
super().__init__(
|
|
87
|
+
name=name,
|
|
88
|
+
description=description,
|
|
89
|
+
merge_state=merge_state
|
|
90
|
+
)
|
|
91
|
+
self.agent = agent
|
|
92
|
+
|
|
93
|
+
def _emit_state_sync_event(
|
|
94
|
+
self,
|
|
95
|
+
*,
|
|
96
|
+
thread_id: str,
|
|
97
|
+
run_id: str,
|
|
98
|
+
node_name: str,
|
|
99
|
+
state: dict,
|
|
100
|
+
running: bool,
|
|
101
|
+
active: bool
|
|
102
|
+
):
|
|
103
|
+
return langchain_dumps({
|
|
104
|
+
"event": "on_copilotkit_state_sync",
|
|
105
|
+
"thread_id": thread_id,
|
|
106
|
+
"run_id": run_id,
|
|
107
|
+
"agent_name": self.name,
|
|
108
|
+
"node_name": node_name,
|
|
109
|
+
"active": active,
|
|
110
|
+
"state": state,
|
|
111
|
+
"running": running,
|
|
112
|
+
"role": "assistant"
|
|
113
|
+
})
|
|
114
|
+
|
|
115
|
+
def execute( # pylint: disable=too-many-arguments
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
state: dict,
|
|
119
|
+
messages: List[Message],
|
|
120
|
+
thread_id: Optional[str] = None,
|
|
121
|
+
node_name: Optional[str] = None,
|
|
122
|
+
actions: Optional[List[any]] = None,
|
|
123
|
+
):
|
|
124
|
+
|
|
125
|
+
langchain_messages = copilotkit_messages_to_langchain(messages)
|
|
126
|
+
state = self.merge_state(
|
|
127
|
+
state=state,
|
|
128
|
+
messages=langchain_messages,
|
|
129
|
+
actions=actions
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
mode = "continue" if thread_id and node_name != "__end__" else "start"
|
|
133
|
+
thread_id = thread_id or str(uuid.uuid4())
|
|
134
|
+
|
|
135
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
136
|
+
if mode == "continue":
|
|
137
|
+
self.agent.update_state(config, state, as_node=node_name)
|
|
138
|
+
|
|
139
|
+
return self._stream_events(
|
|
140
|
+
mode=mode,
|
|
141
|
+
thread_id=thread_id,
|
|
142
|
+
state=state,
|
|
143
|
+
node_name=node_name
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
async def _stream_events(
|
|
147
|
+
self,
|
|
148
|
+
*,
|
|
149
|
+
mode: str,
|
|
150
|
+
thread_id: str,
|
|
151
|
+
state: dict,
|
|
152
|
+
node_name: Optional[str] = None
|
|
153
|
+
):
|
|
154
|
+
|
|
155
|
+
config = {"configurable": {"thread_id": thread_id}}
|
|
156
|
+
streaming_state_extractor = _StreamingStateExtractor({})
|
|
157
|
+
initial_state = state if mode == "start" else None
|
|
158
|
+
prev_node_name = None
|
|
159
|
+
|
|
160
|
+
async for event in self.agent.astream_events(initial_state, config, version="v1"):
|
|
161
|
+
current_node_name = event.get("name")
|
|
162
|
+
event_type = event.get("event")
|
|
163
|
+
run_id = event.get("run_id")
|
|
164
|
+
|
|
165
|
+
metadata = event.get("metadata")
|
|
166
|
+
emit_state = metadata.get("copilotkit:emit-state")
|
|
167
|
+
|
|
168
|
+
if emit_state and event_type == "on_chat_model_start":
|
|
169
|
+
# reset the streaming state extractor
|
|
170
|
+
streaming_state_extractor = _StreamingStateExtractor(emit_state)
|
|
171
|
+
|
|
172
|
+
# we only want to update the node name under certain conditions
|
|
173
|
+
# since we don't need any internal node names to be sent to the frontend
|
|
174
|
+
if current_node_name in self.agent.nodes.keys():
|
|
175
|
+
node_name = current_node_name
|
|
176
|
+
|
|
177
|
+
# we don't have a node name yet, so we can't update the state
|
|
178
|
+
if node_name is None:
|
|
179
|
+
continue
|
|
180
|
+
|
|
181
|
+
updated_state = self.agent.get_state(config).values
|
|
182
|
+
|
|
183
|
+
if emit_state and event_type == "on_chat_model_stream":
|
|
184
|
+
streaming_state_extractor.buffer_tool_calls(event)
|
|
185
|
+
|
|
186
|
+
updated_state = {
|
|
187
|
+
**updated_state,
|
|
188
|
+
**streaming_state_extractor.extract_state()
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
exiting_node = node_name == current_node_name and event_type == "on_chain_end"
|
|
192
|
+
|
|
193
|
+
# we send state sync events when:
|
|
194
|
+
# a) the state has changed
|
|
195
|
+
# b) the node has changed
|
|
196
|
+
# c) the node is ending
|
|
197
|
+
if updated_state != state or prev_node_name != node_name or exiting_node:
|
|
198
|
+
state = updated_state
|
|
199
|
+
prev_node_name = node_name
|
|
200
|
+
yield self._emit_state_sync_event(
|
|
201
|
+
thread_id=thread_id,
|
|
202
|
+
run_id=run_id,
|
|
203
|
+
node_name=node_name,
|
|
204
|
+
state=state,
|
|
205
|
+
running=True,
|
|
206
|
+
active=not exiting_node
|
|
207
|
+
) + "\n"
|
|
208
|
+
|
|
209
|
+
yield langchain_dumps(event) + "\n"
|
|
210
|
+
|
|
211
|
+
state = self.agent.get_state(config)
|
|
212
|
+
is_end_node = state.next == ()
|
|
213
|
+
|
|
214
|
+
node_name = list(state.metadata["writes"].keys())[0]
|
|
215
|
+
yield self._emit_state_sync_event(
|
|
216
|
+
thread_id=thread_id,
|
|
217
|
+
run_id=run_id,
|
|
218
|
+
node_name=node_name if not is_end_node else "__end__",
|
|
219
|
+
state=state.values,
|
|
220
|
+
# For now, we assume that the agent is always running
|
|
221
|
+
# In the future, we will have a special node that will
|
|
222
|
+
# indicate that the agent is done
|
|
223
|
+
running=True,
|
|
224
|
+
|
|
225
|
+
# at this point, the node is ending so we set active to false
|
|
226
|
+
active=False
|
|
227
|
+
) + "\n"
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def dict_repr(self):
|
|
232
|
+
super_repr = super().dict_repr()
|
|
233
|
+
return {
|
|
234
|
+
**super_repr,
|
|
235
|
+
'type': 'langgraph'
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
class _StreamingStateExtractor:
|
|
239
|
+
def __init__(self, emit_state: dict):
|
|
240
|
+
self.emit_state = emit_state
|
|
241
|
+
self.tool_call_buffer = {}
|
|
242
|
+
self.current_tool_call = None
|
|
243
|
+
|
|
244
|
+
def buffer_tool_calls(self, event: dict):
|
|
245
|
+
"""Buffer the tool calls"""
|
|
246
|
+
if len(event["data"]["chunk"].tool_call_chunks) > 0:
|
|
247
|
+
chunk = event["data"]["chunk"].tool_call_chunks[0]
|
|
248
|
+
if chunk["name"] is not None:
|
|
249
|
+
self.current_tool_call = chunk["name"]
|
|
250
|
+
self.tool_call_buffer[self.current_tool_call] = chunk["args"]
|
|
251
|
+
elif self.current_tool_call is not None:
|
|
252
|
+
self.tool_call_buffer[self.current_tool_call] = (
|
|
253
|
+
self.tool_call_buffer[self.current_tool_call] + chunk["args"]
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def get_emit_state_config(self, current_tool_name):
|
|
257
|
+
"""Get the emit state config"""
|
|
258
|
+
|
|
259
|
+
for key, value in self.emit_state.items():
|
|
260
|
+
if current_tool_name == value.get("tool"):
|
|
261
|
+
return (value.get("argument"), key)
|
|
262
|
+
|
|
263
|
+
return (None, None)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def extract_state(self):
|
|
267
|
+
"""Extract the streaming state"""
|
|
268
|
+
parser = JSONParser()
|
|
269
|
+
|
|
270
|
+
state = {}
|
|
271
|
+
|
|
272
|
+
for key, value in self.tool_call_buffer.items():
|
|
273
|
+
argument_name, state_key = self.get_emit_state_config(key)
|
|
274
|
+
|
|
275
|
+
if state_key is None:
|
|
276
|
+
continue
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
parsed_value = parser.parse(value)
|
|
280
|
+
except Exception as _exc: # pylint: disable=broad-except
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
if argument_name is None:
|
|
284
|
+
state[state_key] = parsed_value
|
|
285
|
+
else:
|
|
286
|
+
state[state_key] = parsed_value.get(argument_name)
|
|
287
|
+
|
|
288
|
+
return state
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Demo"""
|
|
2
|
+
|
|
3
|
+
from fastapi import FastAPI
|
|
4
|
+
import uvicorn
|
|
5
|
+
from .integrations.fastapi import add_fastapi_endpoint
|
|
6
|
+
from . import CopilotKitSDK, LangGraphAgent
|
|
7
|
+
from .demos.autotale_ai.agent import graph
|
|
8
|
+
|
|
9
|
+
app = FastAPI()
|
|
10
|
+
sdk = CopilotKitSDK(
|
|
11
|
+
agents=[
|
|
12
|
+
LangGraphAgent(
|
|
13
|
+
name="childrensBookAgent",
|
|
14
|
+
description="Write a children's book.",
|
|
15
|
+
agent=graph,
|
|
16
|
+
)
|
|
17
|
+
],
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
add_fastapi_endpoint(app, sdk, "/copilotkit")
|
|
21
|
+
|
|
22
|
+
def main():
|
|
23
|
+
"""Run the uvicorn server."""
|
|
24
|
+
uvicorn.run("copilotkit.demo:app", host="127.0.0.1", port=8000, reload=True)
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is the main entry point for the autotale AI.
|
|
3
|
+
It defines the workflow graph and the entry point for the agent.
|
|
4
|
+
"""
|
|
5
|
+
# pylint: disable=line-too-long, unused-import
|
|
6
|
+
|
|
7
|
+
from langgraph.graph import StateGraph, END
|
|
8
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
9
|
+
|
|
10
|
+
from langchain_core.messages import ToolMessage
|
|
11
|
+
|
|
12
|
+
from copilotkit.demos.autotale_ai.state import AgentState
|
|
13
|
+
from copilotkit.demos.autotale_ai.chatbot import chatbot_node
|
|
14
|
+
from copilotkit.demos.autotale_ai.story.outline import outline_node
|
|
15
|
+
from copilotkit.demos.autotale_ai.story.characters import characters_node
|
|
16
|
+
from copilotkit.demos.autotale_ai.story.story import story_node
|
|
17
|
+
from copilotkit.demos.autotale_ai.story.style import style_node
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def route_story_writing(state):
|
|
23
|
+
"""Route to story writing nodes."""
|
|
24
|
+
last_message = state["messages"][-1]
|
|
25
|
+
|
|
26
|
+
if isinstance(last_message, ToolMessage):
|
|
27
|
+
return last_message.name
|
|
28
|
+
return END
|
|
29
|
+
|
|
30
|
+
# Define a new graph
|
|
31
|
+
workflow = StateGraph(AgentState)
|
|
32
|
+
workflow.add_node("chatbot_node", chatbot_node)
|
|
33
|
+
workflow.add_node("outline_node", outline_node)
|
|
34
|
+
workflow.add_node("characters_node", characters_node)
|
|
35
|
+
workflow.add_node("style_node", style_node)
|
|
36
|
+
workflow.add_node("story_node", story_node)
|
|
37
|
+
|
|
38
|
+
# Chatbot
|
|
39
|
+
workflow.set_entry_point("chatbot_node")
|
|
40
|
+
|
|
41
|
+
workflow.add_conditional_edges(
|
|
42
|
+
"chatbot_node",
|
|
43
|
+
route_story_writing,
|
|
44
|
+
{
|
|
45
|
+
"set_outline": "outline_node",
|
|
46
|
+
"set_characters": "characters_node",
|
|
47
|
+
"set_story": "story_node",
|
|
48
|
+
"set_style": "style_node",
|
|
49
|
+
END: END,
|
|
50
|
+
}
|
|
51
|
+
)
|
|
52
|
+
workflow.add_edge(
|
|
53
|
+
"outline_node",
|
|
54
|
+
"chatbot_node"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
workflow.add_edge(
|
|
58
|
+
"characters_node",
|
|
59
|
+
"chatbot_node"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
workflow.add_edge(
|
|
63
|
+
"story_node",
|
|
64
|
+
"chatbot_node"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
workflow.add_edge(
|
|
68
|
+
"style_node",
|
|
69
|
+
"chatbot_node"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
memory = MemorySaver()
|
|
73
|
+
|
|
74
|
+
graph = workflow.compile(checkpointer=memory)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main chatbot node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
from langchain_openai import ChatOpenAI
|
|
8
|
+
from langchain_core.messages import SystemMessage
|
|
9
|
+
from langchain_core.runnables import RunnableConfig
|
|
10
|
+
from langchain_core.messages import ToolMessage, AIMessage
|
|
11
|
+
|
|
12
|
+
from copilotkit.demos.autotale_ai.state import AgentState
|
|
13
|
+
from copilotkit.demos.autotale_ai.story.outline import set_outline
|
|
14
|
+
from copilotkit.demos.autotale_ai.story.characters import set_characters
|
|
15
|
+
from copilotkit.demos.autotale_ai.story.story import set_story
|
|
16
|
+
from copilotkit.demos.autotale_ai.story.style import set_style
|
|
17
|
+
from copilotkit.langchain import configure_copilotkit
|
|
18
|
+
# pylint: disable=line-too-long
|
|
19
|
+
|
|
20
|
+
async def chatbot_node(state: AgentState, config: RunnableConfig):
|
|
21
|
+
"""
|
|
22
|
+
The chatbot is responsible for answering the user's questions and selecting
|
|
23
|
+
the next route.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
config = configure_copilotkit(
|
|
27
|
+
config,
|
|
28
|
+
emit_messages=True,
|
|
29
|
+
emit_state={
|
|
30
|
+
"outline": {
|
|
31
|
+
"tool": "set_outline",
|
|
32
|
+
"argument": "outline"
|
|
33
|
+
},
|
|
34
|
+
"characters": {
|
|
35
|
+
"tool": "set_characters",
|
|
36
|
+
"argument": "characters"
|
|
37
|
+
},
|
|
38
|
+
}
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
tools = [set_outline, set_style]
|
|
42
|
+
|
|
43
|
+
if state.get("outline") is not None:
|
|
44
|
+
tools.append(set_characters)
|
|
45
|
+
|
|
46
|
+
if state.get("characters") is not None:
|
|
47
|
+
tools.append(set_story)
|
|
48
|
+
|
|
49
|
+
system_message = """
|
|
50
|
+
You help the user write a children's story. Please assist the user by either having a conversation or by
|
|
51
|
+
taking the appropriate actions to advance the story writing process. Do not repeat the whole story again.
|
|
52
|
+
|
|
53
|
+
Your state consists of the following concepts:
|
|
54
|
+
|
|
55
|
+
- Outline: The outline of the story. Should be short, 2-3 sentences.
|
|
56
|
+
- Characters: The characters that make up the story (depends on outline)
|
|
57
|
+
- Story: The final story result. (depends on outline & characters)
|
|
58
|
+
|
|
59
|
+
If the user asks you to make changes to any of these,
|
|
60
|
+
you MUST take into account dependencies and make the changes accordingly.
|
|
61
|
+
|
|
62
|
+
Example: If after coming up with the characters, the user requires changes in the outline, you must first
|
|
63
|
+
regenerate the outline.
|
|
64
|
+
|
|
65
|
+
Dont bother the user too often, just call the tools.
|
|
66
|
+
Especially, dont' repeat the story and so on, just call the tools.
|
|
67
|
+
"""
|
|
68
|
+
if state.get("outline") is not None:
|
|
69
|
+
system_message += f"\n\nThe current outline is: {state['outline']}"
|
|
70
|
+
|
|
71
|
+
if state.get("characters") is not None:
|
|
72
|
+
system_message += f"\n\nThe current characters are: {json.dumps(state['characters'])}"
|
|
73
|
+
|
|
74
|
+
if state.get("story") is not None:
|
|
75
|
+
system_message += f"\n\nThe current story is: {json.dumps(state['story'])}"
|
|
76
|
+
|
|
77
|
+
last_message = state["messages"][-1] if state["messages"] else None
|
|
78
|
+
|
|
79
|
+
if last_message and isinstance(last_message, AIMessage):
|
|
80
|
+
system_message += """
|
|
81
|
+
The user did not submit the last message. This means they probably changed the state of the story by
|
|
82
|
+
in the UI. Figure out if you need to regenerate the outline, characters or story and call the appropriate
|
|
83
|
+
tool. If not, just respond to the user.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
response = await ChatOpenAI(model="gpt-4o").bind_tools(tools, parallel_tool_calls=False).ainvoke([
|
|
88
|
+
*state["messages"],
|
|
89
|
+
SystemMessage(
|
|
90
|
+
content=system_message
|
|
91
|
+
)
|
|
92
|
+
], config)
|
|
93
|
+
|
|
94
|
+
if not response.tool_calls:
|
|
95
|
+
return {
|
|
96
|
+
"messages": response,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
"messages": [
|
|
101
|
+
response,
|
|
102
|
+
ToolMessage(
|
|
103
|
+
name=response.tool_calls[0]["name"],
|
|
104
|
+
content=json.dumps(response.tool_calls[0]["args"]),
|
|
105
|
+
tool_call_id=response.tool_calls[0]["id"]
|
|
106
|
+
)
|
|
107
|
+
],
|
|
108
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This is the state definition for the autotale AI.
|
|
3
|
+
It defines the state of the agent and the state of the conversation.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List, TypedDict
|
|
7
|
+
from langgraph.graph import MessagesState
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Character(TypedDict):
|
|
12
|
+
"""
|
|
13
|
+
Represents a character in the tale.
|
|
14
|
+
"""
|
|
15
|
+
name: str
|
|
16
|
+
appearance: str
|
|
17
|
+
traits: List[str]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Page(TypedDict):
|
|
21
|
+
"""
|
|
22
|
+
Represents a page in the children's story with an image url.
|
|
23
|
+
"""
|
|
24
|
+
content: str
|
|
25
|
+
|
|
26
|
+
class AgentState(MessagesState):
|
|
27
|
+
"""
|
|
28
|
+
This is the state of the agent.
|
|
29
|
+
It is a subclass of the MessagesState class from langgraph.
|
|
30
|
+
"""
|
|
31
|
+
outline: str
|
|
32
|
+
characters: List[Character]
|
|
33
|
+
story: List[Page]
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Characters node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
import json
|
|
7
|
+
from langchain_core.tools import tool
|
|
8
|
+
|
|
9
|
+
from copilotkit.demos.autotale_ai.state import AgentState, Character
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@tool
|
|
14
|
+
def set_characters(characters: List[Character]):
|
|
15
|
+
"""
|
|
16
|
+
Extract the book's main characters from the conversation.
|
|
17
|
+
The traits should be short: 3-4 adjectives.
|
|
18
|
+
The appearance should be as detailed as possible. What they look like, their clothes, etc.
|
|
19
|
+
"""
|
|
20
|
+
return characters
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def characters_node(state: AgentState):
|
|
24
|
+
"""
|
|
25
|
+
The characters node is responsible for extracting the characters from the conversation.
|
|
26
|
+
"""
|
|
27
|
+
last_message = state["messages"][-1]
|
|
28
|
+
return {
|
|
29
|
+
"characters": json.loads(last_message.content)["characters"]
|
|
30
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Outline node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from langchain_core.tools import tool
|
|
7
|
+
from copilotkit.demos.autotale_ai.state import AgentState
|
|
8
|
+
|
|
9
|
+
@tool
|
|
10
|
+
def set_outline(outline: str):
|
|
11
|
+
"""Sets the outline of the story."""
|
|
12
|
+
return outline
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def outline_node(state: AgentState):
|
|
16
|
+
"""
|
|
17
|
+
The outline node is responsible for generating an outline for the story.
|
|
18
|
+
"""
|
|
19
|
+
last_message = state["messages"][-1]
|
|
20
|
+
return {
|
|
21
|
+
"outline": json.loads(last_message.content)["outline"]
|
|
22
|
+
}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Story node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
import json
|
|
7
|
+
import asyncio
|
|
8
|
+
|
|
9
|
+
from langchain_core.tools import tool
|
|
10
|
+
from langchain_core.messages import SystemMessage
|
|
11
|
+
from langchain_core.runnables import RunnableConfig
|
|
12
|
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
13
|
+
from langchain_openai import ChatOpenAI
|
|
14
|
+
|
|
15
|
+
from copilotkit.demos.autotale_ai.state import AgentState, Character
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ImageDescription(BaseModel):
|
|
19
|
+
"""
|
|
20
|
+
Represents the description of an image of a character in the story.
|
|
21
|
+
"""
|
|
22
|
+
description: str
|
|
23
|
+
|
|
24
|
+
async def _generate_page_image_description(
|
|
25
|
+
messages: list,
|
|
26
|
+
page_content: str,
|
|
27
|
+
characters: List[Character],
|
|
28
|
+
style: str,
|
|
29
|
+
config: RunnableConfig
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Generate a description of the image of a character.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
system_message = SystemMessage(
|
|
36
|
+
content= f"""
|
|
37
|
+
The user and the AI are having a conversation about writing a children's story.
|
|
38
|
+
It's your job to generate a vivid description of a page in the story.
|
|
39
|
+
Make the description as detailed as possible.
|
|
40
|
+
|
|
41
|
+
These are the characters in the story:
|
|
42
|
+
{characters}
|
|
43
|
+
|
|
44
|
+
This is the page content:
|
|
45
|
+
{page_content}
|
|
46
|
+
|
|
47
|
+
This is the graphical style of the story:
|
|
48
|
+
{style}
|
|
49
|
+
|
|
50
|
+
Imagine an image of the page. Describe the looks of the page in great detail.
|
|
51
|
+
Also describe the setting in which the image is taken.
|
|
52
|
+
Make sure to include the name of the characters and full description of the characters in your output.
|
|
53
|
+
Describe the style in detail, it's very important for image generation.
|
|
54
|
+
"""
|
|
55
|
+
)
|
|
56
|
+
model = ChatOpenAI(model="gpt-4o").with_structured_output(ImageDescription)
|
|
57
|
+
response = await model.ainvoke([
|
|
58
|
+
*messages,
|
|
59
|
+
system_message
|
|
60
|
+
], config)
|
|
61
|
+
|
|
62
|
+
return response.description
|
|
63
|
+
|
|
64
|
+
class StoryPage(BaseModel):
|
|
65
|
+
"""
|
|
66
|
+
Represents a page in the children's story. Keep it simple, 3-4 sentences per page.
|
|
67
|
+
"""
|
|
68
|
+
content: str = Field(..., description="A single page in the story")
|
|
69
|
+
|
|
70
|
+
@tool
|
|
71
|
+
def set_story(pages: List[StoryPage]):
|
|
72
|
+
"""
|
|
73
|
+
Considering the outline and characters, write a story.
|
|
74
|
+
Keep it simple, 3-4 sentences per page.
|
|
75
|
+
5 pages max.
|
|
76
|
+
(If the user mentions "chapters" in the conversation they mean pages, treat it as such)
|
|
77
|
+
"""
|
|
78
|
+
return pages
|
|
79
|
+
|
|
80
|
+
async def story_node(state: AgentState, config: RunnableConfig):
|
|
81
|
+
"""
|
|
82
|
+
The story node is responsible for extracting the story from the conversation.
|
|
83
|
+
"""
|
|
84
|
+
last_message = state["messages"][-1]
|
|
85
|
+
pages = json.loads(last_message.content)["pages"]
|
|
86
|
+
characters = state.get("characters", [])
|
|
87
|
+
style = state.get("style", "Pixar movies style 3D images")
|
|
88
|
+
|
|
89
|
+
async def generate_page(page):
|
|
90
|
+
description = await _generate_page_image_description(
|
|
91
|
+
state["messages"],
|
|
92
|
+
page["content"],
|
|
93
|
+
characters,
|
|
94
|
+
style,
|
|
95
|
+
config
|
|
96
|
+
)
|
|
97
|
+
return {
|
|
98
|
+
"content": page["content"],
|
|
99
|
+
"image_description": description
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
tasks = [generate_page(page) for page in pages]
|
|
103
|
+
story = await asyncio.gather(*tasks)
|
|
104
|
+
|
|
105
|
+
return {
|
|
106
|
+
"story": story
|
|
107
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Style node.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from langchain_core.tools import tool
|
|
7
|
+
from copilotkit.demos.autotale_ai.state import AgentState
|
|
8
|
+
|
|
9
|
+
@tool
|
|
10
|
+
def set_style(style: str):
|
|
11
|
+
"""Sets the graphical style of the story."""
|
|
12
|
+
return style
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def style_node(state: AgentState):
|
|
16
|
+
"""
|
|
17
|
+
The style node is responsible for setting the graphical style of the story.
|
|
18
|
+
"""
|
|
19
|
+
last_message = state["messages"][-1]
|
|
20
|
+
return {
|
|
21
|
+
"style": json.loads(last_message.content)["style"]
|
|
22
|
+
}
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Exceptions for CopilotKit."""
|
|
2
|
+
|
|
3
|
+
class ActionNotFoundException(Exception):
|
|
4
|
+
"""Exception raised when an action or agent is not found."""
|
|
5
|
+
|
|
6
|
+
def __init__(self, name: str):
|
|
7
|
+
self.name = name
|
|
8
|
+
super().__init__(f"Action '{name}' not found.")
|
|
9
|
+
|
|
10
|
+
class AgentNotFoundException(Exception):
|
|
11
|
+
"""Exception raised when an agent is not found."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, name: str):
|
|
14
|
+
self.name = name
|
|
15
|
+
super().__init__(f"Agent '{name}' not found.")
|
|
16
|
+
|
|
17
|
+
class ActionExecutionException(Exception):
|
|
18
|
+
"""Exception raised when an action fails to execute."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, name: str, error: Exception):
|
|
21
|
+
self.name = name
|
|
22
|
+
self.error = error
|
|
23
|
+
super().__init__(f"Action '{name}' failed to execute: {error}")
|
|
24
|
+
|
|
25
|
+
class AgentExecutionException(Exception):
|
|
26
|
+
"""Exception raised when an agent fails to execute."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, name: str, error: Exception):
|
|
29
|
+
self.name = name
|
|
30
|
+
self.error = error
|
|
31
|
+
super().__init__(f"Agent '{name}' failed to execute: {error}")
|
|
File without changes
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""FastAPI integration"""
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
from fastapi import FastAPI, Request, HTTPException
|
|
5
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
|
6
|
+
from ..sdk import CopilotKitSDK, CopilotKitSDKContext
|
|
7
|
+
from ..types import Message
|
|
8
|
+
from ..exc import (
|
|
9
|
+
ActionNotFoundException,
|
|
10
|
+
ActionExecutionException,
|
|
11
|
+
AgentNotFoundException,
|
|
12
|
+
AgentExecutionException,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def add_fastapi_endpoint(fastapi_app: FastAPI, sdk: CopilotKitSDK, prefix: str):
|
|
17
|
+
"""Add FastAPI endpoint"""
|
|
18
|
+
async def make_handler(request: Request):
|
|
19
|
+
return await handler(request, sdk)
|
|
20
|
+
|
|
21
|
+
# Ensure the prefix starts with a slash and remove trailing slashes
|
|
22
|
+
normalized_prefix = '/' + prefix.strip('/')
|
|
23
|
+
|
|
24
|
+
fastapi_app.add_api_route(
|
|
25
|
+
f"{normalized_prefix}/{{path:path}}",
|
|
26
|
+
make_handler,
|
|
27
|
+
methods=['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS'],
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def body_get_or_raise(body: any, key: str):
|
|
31
|
+
"""Get value from body or raise an error"""
|
|
32
|
+
value = body.get(key)
|
|
33
|
+
if value is None:
|
|
34
|
+
raise HTTPException(status_code=400, detail=f"{key} is required")
|
|
35
|
+
return value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def handler(request: Request, sdk: CopilotKitSDK):
|
|
39
|
+
"""Handle FastAPI request"""
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
body = await request.json()
|
|
43
|
+
except Exception as exc:
|
|
44
|
+
raise HTTPException(status_code=400, detail="Request body is required") from exc
|
|
45
|
+
|
|
46
|
+
path = request.path_params.get('path')
|
|
47
|
+
method = request.method
|
|
48
|
+
context = {"properties": body.get("properties", {})}
|
|
49
|
+
|
|
50
|
+
if method == 'POST' and path == 'info':
|
|
51
|
+
return await handle_info(sdk=sdk, context=context)
|
|
52
|
+
|
|
53
|
+
if method == 'POST' and path == 'actions/execute':
|
|
54
|
+
name = body_get_or_raise(body, "name")
|
|
55
|
+
arguments = body.get("arguments", {})
|
|
56
|
+
|
|
57
|
+
return await handle_execute_action(
|
|
58
|
+
sdk=sdk,
|
|
59
|
+
context=context,
|
|
60
|
+
name=name,
|
|
61
|
+
arguments=arguments,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if method == 'POST' and path == 'agents/execute':
|
|
65
|
+
thread_id = body.get("threadId")
|
|
66
|
+
node_name = body.get("nodeName")
|
|
67
|
+
|
|
68
|
+
name = body_get_or_raise(body, "name")
|
|
69
|
+
state = body_get_or_raise(body, "state")
|
|
70
|
+
messages = body_get_or_raise(body, "messages")
|
|
71
|
+
actions = body.get("actions", [])
|
|
72
|
+
|
|
73
|
+
return handle_execute_agent(
|
|
74
|
+
sdk=sdk,
|
|
75
|
+
context=context,
|
|
76
|
+
thread_id=thread_id,
|
|
77
|
+
node_name=node_name,
|
|
78
|
+
name=name,
|
|
79
|
+
state=state,
|
|
80
|
+
messages=messages,
|
|
81
|
+
actions=actions,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
raise HTTPException(status_code=404, detail="Not found")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
async def handle_info(*, sdk: CopilotKitSDK, context: CopilotKitSDKContext):
|
|
89
|
+
"""Handle info request with FastAPI"""
|
|
90
|
+
result = sdk.info(context=context)
|
|
91
|
+
return JSONResponse(content=result)
|
|
92
|
+
|
|
93
|
+
async def handle_execute_action(
|
|
94
|
+
*,
|
|
95
|
+
sdk: CopilotKitSDK,
|
|
96
|
+
context: CopilotKitSDKContext,
|
|
97
|
+
name: str,
|
|
98
|
+
arguments: dict,
|
|
99
|
+
):
|
|
100
|
+
"""Handle execute action request with FastAPI"""
|
|
101
|
+
try:
|
|
102
|
+
result = await sdk.execute_action(
|
|
103
|
+
context=context,
|
|
104
|
+
name=name,
|
|
105
|
+
arguments=arguments
|
|
106
|
+
)
|
|
107
|
+
return JSONResponse(content=result)
|
|
108
|
+
except ActionNotFoundException as exc:
|
|
109
|
+
return JSONResponse(content={"error": str(exc)}, status_code=404)
|
|
110
|
+
except ActionExecutionException as exc:
|
|
111
|
+
return JSONResponse(content={"error": str(exc)}, status_code=500)
|
|
112
|
+
|
|
113
|
+
def handle_execute_agent( # pylint: disable=too-many-arguments
|
|
114
|
+
*,
|
|
115
|
+
sdk: CopilotKitSDK,
|
|
116
|
+
context: CopilotKitSDKContext,
|
|
117
|
+
thread_id: str,
|
|
118
|
+
node_name: str,
|
|
119
|
+
name: str,
|
|
120
|
+
state: dict,
|
|
121
|
+
messages: List[Message],
|
|
122
|
+
actions: List[any],
|
|
123
|
+
):
|
|
124
|
+
"""Handle continue agent execution request with FastAPI"""
|
|
125
|
+
try:
|
|
126
|
+
events = sdk.execute_agent(
|
|
127
|
+
context=context,
|
|
128
|
+
thread_id=thread_id,
|
|
129
|
+
name=name,
|
|
130
|
+
node_name=node_name,
|
|
131
|
+
state=state,
|
|
132
|
+
messages=messages,
|
|
133
|
+
actions=actions,
|
|
134
|
+
)
|
|
135
|
+
return StreamingResponse(events, media_type="application/json")
|
|
136
|
+
except AgentNotFoundException as exc:
|
|
137
|
+
return JSONResponse(content={"error": str(exc)}, status_code=404)
|
|
138
|
+
except AgentExecutionException as exc:
|
|
139
|
+
return JSONResponse(content={"error": str(exc)}, status_code=500)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangChain specific utilities for CopilotKit
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
from langchain_core.messages import (
|
|
7
|
+
HumanMessage,
|
|
8
|
+
SystemMessage,
|
|
9
|
+
BaseMessage,
|
|
10
|
+
AIMessage,
|
|
11
|
+
ToolMessage
|
|
12
|
+
)
|
|
13
|
+
from langchain_core.runnables import RunnableConfig
|
|
14
|
+
from langchain_core.runnables.config import ensure_config
|
|
15
|
+
|
|
16
|
+
from .types import Message
|
|
17
|
+
def copilotkit_messages_to_langchain(messages: List[Message]) -> List[BaseMessage]:
|
|
18
|
+
"""
|
|
19
|
+
Convert CopilotKit messages to LangChain messages
|
|
20
|
+
"""
|
|
21
|
+
result = []
|
|
22
|
+
for message in messages:
|
|
23
|
+
if "content" in message:
|
|
24
|
+
if message["role"] == "user":
|
|
25
|
+
result.append(HumanMessage(content=message["content"], id=message["id"]))
|
|
26
|
+
elif message["role"] == "system":
|
|
27
|
+
result.append(SystemMessage(content=message["content"], id=message["id"]))
|
|
28
|
+
elif message["role"] == "assistant":
|
|
29
|
+
result.append(AIMessage(content=message["content"], id=message["id"]))
|
|
30
|
+
elif "arguments" in message:
|
|
31
|
+
tool_call = {
|
|
32
|
+
"name": message["name"],
|
|
33
|
+
"args": message["arguments"],
|
|
34
|
+
"id": message["id"],
|
|
35
|
+
}
|
|
36
|
+
result.append(AIMessage(id=message["id"], content="", tool_calls=[tool_call]))
|
|
37
|
+
elif "actionExecutionId" in message:
|
|
38
|
+
result.append(ToolMessage(
|
|
39
|
+
id=message["id"],
|
|
40
|
+
content=message["result"],
|
|
41
|
+
name=message["actionName"],
|
|
42
|
+
tool_call_id=message["actionExecutionId"]
|
|
43
|
+
))
|
|
44
|
+
return result
|
|
45
|
+
|
|
46
|
+
def configure_copilotkit(
|
|
47
|
+
config: Optional[RunnableConfig] = None,
|
|
48
|
+
*,
|
|
49
|
+
emit_tool_calls: bool = False,
|
|
50
|
+
emit_messages: bool = False,
|
|
51
|
+
emit_all: bool = False,
|
|
52
|
+
emit_state: Optional[dict] = None
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Configure for LangChain for use in CopilotKit
|
|
56
|
+
"""
|
|
57
|
+
tags = config.get("tags", []) if config else []
|
|
58
|
+
metadata = config.get("metadata", {}) if config else {}
|
|
59
|
+
|
|
60
|
+
if emit_tool_calls or emit_all:
|
|
61
|
+
tags.append("copilotkit:emit-tool-calls")
|
|
62
|
+
if emit_messages or emit_all:
|
|
63
|
+
tags.append("copilotkit:emit-messages")
|
|
64
|
+
|
|
65
|
+
if emit_state:
|
|
66
|
+
metadata["copilotkit:emit-state"] = emit_state
|
|
67
|
+
|
|
68
|
+
config = (config or {}).copy()
|
|
69
|
+
config["tags"] = tags
|
|
70
|
+
config["metadata"] = metadata
|
|
71
|
+
return ensure_config(config)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Parameter classes for CopilotKit"""
|
|
2
|
+
|
|
3
|
+
from typing import TypedDict, Optional, Literal, List
|
|
4
|
+
|
|
5
|
+
class BaseParameter(TypedDict):
|
|
6
|
+
"""Base parameter class"""
|
|
7
|
+
name: str
|
|
8
|
+
description: Optional[str]
|
|
9
|
+
required: Optional[bool]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def normalize_parameters(parameters: Optional[List[BaseParameter]]) -> List[BaseParameter]:
|
|
13
|
+
"""Normalize the parameters to ensure they have the correct type and format."""
|
|
14
|
+
if parameters is None:
|
|
15
|
+
return []
|
|
16
|
+
return [_normalize_parameter(parameter) for parameter in parameters]
|
|
17
|
+
|
|
18
|
+
def _normalize_parameter(parameter: BaseParameter) -> BaseParameter:
|
|
19
|
+
"""Normalize a parameter to ensure it has the correct type and format."""
|
|
20
|
+
if not hasattr(parameter, 'type'):
|
|
21
|
+
parameter['type'] = 'string'
|
|
22
|
+
if not hasattr(parameter, 'required'):
|
|
23
|
+
parameter['required'] = False
|
|
24
|
+
if not hasattr(parameter, 'description'):
|
|
25
|
+
parameter['description'] = ''
|
|
26
|
+
|
|
27
|
+
if parameter['type'] == 'object' or parameter['type'] == 'object[]':
|
|
28
|
+
parameter['attributes'] = normalize_parameters(parameter.get('attributes'))
|
|
29
|
+
return parameter
|
|
30
|
+
|
|
31
|
+
class StringParameter(BaseParameter):
|
|
32
|
+
"""String parameter class"""
|
|
33
|
+
type: Literal["string"]
|
|
34
|
+
enum: Optional[list[str]]
|
|
35
|
+
|
|
36
|
+
class NumberParameter(BaseParameter):
|
|
37
|
+
"""Number parameter class"""
|
|
38
|
+
type: Literal["number"]
|
|
39
|
+
|
|
40
|
+
class BooleanParameter(BaseParameter):
|
|
41
|
+
"""Boolean parameter class"""
|
|
42
|
+
type: Literal["boolean"]
|
|
43
|
+
|
|
44
|
+
class ObjectParameter(BaseParameter):
|
|
45
|
+
"""Object parameter class"""
|
|
46
|
+
type: Literal["object"]
|
|
47
|
+
attributes: List[BaseParameter]
|
|
48
|
+
|
|
49
|
+
class ObjectArrayParameter(BaseParameter):
|
|
50
|
+
"""Object array parameter class"""
|
|
51
|
+
type: Literal["object[]"]
|
|
52
|
+
attributes: List[BaseParameter]
|
|
53
|
+
|
|
54
|
+
class StringArrayParameter(BaseParameter):
|
|
55
|
+
"""String array parameter class"""
|
|
56
|
+
type: Literal["string[]"]
|
|
57
|
+
|
|
58
|
+
class NumberArrayParameter(BaseParameter):
|
|
59
|
+
"""Number array parameter class"""
|
|
60
|
+
type: Literal["number[]"]
|
|
61
|
+
|
|
62
|
+
class BooleanArrayParameter(BaseParameter):
|
|
63
|
+
"""Boolean array parameter class"""
|
|
64
|
+
type: Literal["boolean[]"]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""CopilotKit SDK"""
|
|
2
|
+
|
|
3
|
+
from typing import List, Callable, Union, Optional, TypedDict, Any
|
|
4
|
+
from .agent import Agent
|
|
5
|
+
from .action import Action
|
|
6
|
+
from .types import Message
|
|
7
|
+
from .exc import (
|
|
8
|
+
ActionNotFoundException,
|
|
9
|
+
AgentNotFoundException,
|
|
10
|
+
ActionExecutionException,
|
|
11
|
+
AgentExecutionException
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CopilotKitSDKContext(TypedDict):
|
|
16
|
+
"""CopilotKit SDK Context"""
|
|
17
|
+
properties: Any
|
|
18
|
+
|
|
19
|
+
class CopilotKitSDK:
|
|
20
|
+
"""CopilotKit SDK"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
actions: Optional[Union[List[Action], Callable[[], List[Action]]]] = None,
|
|
26
|
+
agents: Optional[Union[List[Agent], Callable[[], List[Agent]]]] = None,
|
|
27
|
+
):
|
|
28
|
+
self.agents = agents or []
|
|
29
|
+
self.actions = actions or []
|
|
30
|
+
|
|
31
|
+
def info(
|
|
32
|
+
self,
|
|
33
|
+
*,
|
|
34
|
+
context: CopilotKitSDKContext
|
|
35
|
+
) -> List[Union[Action, Agent]]:
|
|
36
|
+
"""Returns information about available actions and agents"""
|
|
37
|
+
|
|
38
|
+
actions = self.actions(context) if callable(self.actions) else self.actions
|
|
39
|
+
agents = self.agents(context) if callable(self.agents) else self.agents
|
|
40
|
+
|
|
41
|
+
result = {
|
|
42
|
+
"actions": [action.dict_repr() for action in actions],
|
|
43
|
+
"agents": [agent.dict_repr() for agent in agents]
|
|
44
|
+
}
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
def _get_action(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
context: CopilotKitSDKContext,
|
|
51
|
+
name: str,
|
|
52
|
+
) -> Action:
|
|
53
|
+
"""Get an action by name"""
|
|
54
|
+
actions = self.actions(context) if callable(self.actions) else self.actions
|
|
55
|
+
action = next((action for action in actions if action.name == name), None)
|
|
56
|
+
if action is None:
|
|
57
|
+
raise ActionNotFoundException(name)
|
|
58
|
+
return action
|
|
59
|
+
|
|
60
|
+
def execute_action(
|
|
61
|
+
self,
|
|
62
|
+
*,
|
|
63
|
+
context: CopilotKitSDKContext,
|
|
64
|
+
name: str,
|
|
65
|
+
arguments: dict,
|
|
66
|
+
) -> dict:
|
|
67
|
+
"""Execute an action"""
|
|
68
|
+
|
|
69
|
+
action = self._get_action(context=context, name=name)
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
return action.execute(arguments=arguments)
|
|
73
|
+
except Exception as error:
|
|
74
|
+
raise ActionExecutionException(name, error) from error
|
|
75
|
+
|
|
76
|
+
def execute_agent( # pylint: disable=too-many-arguments
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
context: CopilotKitSDKContext,
|
|
80
|
+
name: str,
|
|
81
|
+
thread_id: str,
|
|
82
|
+
node_name: str,
|
|
83
|
+
state: dict,
|
|
84
|
+
messages: List[Message],
|
|
85
|
+
actions: List[any],
|
|
86
|
+
):
|
|
87
|
+
"""Execute an agent"""
|
|
88
|
+
agents = self.agents(context) if callable(self.agents) else self.agents
|
|
89
|
+
agent = next((agent for agent in agents if agent.name == name), None)
|
|
90
|
+
if agent is None:
|
|
91
|
+
raise AgentNotFoundException(name)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
return agent.execute(
|
|
95
|
+
thread_id=thread_id,
|
|
96
|
+
node_name=node_name,
|
|
97
|
+
state=state,
|
|
98
|
+
messages=messages,
|
|
99
|
+
actions=actions,
|
|
100
|
+
)
|
|
101
|
+
except Exception as error:
|
|
102
|
+
raise AgentExecutionException(name, error) from error
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""State for CopilotKit"""
|
|
2
|
+
|
|
3
|
+
from typing import TypedDict
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
class MessageRole(Enum):
|
|
7
|
+
"""Message role"""
|
|
8
|
+
ASSISTANT = "assistant"
|
|
9
|
+
SYSTEM = "system"
|
|
10
|
+
USER = "user"
|
|
11
|
+
|
|
12
|
+
class Message(TypedDict):
|
|
13
|
+
"""Message"""
|
|
14
|
+
id: str
|
|
15
|
+
createdAt: str
|
|
16
|
+
|
|
17
|
+
class TextMessage(Message):
|
|
18
|
+
"""Text message"""
|
|
19
|
+
role: MessageRole
|
|
20
|
+
content: str
|
|
21
|
+
|
|
22
|
+
class ActionExecutionMessage(Message):
|
|
23
|
+
"""Action execution message"""
|
|
24
|
+
name: str
|
|
25
|
+
arguments: dict
|
|
26
|
+
scope: str
|
|
27
|
+
|
|
28
|
+
class ResultMessage(Message):
|
|
29
|
+
"""Result message"""
|
|
30
|
+
actionExecutionId: str
|
|
31
|
+
actionName: str
|
|
32
|
+
result: str
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "copilotkit"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "CopilotKit python SDK"
|
|
5
|
+
authors = ["Markus Ecker <markus.ecker@gmail.com>"]
|
|
6
|
+
license = "MIT"
|
|
7
|
+
readme = "README.md"
|
|
8
|
+
homepage = "https://copilotkit.ai"
|
|
9
|
+
keywords = ["copilot", "copilotkit", "langgraph", "langchain", "ai", "langsmith", "langserve"]
|
|
10
|
+
|
|
11
|
+
[tool.poetry.dependencies]
|
|
12
|
+
python = "^3.12"
|
|
13
|
+
langgraph = "^0.2.3"
|
|
14
|
+
httpx = "^0.27.0"
|
|
15
|
+
fastapi = "^0.111.1"
|
|
16
|
+
langchain = "^0.2.12"
|
|
17
|
+
langchain-openai = "^0.1.20"
|
|
18
|
+
partialjson = "^0.0.8"
|
|
19
|
+
|
|
20
|
+
[build-system]
|
|
21
|
+
requires = ["poetry-core"]
|
|
22
|
+
build-backend = "poetry.core.masonry.api"
|
|
23
|
+
|
|
24
|
+
[tool.poetry.scripts]
|
|
25
|
+
demo = "copilotkit.demo:main"
|