copilotkit 0.1.89__tar.gz → 0.1.90__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {copilotkit-0.1.89 → copilotkit-0.1.90}/PKG-INFO +1 -1
  2. copilotkit-0.1.90/copilotkit/__init__.py +43 -0
  3. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/a2ui.py +4 -6
  4. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/action.py +20 -20
  5. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/agent.py +14 -14
  6. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/copilotkit_lg_middleware.py +130 -79
  7. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/__init__.py +4 -2
  8. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/copilotkit_integration.py +134 -56
  9. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/crewai_agent.py +124 -108
  10. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/crewai/crewai_sdk.py +142 -144
  11. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/exc.py +4 -0
  12. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/header_propagation.py +7 -6
  13. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/html.py +3 -1
  14. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/integrations/fastapi.py +72 -71
  15. copilotkit-0.1.90/copilotkit/langchain.py +30 -0
  16. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/langgraph.py +84 -77
  17. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/langgraph_agui_agent.py +50 -24
  18. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/logging.py +4 -2
  19. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/parameter.py +23 -15
  20. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/protocol.py +81 -44
  21. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/runloop.py +37 -41
  22. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/sdk.py +27 -32
  23. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/types.py +15 -1
  24. copilotkit-0.1.90/copilotkit/utils.py +5 -0
  25. {copilotkit-0.1.89 → copilotkit-0.1.90}/pyproject.toml +1 -1
  26. copilotkit-0.1.89/copilotkit/__init__.py +0 -31
  27. copilotkit-0.1.89/copilotkit/langchain.py +0 -29
  28. copilotkit-0.1.89/copilotkit/utils.py +0 -8
  29. {copilotkit-0.1.89 → copilotkit-0.1.90}/README.md +0 -0
  30. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/integrations/__init__.py +0 -0
  31. {copilotkit-0.1.89 → copilotkit-0.1.90}/copilotkit/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: copilotkit
3
- Version: 0.1.89
3
+ Version: 0.1.90
4
4
  Summary: CopilotKit python SDK
5
5
  License: MIT
6
6
  Keywords: copilot,copilotkit,langgraph,langchain,ai,langsmith,langserve
@@ -0,0 +1,43 @@
1
+ """CopilotKit SDK"""
2
+
3
+ from .sdk import (
4
+ CopilotKitRemoteEndpoint,
5
+ CopilotKitContext,
6
+ CopilotKitSDK,
7
+ CopilotKitSDKContext,
8
+ )
9
+ from .action import Action
10
+ from .langgraph import CopilotKitState
11
+ from .parameter import Parameter
12
+ from .agent import Agent
13
+ from .langgraph_agui_agent import LangGraphAGUIAgent
14
+ from .copilotkit_lg_middleware import CopilotKitMiddleware
15
+ from ag_ui_langgraph.middlewares.state_streaming import (
16
+ StateStreamingMiddleware,
17
+ StateItem,
18
+ )
19
+ from .header_propagation import (
20
+ set_forwarded_headers,
21
+ get_forwarded_headers,
22
+ install_httpx_hook,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "CopilotKitRemoteEndpoint",
28
+ "CopilotKitSDK",
29
+ "Action",
30
+ "CopilotKitState",
31
+ "Parameter",
32
+ "Agent",
33
+ "CopilotKitContext",
34
+ "CopilotKitSDKContext",
35
+ "CrewAIAgent", # pyright: ignore[reportUnsupportedDunderAll] pylint: disable=undefined-all-variable
36
+ "LangGraphAGUIAgent",
37
+ "CopilotKitMiddleware",
38
+ "StateStreamingMiddleware",
39
+ "StateItem",
40
+ "set_forwarded_headers",
41
+ "get_forwarded_headers",
42
+ "install_httpx_hook",
43
+ ]
@@ -38,7 +38,7 @@ def update_components(
38
38
  "updateComponents": {
39
39
  "surfaceId": surface_id,
40
40
  "components": components,
41
- }
41
+ },
42
42
  }
43
43
 
44
44
 
@@ -54,7 +54,7 @@ def update_data_model(
54
54
  "surfaceId": surface_id,
55
55
  "path": path,
56
56
  "value": data,
57
- }
57
+ },
58
58
  }
59
59
 
60
60
 
@@ -72,7 +72,7 @@ def create_surface(
72
72
  "createSurface": {
73
73
  "surfaceId": surface_id,
74
74
  "catalogId": catalog_id,
75
- }
75
+ },
76
76
  }
77
77
 
78
78
 
@@ -80,9 +80,7 @@ A2UI_OPERATIONS_KEY = "a2ui_operations"
80
80
  """The container key used to wrap A2UI operations for explicit detection."""
81
81
 
82
82
 
83
- def render(
84
- operations: list[dict[str, Any]]
85
- ) -> str:
83
+ def render(operations: list[dict[str, Any]]) -> str:
86
84
  """Wrap operations in the a2ui_operations container and serialize to JSON.
87
85
 
88
86
  Args:
@@ -5,26 +5,32 @@ from inspect import iscoroutinefunction
5
5
  from typing import Optional, List, Callable, TypedDict, Any, cast
6
6
  from .parameter import Parameter, normalize_parameters
7
7
 
8
+
8
9
  class ActionDict(TypedDict):
9
10
  """Dict representation of an action"""
11
+
10
12
  name: str
11
13
  description: str
12
14
  parameters: List[Parameter]
13
15
 
16
+
14
17
  class ActionResultDict(TypedDict):
15
18
  """Dict representation of an action result"""
19
+
16
20
  result: Any
17
21
 
22
+
18
23
  class Action: # pylint: disable=too-few-public-methods
19
24
  """Action class for CopilotKit"""
25
+
20
26
  def __init__(
21
- self,
22
- *,
23
- name: str,
24
- handler: Callable,
25
- description: Optional[str] = None,
26
- parameters: Optional[List[Parameter]] = None,
27
- ):
27
+ self,
28
+ *,
29
+ name: str,
30
+ handler: Callable,
31
+ description: Optional[str] = None,
32
+ parameters: Optional[List[Parameter]] = None,
33
+ ):
28
34
  self.name = name
29
35
  self.description = description
30
36
  self.parameters = parameters
@@ -32,26 +38,20 @@ class Action: # pylint: disable=too-few-public-methods
32
38
 
33
39
  if not re.match(r"^[a-zA-Z0-9_-]+$", name):
34
40
  raise ValueError(
35
- f"Invalid action name '{name}': " +
36
- "must consist of alphanumeric characters, underscores, and hyphens only"
41
+ f"Invalid action name '{name}': "
42
+ + "must consist of alphanumeric characters, underscores, and hyphens only"
37
43
  )
38
44
 
39
- async def execute(
40
- self,
41
- *,
42
- arguments: dict
43
- ) -> ActionResultDict:
45
+ async def execute(self, *, arguments: dict) -> ActionResultDict:
44
46
  """Execute the action"""
45
47
  result = self.handler(**arguments)
46
48
 
47
- return {
48
- "result": await result if iscoroutinefunction(self.handler) else result
49
- }
49
+ return {"result": await result if iscoroutinefunction(self.handler) else result}
50
50
 
51
51
  def dict_repr(self) -> ActionDict:
52
52
  """Dict representation of the action"""
53
53
  return {
54
- 'name': self.name,
55
- 'description': self.description or '',
56
- 'parameters': normalize_parameters(cast(Any, self.parameters)),
54
+ "name": self.name,
55
+ "description": self.description or "",
56
+ "parameters": normalize_parameters(cast(Any, self.parameters)),
57
57
  }
@@ -7,30 +7,34 @@ from .types import Message
7
7
  from .action import ActionDict
8
8
  from .types import MetaEvent
9
9
 
10
+
10
11
  class AgentDict(TypedDict):
11
12
  """Agent dictionary"""
13
+
12
14
  name: str
13
15
  description: Optional[str]
14
16
 
17
+
15
18
  class Agent(ABC):
16
19
  """Agent class for CopilotKit"""
20
+
17
21
  def __init__(
18
- self,
19
- *,
20
- name: str,
21
- description: Optional[str] = None,
22
- ):
22
+ self,
23
+ *,
24
+ name: str,
25
+ description: Optional[str] = None,
26
+ ):
23
27
  self.name = name
24
28
  self.description = description
25
29
 
26
30
  if not re.match(r"^[a-zA-Z0-9_-]+$", name):
27
31
  raise ValueError(
28
- f"Invalid agent name '{name}': " +
29
- "must consist of alphanumeric characters, underscores, and hyphens only"
32
+ f"Invalid agent name '{name}': "
33
+ + "must consist of alphanumeric characters, underscores, and hyphens only"
30
34
  )
31
35
 
32
36
  @abstractmethod
33
- def execute( # pylint: disable=too-many-arguments
37
+ def execute( # pylint: disable=too-many-arguments
34
38
  self,
35
39
  *,
36
40
  state: dict,
@@ -54,13 +58,9 @@ class Agent(ABC):
54
58
  "threadId": thread_id or "",
55
59
  "threadExists": False,
56
60
  "state": {},
57
- "messages": []
61
+ "messages": [],
58
62
  }
59
63
 
60
-
61
64
  def dict_repr(self) -> AgentDict:
62
65
  """Dict representation of the action"""
63
- return {
64
- 'name': self.name,
65
- 'description': self.description or ''
66
- }
66
+ return {"name": self.name, "description": self.description or ""}
@@ -27,8 +27,29 @@ from langchain.agents.middleware import (
27
27
  )
28
28
  from langgraph.runtime import Runtime
29
29
 
30
+ from .header_propagation import install_httpx_hook
30
31
  from .langgraph import CopilotKitProperties
31
32
 
33
+ # Track which httpx clients already have the header-propagation hook installed
34
+ # (by object id) so we never double-install on repeated model calls.
35
+ _hooked_clients: set[int] = set()
36
+
37
+
38
+ def _ensure_httpx_hook(model: Any) -> None:
39
+ """Install the header-propagation httpx hook on a LangChain chat model's
40
+ underlying HTTP client(s), if present. No-op for models that don't expose
41
+ an httpx transport (e.g. non-OpenAI/Anthropic providers).
42
+ """
43
+ for attr in ("client", "async_client"):
44
+ client = getattr(model, attr, None)
45
+ if client is None:
46
+ continue
47
+ cid = id(client)
48
+ if cid not in _hooked_clients:
49
+ install_httpx_hook(client)
50
+ _hooked_clients.add(cid)
51
+
52
+
32
53
  class StateSchema(AgentState):
33
54
  copilotkit: CopilotKitProperties
34
55
 
@@ -36,15 +57,17 @@ class StateSchema(AgentState):
36
57
  # Internal/framework keys that should never be surfaced to the LLM as
37
58
  # user-facing state. These are either reducer-managed message buckets,
38
59
  # CopilotKit/AG-UI plumbing, or graph-internal scaffolding.
39
- _RESERVED_STATE_KEYS = frozenset({
40
- "messages",
41
- "copilotkit",
42
- "ag-ui",
43
- "tools",
44
- "structured_response",
45
- "thread_id",
46
- "remaining_steps",
47
- })
60
+ _RESERVED_STATE_KEYS = frozenset(
61
+ {
62
+ "messages",
63
+ "copilotkit",
64
+ "ag-ui",
65
+ "tools",
66
+ "structured_response",
67
+ "thread_id",
68
+ "remaining_steps",
69
+ }
70
+ )
48
71
 
49
72
 
50
73
  class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
@@ -105,7 +128,8 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
105
128
  keys: list[str] = [k for k in self._expose_state if k in state]
106
129
  else:
107
130
  keys = [
108
- k for k in state
131
+ k
132
+ for k in state
109
133
  if k not in _RESERVED_STATE_KEYS and not str(k).startswith("_")
110
134
  ]
111
135
 
@@ -133,17 +157,22 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
133
157
  existing = request.system_message
134
158
  if existing is None:
135
159
  return request.override(system_message=SystemMessage(content=note))
136
- base = existing.content if isinstance(existing.content, str) else str(existing.content)
160
+ base = (
161
+ existing.content
162
+ if isinstance(existing.content, str)
163
+ else str(existing.content)
164
+ )
137
165
  return request.override(
138
166
  system_message=SystemMessage(content=f"{base}\n\n{note}")
139
167
  )
140
168
 
141
169
  # Inject frontend tools and surface user state before model call
142
170
  def wrap_model_call(
143
- self,
144
- request: ModelRequest,
145
- handler: Callable[[ModelRequest], ModelResponse],
171
+ self,
172
+ request: ModelRequest,
173
+ handler: Callable[[ModelRequest], ModelResponse],
146
174
  ) -> ModelResponse:
175
+ _ensure_httpx_hook(request.model)
147
176
  request = self._apply_state_note(request)
148
177
  frontend_tools = request.state.get("copilotkit", {}).get("actions", [])
149
178
 
@@ -185,7 +214,7 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
185
214
  tc_groups: dict[str, list] = {}
186
215
  for i, msg in enumerate(messages):
187
216
  if isinstance(msg, ToolMessage):
188
- tc_id = getattr(msg, 'tool_call_id', None)
217
+ tc_id = getattr(msg, "tool_call_id", None)
189
218
  if tc_id:
190
219
  tc_groups.setdefault(tc_id, []).append(i)
191
220
 
@@ -195,9 +224,12 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
195
224
  continue
196
225
  # Separate interrupted placeholders from real results
197
226
  real_indices = [
198
- i for i in indices
199
- if not (isinstance(messages[i].content, str)
200
- and _INTERRUPTED_PAT.match(messages[i].content))
227
+ i
228
+ for i in indices
229
+ if not (
230
+ isinstance(messages[i].content, str)
231
+ and _INTERRUPTED_PAT.match(messages[i].content)
232
+ )
201
233
  ]
202
234
  interrupted_indices = [i for i in indices if i not in real_indices]
203
235
  if real_indices and interrupted_indices:
@@ -215,31 +247,36 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
215
247
  drop_indices.update(interrupted_indices[:-1])
216
248
 
217
249
  if drop_indices:
218
- messages[:] = [msg for i, msg in enumerate(messages) if i not in drop_indices]
250
+ messages[:] = [
251
+ msg for i, msg in enumerate(messages) if i not in drop_indices
252
+ ]
219
253
 
220
254
  for idx, msg in enumerate(messages):
221
255
  if not isinstance(msg, AIMessage):
222
256
  continue
223
257
 
224
- tool_calls = getattr(msg, 'tool_calls', None) or []
258
+ tool_calls = getattr(msg, "tool_calls", None) or []
225
259
 
226
260
  # 1. Sync content with tool_calls: remove tool_use content blocks
227
261
  # that aren't in msg.tool_calls (e.g. stripped by after_model
228
262
  # but content blocks left behind in checkpoint).
229
263
  if tool_calls and isinstance(msg.content, list):
230
- tc_ids = {tc.get('id') for tc in tool_calls}
264
+ tc_ids = {tc.get("id") for tc in tool_calls}
231
265
  msg.content = [
232
- block for block in msg.content
233
- if not (isinstance(block, dict)
234
- and block.get('type') == 'tool_use'
235
- and block.get('id') not in tc_ids)
266
+ block
267
+ for block in msg.content
268
+ if not (
269
+ isinstance(block, dict)
270
+ and block.get("type") == "tool_use"
271
+ and block.get("id") not in tc_ids
272
+ )
236
273
  ]
237
274
  elif not tool_calls and isinstance(msg.content, list):
238
275
  # No tool_calls at all — strip ALL tool_use content blocks
239
276
  msg.content = [
240
- block for block in msg.content
241
- if not (isinstance(block, dict)
242
- and block.get('type') == 'tool_use')
277
+ block
278
+ for block in msg.content
279
+ if not (isinstance(block, dict) and block.get("type") == "tool_use")
243
280
  ]
244
281
 
245
282
  if not tool_calls:
@@ -253,45 +290,52 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
253
290
  adjacent_tc_ids: set = set()
254
291
  j = idx + 1
255
292
  while j < len(messages) and isinstance(messages[j], ToolMessage):
256
- tc_id = getattr(messages[j], 'tool_call_id', None)
293
+ tc_id = getattr(messages[j], "tool_call_id", None)
257
294
  if tc_id:
258
295
  adjacent_tc_ids.add(tc_id)
259
296
  j += 1
260
297
 
261
- unanswered = [tc for tc in tool_calls if tc.get('id') not in adjacent_tc_ids]
298
+ unanswered = [
299
+ tc for tc in tool_calls if tc.get("id") not in adjacent_tc_ids
300
+ ]
262
301
  if unanswered:
263
- unanswered_ids = {tc['id'] for tc in unanswered}
264
- msg.tool_calls = [tc for tc in tool_calls if tc.get('id') in adjacent_tc_ids]
302
+ unanswered_ids = {tc["id"] for tc in unanswered}
303
+ msg.tool_calls = [
304
+ tc for tc in tool_calls if tc.get("id") in adjacent_tc_ids
305
+ ]
265
306
 
266
307
  # Also strip matching content blocks
267
308
  if isinstance(msg.content, list):
268
309
  msg.content = [
269
- block for block in msg.content
270
- if not (isinstance(block, dict)
271
- and block.get('type') == 'tool_use'
272
- and block.get('id') in unanswered_ids)
310
+ block
311
+ for block in msg.content
312
+ if not (
313
+ isinstance(block, dict)
314
+ and block.get("type") == "tool_use"
315
+ and block.get("id") in unanswered_ids
316
+ )
273
317
  ]
274
318
 
275
319
  # 3. Fix string args in tool_calls
276
- for tc in (msg.tool_calls or []):
277
- if isinstance(tc.get('args'), str):
320
+ for tc in msg.tool_calls or []:
321
+ if isinstance(tc.get("args"), str):
278
322
  try:
279
- tc['args'] = json.loads(tc['args'])
323
+ tc["args"] = json.loads(tc["args"])
280
324
  except (json.JSONDecodeError, TypeError):
281
- tc['args'] = {}
325
+ tc["args"] = {}
282
326
 
283
327
  # 4. Fix string input in content blocks
284
328
  if isinstance(msg.content, list):
285
329
  for block in msg.content:
286
- if isinstance(block, dict) and block.get('type') == 'tool_use':
287
- inp = block.get('input')
330
+ if isinstance(block, dict) and block.get("type") == "tool_use":
331
+ inp = block.get("input")
288
332
  if isinstance(inp, str):
289
333
  try:
290
- block['input'] = json.loads(inp) if inp else {}
334
+ block["input"] = json.loads(inp) if inp else {}
291
335
  except (json.JSONDecodeError, TypeError):
292
- block['input'] = {}
336
+ block["input"] = {}
293
337
  elif inp is None:
294
- block['input'] = {}
338
+ block["input"] = {}
295
339
 
296
340
  # 5. Remove orphan ToolMessages whose tool_call_id no longer matches
297
341
  # any remaining tool_call in any AIMessage. These can be left over
@@ -299,23 +343,25 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
299
343
  remaining_tc_ids: set = set()
300
344
  for msg in messages:
301
345
  if isinstance(msg, AIMessage):
302
- for tc in (getattr(msg, 'tool_calls', None) or []):
303
- tc_id = tc.get('id')
346
+ for tc in getattr(msg, "tool_calls", None) or []:
347
+ tc_id = tc.get("id")
304
348
  if tc_id:
305
349
  remaining_tc_ids.add(tc_id)
306
350
  messages[:] = [
307
- msg for msg in messages
351
+ msg
352
+ for msg in messages
308
353
  if not isinstance(msg, ToolMessage)
309
- or getattr(msg, 'tool_call_id', None) in remaining_tc_ids
354
+ or getattr(msg, "tool_call_id", None) in remaining_tc_ids
310
355
  ]
311
356
 
312
357
  return messages
313
358
 
314
359
  async def awrap_model_call(
315
- self,
316
- request: ModelRequest,
317
- handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
360
+ self,
361
+ request: ModelRequest,
362
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
318
363
  ) -> ModelResponse:
364
+ _ensure_httpx_hook(request.model)
319
365
  self._fix_messages_for_bedrock(request.messages)
320
366
  request = self._apply_state_note(request)
321
367
 
@@ -331,9 +377,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
331
377
 
332
378
  # Inject app context before agent runs
333
379
  def before_agent(
334
- self,
335
- state: StateSchema,
336
- runtime: Runtime[Any],
380
+ self,
381
+ state: StateSchema,
382
+ runtime: Runtime[Any],
337
383
  ) -> dict[str, Any] | None:
338
384
  messages = state.get("messages", [])
339
385
 
@@ -342,7 +388,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
342
388
 
343
389
  # Get app context from state or runtime
344
390
  copilotkit_state = state.get("copilotkit", {})
345
- app_context = copilotkit_state.get("context") or getattr(runtime, "context", None)
391
+ app_context = copilotkit_state.get("context") or getattr(
392
+ runtime, "context", None
393
+ )
346
394
 
347
395
  # Check if app_context is missing or empty
348
396
  if not app_context:
@@ -408,7 +456,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
408
456
  # duplicate at the end of the message list.
409
457
  if existing_context_index != -1:
410
458
  existing_id = getattr(messages[existing_context_index], "id", None)
411
- context_message = SystemMessage(content=context_message_content, id=existing_id)
459
+ context_message = SystemMessage(
460
+ content=context_message_content, id=existing_id
461
+ )
412
462
  else:
413
463
  context_message = SystemMessage(content=context_message_content)
414
464
 
@@ -431,26 +481,25 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
431
481
  }
432
482
 
433
483
  async def abefore_agent(
434
- self,
435
- state: StateSchema,
436
- runtime: Runtime[Any],
484
+ self,
485
+ state: StateSchema,
486
+ runtime: Runtime[Any],
437
487
  ) -> dict[str, Any] | None:
438
488
  # Delegate to sync implementation
439
489
  return self.before_agent(state, runtime)
440
490
 
441
491
  # Intercept frontend tool calls after model returns, before ToolNode executes
442
492
  def after_model(
443
- self,
444
- state: StateSchema,
445
- runtime: Runtime[Any],
493
+ self,
494
+ state: StateSchema,
495
+ runtime: Runtime[Any],
446
496
  ) -> dict[str, Any] | None:
447
497
  frontend_tools = state.get("copilotkit", {}).get("actions", [])
448
498
  if not frontend_tools:
449
499
  return None
450
500
 
451
501
  frontend_tool_names = {
452
- t.get("function", {}).get("name") or t.get("name")
453
- for t in frontend_tools
502
+ t.get("function", {}).get("name") or t.get("name") for t in frontend_tools
454
503
  }
455
504
 
456
505
  # Find last AI message with tool calls
@@ -494,18 +543,18 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
494
543
  }
495
544
 
496
545
  async def aafter_model(
497
- self,
498
- state: StateSchema,
499
- runtime: Runtime[Any],
546
+ self,
547
+ state: StateSchema,
548
+ runtime: Runtime[Any],
500
549
  ) -> dict[str, Any] | None:
501
550
  # Delegate to sync implementation
502
551
  return self.after_model(state, runtime)
503
552
 
504
553
  # Restore frontend tool calls to AIMessage before agent exits
505
554
  def after_agent(
506
- self,
507
- state: StateSchema,
508
- runtime: Runtime[Any],
555
+ self,
556
+ state: StateSchema,
557
+ runtime: Runtime[Any],
509
558
  ) -> dict[str, Any] | None:
510
559
  copilotkit_state = state.get("copilotkit", {})
511
560
  intercepted_tool_calls = copilotkit_state.get("intercepted_tool_calls")
@@ -520,11 +569,13 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
520
569
  for msg in messages:
521
570
  if isinstance(msg, AIMessage) and msg.id == original_message_id:
522
571
  existing_tool_calls = getattr(msg, "tool_calls", None) or []
523
- updated_messages.append(AIMessage(
524
- content=msg.content,
525
- tool_calls=[*existing_tool_calls, *intercepted_tool_calls],
526
- id=msg.id,
527
- ))
572
+ updated_messages.append(
573
+ AIMessage(
574
+ content=msg.content,
575
+ tool_calls=[*existing_tool_calls, *intercepted_tool_calls],
576
+ id=msg.id,
577
+ )
578
+ )
528
579
  else:
529
580
  updated_messages.append(msg)
530
581
 
@@ -537,9 +588,9 @@ class CopilotKitMiddleware(AgentMiddleware[StateSchema, Any]):
537
588
  }
538
589
 
539
590
  async def aafter_agent(
540
- self,
541
- state: StateSchema,
542
- runtime: Runtime[Any],
591
+ self,
592
+ state: StateSchema,
593
+ runtime: Runtime[Any],
543
594
  ) -> dict[str, Any] | None:
544
595
  # Delegate to sync implementation
545
596
  return self.after_agent(state, runtime)
@@ -1,6 +1,7 @@
1
1
  """
2
2
  CrewAI
3
3
  """
4
+
4
5
  from .crewai_agent import CrewAIAgent
5
6
  from .crewai_sdk import (
6
7
  CopilotKitProperties,
@@ -20,8 +21,9 @@ from .copilotkit_integration import (
20
21
  create_tool_proxy,
21
22
  FlowInputState,
22
23
  CopilotKitStateUpdateEvent,
23
- emit_copilotkit_state_update_event
24
+ emit_copilotkit_state_update_event,
24
25
  )
26
+
25
27
  __all__ = [
26
28
  "CrewAIAgent",
27
29
  "CopilotKitProperties",
@@ -39,5 +41,5 @@ __all__ = [
39
41
  "create_tool_proxy",
40
42
  "FlowInputState",
41
43
  "CopilotKitStateUpdateEvent",
42
- "emit_copilotkit_state_update_event"
44
+ "emit_copilotkit_state_update_event",
43
45
  ]