copilotkit 0.1.89__tar.gz → 0.1.90__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {copilotkit-0.1.89 → copilotkit-0.1.90}/PKG-INFO +1 -1
- copilotkit-0.1.90/copilotkit/__init__.py +43 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/a2ui.py +4 -6
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/action.py +20 -20
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/agent.py +14 -14
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/copilotkit_lg_middleware.py +130 -79
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/__init__.py +4 -2
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/copilotkit_integration.py +134 -56
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/crewai_agent.py +124 -108
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/crewai_sdk.py +142 -144
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/exc.py +4 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/header_propagation.py +7 -6
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/html.py +3 -1
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/integrations/fastapi.py +72 -71
- copilotkit-0.1.90/copilotkit/langchain.py +30 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/langgraph.py +84 -77
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/langgraph_agui_agent.py +50 -24
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/logging.py +4 -2
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/parameter.py +23 -15
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/protocol.py +81 -44
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/runloop.py +37 -41
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/sdk.py +27 -32
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/types.py +15 -1
- copilotkit-0.1.90/copilotkit/utils.py +5 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/pyproject.toml +1 -1
- copilotkit-0.1.89/copilotkit/__init__.py +0 -31
- copilotkit-0.1.89/copilotkit/langchain.py +0 -29
- copilotkit-0.1.89/copilotkit/utils.py +0 -8
- {copilotkit-0.1.89 → copilotkit-0.1.90}/README.md +0 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/integrations/__init__.py +0 -0
- {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/py.typed +0 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""CopilotKit SDK"""
|
|
2
|
+
|
|
3
|
+
from .sdk import (
|
|
4
|
+
CopilotKitRemoteEndpoint,
|
|
5
|
+
CopilotKitContext,
|
|
6
|
+
CopilotKitSDK,
|
|
7
|
+
CopilotKitSDKContext,
|
|
8
|
+
)
|
|
9
|
+
from .action import Action
|
|
10
|
+
from .langgraph import CopilotKitState
|
|
11
|
+
from .parameter import Parameter
|
|
12
|
+
from .agent import Agent
|
|
13
|
+
from .langgraph_agui_agent import LangGraphAGUIAgent
|
|
14
|
+
from .copilotkit_lg_middleware import CopilotKitMiddleware
|
|
15
|
+
from ag_ui_langgraph.middlewares.state_streaming import (
|
|
16
|
+
StateStreamingMiddleware,
|
|
17
|
+
StateItem,
|
|
18
|
+
)
|
|
19
|
+
from .header_propagation import (
|
|
20
|
+
set_forwarded_headers,
|
|
21
|
+
get_forwarded_headers,
|
|
22
|
+
install_httpx_hook,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"CopilotKitRemoteEndpoint",
|
|
28
|
+
"CopilotKitSDK",
|
|
29
|
+
"Action",
|
|
30
|
+
"CopilotKitState",
|
|
31
|
+
"Parameter",
|
|
32
|
+
"Agent",
|
|
33
|
+
"CopilotKitContext",
|
|
34
|
+
"CopilotKitSDKContext",
|
|
35
|
+
"CrewAIAgent", # pyright: ignore[reportUnsupportedDunderAll] pylint: disable=undefined-all-variable
|
|
36
|
+
"LangGraphAGUIAgent",
|
|
37
|
+
"CopilotKitMiddleware",
|
|
38
|
+
"StateStreamingMiddleware",
|
|
39
|
+
"StateItem",
|
|
40
|
+
"set_forwarded_headers",
|
|
41
|
+
"get_forwarded_headers",
|
|
42
|
+
"install_httpx_hook",
|
|
43
|
+
]
|
|
@@ -38,7 +38,7 @@ def update_components(
|
|
|
38
38
|
"updateComponents": {
|
|
39
39
|
"surfaceId": surface_id,
|
|
40
40
|
"components": components,
|
|
41
|
-
}
|
|
41
|
+
},
|
|
42
42
|
}
|
|
43
43
|
|
|
44
44
|
|
|
@@ -54,7 +54,7 @@ def update_data_model(
|
|
|
54
54
|
"surfaceId": surface_id,
|
|
55
55
|
"path": path,
|
|
56
56
|
"value": data,
|
|
57
|
-
}
|
|
57
|
+
},
|
|
58
58
|
}
|
|
59
59
|
|
|
60
60
|
|
|
@@ -72,7 +72,7 @@ def create_surface(
|
|
|
72
72
|
"createSurface": {
|
|
73
73
|
"surfaceId": surface_id,
|
|
74
74
|
"catalogId": catalog_id,
|
|
75
|
-
}
|
|
75
|
+
},
|
|
76
76
|
}
|
|
77
77
|
|
|
78
78
|
|
|
@@ -80,9 +80,7 @@ A2UI_OPERATIONS_KEY = "a2ui_operations"
|
|
|
80
80
|
"""The container key used to wrap A2UI operations for explicit detection."""
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
def render(
|
|
84
|
-
operations: list[dict[str, Any]]
|
|
85
|
-
) -> str:
|
|
83
|
+
def render(operations: list[dict[str, Any]]) -> str:
|
|
86
84
|
"""Wrap operations in the a2ui_operations container and serialize to JSON.
|
|
87
85
|
|
|
88
86
|
Args:
|
|
@@ -5,26 +5,32 @@ from inspect import iscoroutinefunction
|
|
|
5
5
|
from typing import Optional, List, Callable, TypedDict, Any, cast
|
|
6
6
|
from .parameter import Parameter, normalize_parameters
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
class ActionDict(TypedDict):
|
|
9
10
|
"""Dict representation of an action"""
|
|
11
|
+
|
|
10
12
|
name: str
|
|
11
13
|
description: str
|
|
12
14
|
parameters: List[Parameter]
|
|
13
15
|
|
|
16
|
+
|
|
14
17
|
class ActionResultDict(TypedDict):
|
|
15
18
|
"""Dict representation of an action result"""
|
|
19
|
+
|
|
16
20
|
result: Any
|
|
17
21
|
|
|
22
|
+
|
|
18
23
|
class Action: # pylint: disable=too-few-public-methods
|
|
19
24
|
"""Action class for CopilotKit"""
|
|
25
|
+
|
|
20
26
|
def __init__(
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
27
|
+
self,
|
|
28
|
+
*,
|
|
29
|
+
name: str,
|
|
30
|
+
handler: Callable,
|
|
31
|
+
description: Optional[str] = None,
|
|
32
|
+
parameters: Optional[List[Parameter]] = None,
|
|
33
|
+
):
|
|
28
34
|
self.name = name
|
|
29
35
|
self.description = description
|
|
30
36
|
self.parameters = parameters
|
|
@@ -32,26 +38,20 @@ class Action: # pylint: disable=too-few-public-methods
|
|
|
32
38
|
|
|
33
39
|
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
|
34
40
|
raise ValueError(
|
|
35
|
-
f"Invalid action name '{name}': "
|
|
36
|
-
"must consist of alphanumeric characters, underscores, and hyphens only"
|
|
41
|
+
f"Invalid action name '{name}': "
|
|
42
|
+
+ "must consist of alphanumeric characters, underscores, and hyphens only"
|
|
37
43
|
)
|
|
38
44
|
|
|
39
|
-
async def execute(
|
|
40
|
-
self,
|
|
41
|
-
*,
|
|
42
|
-
arguments: dict
|
|
43
|
-
) -> ActionResultDict:
|
|
45
|
+
async def execute(self, *, arguments: dict) -> ActionResultDict:
|
|
44
46
|
"""Execute the action"""
|
|
45
47
|
result = self.handler(**arguments)
|
|
46
48
|
|
|
47
|
-
return {
|
|
48
|
-
"result": await result if iscoroutinefunction(self.handler) else result
|
|
49
|
-
}
|
|
49
|
+
return {"result": await result if iscoroutinefunction(self.handler) else result}
|
|
50
50
|
|
|
51
51
|
def dict_repr(self) -> ActionDict:
|
|
52
52
|
"""Dict representation of the action"""
|
|
53
53
|
return {
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
"name": self.name,
|
|
55
|
+
"description": self.description or "",
|
|
56
|
+
"parameters": normalize_parameters(cast(Any, self.parameters)),
|
|
57
57
|
}
|
|
@@ -7,30 +7,34 @@ from .types import Message
|
|
|
7
7
|
from .action import ActionDict
|
|
8
8
|
from .types import MetaEvent
|
|
9
9
|
|
|
10
|
+
|
|
10
11
|
class AgentDict(TypedDict):
|
|
11
12
|
"""Agent dictionary"""
|
|
13
|
+
|
|
12
14
|
name: str
|
|
13
15
|
description: Optional[str]
|
|
14
16
|
|
|
17
|
+
|
|
15
18
|
class Agent(ABC):
|
|
16
19
|
"""Agent class for CopilotKit"""
|
|
20
|
+
|
|
17
21
|
def __init__(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
22
|
+
self,
|
|
23
|
+
*,
|
|
24
|
+
name: str,
|
|
25
|
+
description: Optional[str] = None,
|
|
26
|
+
):
|
|
23
27
|
self.name = name
|
|
24
28
|
self.description = description
|
|
25
29
|
|
|
26
30
|
if not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
|
27
31
|
raise ValueError(
|
|
28
|
-
f"Invalid agent name '{name}': "
|
|
29
|
-
"must consist of alphanumeric characters, underscores, and hyphens only"
|
|
32
|
+
f"Invalid agent name '{name}': "
|
|
33
|
+
+ "must consist of alphanumeric characters, underscores, and hyphens only"
|
|
30
34
|
)
|
|
31
35
|
|
|
32
36
|
@abstractmethod
|
|
33
|
-
def execute(
|
|
37
|
+
def execute( # pylint: disable=too-many-arguments
|
|
34
38
|
self,
|
|
35
39
|
*,
|
|
36
40
|
state: dict,
|
|
@@ -54,13 +58,9 @@ class Agent(ABC):
|
|
|
54
58
|
"threadId": thread_id or "",
|
|
55
59
|
"threadExists": False,
|
|
56
60
|
"state": {},
|
|
57
|
-
"messages": []
|
|
61
|
+
"messages": [],
|
|
58
62
|
}
|
|
59
63
|
|
|
60
|
-
|
|
61
64
|
def dict_repr(self) -> AgentDict:
|
|
62
65
|
"""Dict representation of the action"""
|
|
63
|
-
return {
|
|
64
|
-
'name': self.name,
|
|
65
|
-
'description': self.description or ''
|
|
66
|
-
}
|
|
66
|
+
return {"name": self.name, "description": self.description or ""}
|
|
@@ -27,8 +27,29 @@ from langchain.agents.middleware import (
|
|
|
27
27
|
)
|
|
28
28
|
from langgraph.runtime import Runtime
|
|
29
29
|
|
|
30
|
+
from .header_propagation import install_httpx_hook
|
|
30
31
|
from .langgraph import CopilotKitProperties
|
|
31
32
|
|
|
33
|
+
# Track which httpx clients already have the header-propagation hook installed
|
|
34
|
+
# (by object id) so we never double-install on repeated model calls.
|
|
35
|
+
_hooked_clients: set[int] = set()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _ensure_httpx_hook(model: Any) -> None:
|
|
39
|
+
"""Install the header-propagation httpx hook on a LangChain chat model's
|
|
40
|
+
underlying HTTP client(s), if present. No-op for models that don't expose
|
|
41
|
+
an httpx transport (e.g. non-OpenAI/Anthropic providers).
|
|
42
|
+
"""
|
|
43
|
+
for attr in ("client", "async_client"):
|
|
44
|
+
client = getattr(model, attr, None)
|
|
45
|
+
if client is None:
|
|
46
|
+
continue
|
|
47
|
+
cid = id(client)
|
|
48
|
+
if cid not in _hooked_clients:
|
|
49
|
+
install_httpx_hook(client)
|
|
50
|
+
_hooked_clients.add(cid)
|
|
51
|
+
|
|
52
|
+
|
|
32
53
|
class StateSchema(AgentState):
|
|
33
54
|
copilotkit: CopilotKitProperties
|
|
34
55
|
|
|
@@ -36,15 +57,17 @@ class StateSchema(AgentState):
|
|
|
36
57
|
# Internal/framework keys that should never be surfaced to the LLM as
|
|
37
58
|
# user-facing state. These are either reducer-managed message buckets,
|
|
38
59
|
# CopilotKit/AG-UI plumbing, or graph-internal scaffolding.
|
|
39
|
-
_RESERVED_STATE_KEYS = frozenset(
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
60
|
+
_RESERVED_STATE_KEYS = frozenset(
|
|
61
|
+
{
|
|
62
|
+
"messages",
|
|
63
|
+
"copilotkit",
|
|
64
|
+
"ag-ui",
|
|
65
|
+
"tools",
|
|
66
|
+
"structured_response",
|
|
67
|
+
"thread_id",
|
|
68
|
+
"remaining_steps",
|
|
69
|
+
}
|
|
70
|
+
)
|
|
48
71
|
|
|
49
72
|
|
|
50
73
|
class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
@@ -105,7 +128,8 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
105
128
|
keys: list[str] = [k for k in self._expose_state if k in state]
|
|
106
129
|
else:
|
|
107
130
|
keys = [
|
|
108
|
-
k
|
|
131
|
+
k
|
|
132
|
+
for k in state
|
|
109
133
|
if k not in _RESERVED_STATE_KEYS and not str(k).startswith("_")
|
|
110
134
|
]
|
|
111
135
|
|
|
@@ -133,17 +157,22 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
133
157
|
existing = request.system_message
|
|
134
158
|
if existing is None:
|
|
135
159
|
return request.override(system_message=SystemMessage(content=note))
|
|
136
|
-
base =
|
|
160
|
+
base = (
|
|
161
|
+
existing.content
|
|
162
|
+
if isinstance(existing.content, str)
|
|
163
|
+
else str(existing.content)
|
|
164
|
+
)
|
|
137
165
|
return request.override(
|
|
138
166
|
system_message=SystemMessage(content=f"{base}\n\n{note}")
|
|
139
167
|
)
|
|
140
168
|
|
|
141
169
|
# Inject frontend tools and surface user state before model call
|
|
142
170
|
def wrap_model_call(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
171
|
+
self,
|
|
172
|
+
request: ModelRequest,
|
|
173
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
146
174
|
) -> ModelResponse:
|
|
175
|
+
_ensure_httpx_hook(request.model)
|
|
147
176
|
request = self._apply_state_note(request)
|
|
148
177
|
frontend_tools = request.state.get("copilotkit", {}).get("actions", [])
|
|
149
178
|
|
|
@@ -185,7 +214,7 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
185
214
|
tc_groups: dict[str, list] = {}
|
|
186
215
|
for i, msg in enumerate(messages):
|
|
187
216
|
if isinstance(msg, ToolMessage):
|
|
188
|
-
tc_id = getattr(msg,
|
|
217
|
+
tc_id = getattr(msg, "tool_call_id", None)
|
|
189
218
|
if tc_id:
|
|
190
219
|
tc_groups.setdefault(tc_id, []).append(i)
|
|
191
220
|
|
|
@@ -195,9 +224,12 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
195
224
|
continue
|
|
196
225
|
# Separate interrupted placeholders from real results
|
|
197
226
|
real_indices = [
|
|
198
|
-
i
|
|
199
|
-
|
|
200
|
-
|
|
227
|
+
i
|
|
228
|
+
for i in indices
|
|
229
|
+
if not (
|
|
230
|
+
isinstance(messages[i].content, str)
|
|
231
|
+
and _INTERRUPTED_PAT.match(messages[i].content)
|
|
232
|
+
)
|
|
201
233
|
]
|
|
202
234
|
interrupted_indices = [i for i in indices if i not in real_indices]
|
|
203
235
|
if real_indices and interrupted_indices:
|
|
@@ -215,31 +247,36 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
215
247
|
drop_indices.update(interrupted_indices[:-1])
|
|
216
248
|
|
|
217
249
|
if drop_indices:
|
|
218
|
-
messages[:] = [
|
|
250
|
+
messages[:] = [
|
|
251
|
+
msg for i, msg in enumerate(messages) if i not in drop_indices
|
|
252
|
+
]
|
|
219
253
|
|
|
220
254
|
for idx, msg in enumerate(messages):
|
|
221
255
|
if not isinstance(msg, AIMessage):
|
|
222
256
|
continue
|
|
223
257
|
|
|
224
|
-
tool_calls = getattr(msg,
|
|
258
|
+
tool_calls = getattr(msg, "tool_calls", None) or []
|
|
225
259
|
|
|
226
260
|
# 1. Sync content with tool_calls: remove tool_use content blocks
|
|
227
261
|
# that aren't in msg.tool_calls (e.g. stripped by after_model
|
|
228
262
|
# but content blocks left behind in checkpoint).
|
|
229
263
|
if tool_calls and isinstance(msg.content, list):
|
|
230
|
-
tc_ids = {tc.get(
|
|
264
|
+
tc_ids = {tc.get("id") for tc in tool_calls}
|
|
231
265
|
msg.content = [
|
|
232
|
-
block
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
266
|
+
block
|
|
267
|
+
for block in msg.content
|
|
268
|
+
if not (
|
|
269
|
+
isinstance(block, dict)
|
|
270
|
+
and block.get("type") == "tool_use"
|
|
271
|
+
and block.get("id") not in tc_ids
|
|
272
|
+
)
|
|
236
273
|
]
|
|
237
274
|
elif not tool_calls and isinstance(msg.content, list):
|
|
238
275
|
# No tool_calls at all — strip ALL tool_use content blocks
|
|
239
276
|
msg.content = [
|
|
240
|
-
block
|
|
241
|
-
|
|
242
|
-
|
|
277
|
+
block
|
|
278
|
+
for block in msg.content
|
|
279
|
+
if not (isinstance(block, dict) and block.get("type") == "tool_use")
|
|
243
280
|
]
|
|
244
281
|
|
|
245
282
|
if not tool_calls:
|
|
@@ -253,45 +290,52 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
253
290
|
adjacent_tc_ids: set = set()
|
|
254
291
|
j = idx + 1
|
|
255
292
|
while j < len(messages) and isinstance(messages[j], ToolMessage):
|
|
256
|
-
tc_id = getattr(messages[j],
|
|
293
|
+
tc_id = getattr(messages[j], "tool_call_id", None)
|
|
257
294
|
if tc_id:
|
|
258
295
|
adjacent_tc_ids.add(tc_id)
|
|
259
296
|
j += 1
|
|
260
297
|
|
|
261
|
-
unanswered = [
|
|
298
|
+
unanswered = [
|
|
299
|
+
tc for tc in tool_calls if tc.get("id") not in adjacent_tc_ids
|
|
300
|
+
]
|
|
262
301
|
if unanswered:
|
|
263
|
-
unanswered_ids = {tc[
|
|
264
|
-
msg.tool_calls = [
|
|
302
|
+
unanswered_ids = {tc["id"] for tc in unanswered}
|
|
303
|
+
msg.tool_calls = [
|
|
304
|
+
tc for tc in tool_calls if tc.get("id") in adjacent_tc_ids
|
|
305
|
+
]
|
|
265
306
|
|
|
266
307
|
# Also strip matching content blocks
|
|
267
308
|
if isinstance(msg.content, list):
|
|
268
309
|
msg.content = [
|
|
269
|
-
block
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
310
|
+
block
|
|
311
|
+
for block in msg.content
|
|
312
|
+
if not (
|
|
313
|
+
isinstance(block, dict)
|
|
314
|
+
and block.get("type") == "tool_use"
|
|
315
|
+
and block.get("id") in unanswered_ids
|
|
316
|
+
)
|
|
273
317
|
]
|
|
274
318
|
|
|
275
319
|
# 3. Fix string args in tool_calls
|
|
276
|
-
for tc in
|
|
277
|
-
if isinstance(tc.get(
|
|
320
|
+
for tc in msg.tool_calls or []:
|
|
321
|
+
if isinstance(tc.get("args"), str):
|
|
278
322
|
try:
|
|
279
|
-
tc[
|
|
323
|
+
tc["args"] = json.loads(tc["args"])
|
|
280
324
|
except (json.JSONDecodeError, TypeError):
|
|
281
|
-
tc[
|
|
325
|
+
tc["args"] = {}
|
|
282
326
|
|
|
283
327
|
# 4. Fix string input in content blocks
|
|
284
328
|
if isinstance(msg.content, list):
|
|
285
329
|
for block in msg.content:
|
|
286
|
-
if isinstance(block, dict) and block.get(
|
|
287
|
-
inp = block.get(
|
|
330
|
+
if isinstance(block, dict) and block.get("type") == "tool_use":
|
|
331
|
+
inp = block.get("input")
|
|
288
332
|
if isinstance(inp, str):
|
|
289
333
|
try:
|
|
290
|
-
block[
|
|
334
|
+
block["input"] = json.loads(inp) if inp else {}
|
|
291
335
|
except (json.JSONDecodeError, TypeError):
|
|
292
|
-
block[
|
|
336
|
+
block["input"] = {}
|
|
293
337
|
elif inp is None:
|
|
294
|
-
block[
|
|
338
|
+
block["input"] = {}
|
|
295
339
|
|
|
296
340
|
# 5. Remove orphan ToolMessages whose tool_call_id no longer matches
|
|
297
341
|
# any remaining tool_call in any AIMessage. These can be left over
|
|
@@ -299,23 +343,25 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
299
343
|
remaining_tc_ids: set = set()
|
|
300
344
|
for msg in messages:
|
|
301
345
|
if isinstance(msg, AIMessage):
|
|
302
|
-
for tc in
|
|
303
|
-
tc_id = tc.get(
|
|
346
|
+
for tc in getattr(msg, "tool_calls", None) or []:
|
|
347
|
+
tc_id = tc.get("id")
|
|
304
348
|
if tc_id:
|
|
305
349
|
remaining_tc_ids.add(tc_id)
|
|
306
350
|
messages[:] = [
|
|
307
|
-
msg
|
|
351
|
+
msg
|
|
352
|
+
for msg in messages
|
|
308
353
|
if not isinstance(msg, ToolMessage)
|
|
309
|
-
|
|
354
|
+
or getattr(msg, "tool_call_id", None) in remaining_tc_ids
|
|
310
355
|
]
|
|
311
356
|
|
|
312
357
|
return messages
|
|
313
358
|
|
|
314
359
|
async def awrap_model_call(
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
360
|
+
self,
|
|
361
|
+
request: ModelRequest,
|
|
362
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
318
363
|
) -> ModelResponse:
|
|
364
|
+
_ensure_httpx_hook(request.model)
|
|
319
365
|
self._fix_messages_for_bedrock(request.messages)
|
|
320
366
|
request = self._apply_state_note(request)
|
|
321
367
|
|
|
@@ -331,9 +377,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
331
377
|
|
|
332
378
|
# Inject app context before agent runs
|
|
333
379
|
def before_agent(
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
380
|
+
self,
|
|
381
|
+
state: StateSchema,
|
|
382
|
+
runtime: Runtime[Any],
|
|
337
383
|
) -> dict[str, Any] | None:
|
|
338
384
|
messages = state.get("messages", [])
|
|
339
385
|
|
|
@@ -342,7 +388,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
342
388
|
|
|
343
389
|
# Get app context from state or runtime
|
|
344
390
|
copilotkit_state = state.get("copilotkit", {})
|
|
345
|
-
app_context = copilotkit_state.get("context") or getattr(
|
|
391
|
+
app_context = copilotkit_state.get("context") or getattr(
|
|
392
|
+
runtime, "context", None
|
|
393
|
+
)
|
|
346
394
|
|
|
347
395
|
# Check if app_context is missing or empty
|
|
348
396
|
if not app_context:
|
|
@@ -408,7 +456,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
408
456
|
# duplicate at the end of the message list.
|
|
409
457
|
if existing_context_index != -1:
|
|
410
458
|
existing_id = getattr(messages[existing_context_index], "id", None)
|
|
411
|
-
context_message = SystemMessage(
|
|
459
|
+
context_message = SystemMessage(
|
|
460
|
+
content=context_message_content, id=existing_id
|
|
461
|
+
)
|
|
412
462
|
else:
|
|
413
463
|
context_message = SystemMessage(content=context_message_content)
|
|
414
464
|
|
|
@@ -431,26 +481,25 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
431
481
|
}
|
|
432
482
|
|
|
433
483
|
async def abefore_agent(
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
484
|
+
self,
|
|
485
|
+
state: StateSchema,
|
|
486
|
+
runtime: Runtime[Any],
|
|
437
487
|
) -> dict[str, Any] | None:
|
|
438
488
|
# Delegate to sync implementation
|
|
439
489
|
return self.before_agent(state, runtime)
|
|
440
490
|
|
|
441
491
|
# Intercept frontend tool calls after model returns, before ToolNode executes
|
|
442
492
|
def after_model(
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
493
|
+
self,
|
|
494
|
+
state: StateSchema,
|
|
495
|
+
runtime: Runtime[Any],
|
|
446
496
|
) -> dict[str, Any] | None:
|
|
447
497
|
frontend_tools = state.get("copilotkit", {}).get("actions", [])
|
|
448
498
|
if not frontend_tools:
|
|
449
499
|
return None
|
|
450
500
|
|
|
451
501
|
frontend_tool_names = {
|
|
452
|
-
t.get("function", {}).get("name") or t.get("name")
|
|
453
|
-
for t in frontend_tools
|
|
502
|
+
t.get("function", {}).get("name") or t.get("name") for t in frontend_tools
|
|
454
503
|
}
|
|
455
504
|
|
|
456
505
|
# Find last AI message with tool calls
|
|
@@ -494,18 +543,18 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
494
543
|
}
|
|
495
544
|
|
|
496
545
|
async def aafter_model(
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
546
|
+
self,
|
|
547
|
+
state: StateSchema,
|
|
548
|
+
runtime: Runtime[Any],
|
|
500
549
|
) -> dict[str, Any] | None:
|
|
501
550
|
# Delegate to sync implementation
|
|
502
551
|
return self.after_model(state, runtime)
|
|
503
552
|
|
|
504
553
|
# Restore frontend tool calls to AIMessage before agent exits
|
|
505
554
|
def after_agent(
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
555
|
+
self,
|
|
556
|
+
state: StateSchema,
|
|
557
|
+
runtime: Runtime[Any],
|
|
509
558
|
) -> dict[str, Any] | None:
|
|
510
559
|
copilotkit_state = state.get("copilotkit", {})
|
|
511
560
|
intercepted_tool_calls = copilotkit_state.get("intercepted_tool_calls")
|
|
@@ -520,11 +569,13 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
520
569
|
for msg in messages:
|
|
521
570
|
if isinstance(msg, AIMessage) and msg.id == original_message_id:
|
|
522
571
|
existing_tool_calls = getattr(msg, "tool_calls", None) or []
|
|
523
|
-
updated_messages.append(
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
572
|
+
updated_messages.append(
|
|
573
|
+
AIMessage(
|
|
574
|
+
content=msg.content,
|
|
575
|
+
tool_calls=[*existing_tool_calls, *intercepted_tool_calls],
|
|
576
|
+
id=msg.id,
|
|
577
|
+
)
|
|
578
|
+
)
|
|
528
579
|
else:
|
|
529
580
|
updated_messages.append(msg)
|
|
530
581
|
|
|
@@ -537,9 +588,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
|
|
|
537
588
|
}
|
|
538
589
|
|
|
539
590
|
async def aafter_agent(
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
591
|
+
self,
|
|
592
|
+
state: StateSchema,
|
|
593
|
+
runtime: Runtime[Any],
|
|
543
594
|
) -> dict[str, Any] | None:
|
|
544
595
|
# Delegate to sync implementation
|
|
545
596
|
return self.after_agent(state, runtime)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
CrewAI
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
from .crewai_agent import CrewAIAgent
|
|
5
6
|
from .crewai_sdk import (
|
|
6
7
|
CopilotKitProperties,
|
|
@@ -20,8 +21,9 @@ from .copilotkit_integration import (
|
|
|
20
21
|
create_tool_proxy,
|
|
21
22
|
FlowInputState,
|
|
22
23
|
CopilotKitStateUpdateEvent,
|
|
23
|
-
emit_copilotkit_state_update_event
|
|
24
|
+
emit_copilotkit_state_update_event,
|
|
24
25
|
)
|
|
26
|
+
|
|
25
27
|
__all__ = [
|
|
26
28
|
"CrewAIAgent",
|
|
27
29
|
"CopilotKitProperties",
|
|
@@ -39,5 +41,5 @@ __all__ = [
|
|
|
39
41
|
"create_tool_proxy",
|
|
40
42
|
"FlowInputState",
|
|
41
43
|
"CopilotKitStateUpdateEvent",
|
|
42
|
-
"emit_copilotkit_state_update_event"
|
|
44
|
+
"emit_copilotkit_state_update_event",
|
|
43
45
|
]
|