@copilotkit/runtime 1.6.0-next.1 → 1.6.0-next.10

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/CHANGELOG.md +69 -0
  2. package/__snapshots__/schema/schema.graphql +1 -0
  3. package/dist/{chunk-CCQ73DAH.mjs → chunk-3ORZOFAL.mjs} +2 -2
  4. package/dist/{chunk-CHDIEE43.mjs → chunk-5NJDGEB3.mjs} +87 -44
  5. package/dist/chunk-5NJDGEB3.mjs.map +1 -0
  6. package/dist/{chunk-A25FIW7J.mjs → chunk-6XDMOAOM.mjs} +2 -2
  7. package/dist/{chunk-OS5YD32G.mjs → chunk-FZJAYGIR.mjs} +45 -12
  8. package/dist/chunk-FZJAYGIR.mjs.map +1 -0
  9. package/dist/{chunk-C7GTLEVO.mjs → chunk-JUWN34MZ.mjs} +2 -2
  10. package/dist/{copilot-runtime-67033bfa.d.ts → copilot-runtime-15bfc4f4.d.ts} +2 -2
  11. package/dist/graphql/types/converted/index.d.ts +1 -1
  12. package/dist/{groq-adapter-9d15c927.d.ts → groq-adapter-fb9aa3ab.d.ts} +1 -1
  13. package/dist/{index-f6d1f30b.d.ts → index-5bec5424.d.ts} +2 -1
  14. package/dist/index.d.ts +4 -4
  15. package/dist/index.js +199 -123
  16. package/dist/index.js.map +1 -1
  17. package/dist/index.mjs +5 -5
  18. package/dist/{langserve-7cc5be48.d.ts → langserve-6f7af8d3.d.ts} +1 -1
  19. package/dist/lib/index.d.ts +4 -4
  20. package/dist/lib/index.js +192 -116
  21. package/dist/lib/index.js.map +1 -1
  22. package/dist/lib/index.mjs +5 -5
  23. package/dist/lib/integrations/index.d.ts +4 -4
  24. package/dist/lib/integrations/index.js +5 -3
  25. package/dist/lib/integrations/index.js.map +1 -1
  26. package/dist/lib/integrations/index.mjs +5 -5
  27. package/dist/lib/integrations/nest/index.d.ts +3 -3
  28. package/dist/lib/integrations/nest/index.js +5 -3
  29. package/dist/lib/integrations/nest/index.js.map +1 -1
  30. package/dist/lib/integrations/nest/index.mjs +3 -3
  31. package/dist/lib/integrations/node-express/index.d.ts +3 -3
  32. package/dist/lib/integrations/node-express/index.js +5 -3
  33. package/dist/lib/integrations/node-express/index.js.map +1 -1
  34. package/dist/lib/integrations/node-express/index.mjs +3 -3
  35. package/dist/lib/integrations/node-http/index.d.ts +3 -3
  36. package/dist/lib/integrations/node-http/index.js +5 -3
  37. package/dist/lib/integrations/node-http/index.js.map +1 -1
  38. package/dist/lib/integrations/node-http/index.mjs +2 -2
  39. package/dist/service-adapters/index.d.ts +4 -4
  40. package/dist/service-adapters/index.js +70 -37
  41. package/dist/service-adapters/index.js.map +1 -1
  42. package/dist/service-adapters/index.mjs +1 -1
  43. package/package.json +5 -4
  44. package/src/agents/langgraph/event-source.ts +21 -5
  45. package/src/graphql/types/enums.ts +1 -0
  46. package/src/lib/runtime/__tests__/remote-action-constructors.test.ts +236 -0
  47. package/src/lib/runtime/copilot-runtime.ts +9 -3
  48. package/src/lib/runtime/remote-action-constructors.ts +9 -7
  49. package/src/lib/runtime/remote-lg-action.ts +35 -7
  50. package/src/service-adapters/conversion.ts +39 -46
  51. package/src/service-adapters/groq/groq-adapter.ts +6 -3
  52. package/src/service-adapters/openai/openai-adapter.ts +1 -1
  53. package/src/service-adapters/openai/openai-assistant-adapter.ts +1 -1
  54. package/src/service-adapters/openai/utils.ts +39 -13
  55. package/src/service-adapters/unify/unify-adapter.ts +1 -1
  56. package/tsconfig.json +3 -2
  57. package/dist/chunk-CHDIEE43.mjs.map +0 -1
  58. package/dist/chunk-OS5YD32G.mjs.map +0 -1
  59. /package/dist/{chunk-CCQ73DAH.mjs.map → chunk-3ORZOFAL.mjs.map} +0 -0
  60. /package/dist/{chunk-A25FIW7J.mjs.map → chunk-6XDMOAOM.mjs.map} +0 -0
  61. /package/dist/{chunk-C7GTLEVO.mjs.map → chunk-JUWN34MZ.mjs.map} +0 -0
@@ -20,6 +20,7 @@ interface LangGraphEventWithState {
20
20
  lastToolCallId: string | null;
21
21
  lastToolCallName: string | null;
22
22
  currentContent: string | null;
23
+ processedToolCallIds: Set<string>;
23
24
  }
24
25
 
25
26
  export class RemoteLangGraphEventSource {
@@ -87,11 +88,20 @@ export class RemoteLangGraphEventSource {
87
88
  acc.lastMessageId = this.getCurrentMessageId(event) ?? acc.lastMessageId;
88
89
  const toolCallChunks = this.getCurrentToolCallChunks(event) ?? [];
89
90
  const responseMetadata = this.getResponseMetadata(event);
91
+ // Check if a given event is a tool call
92
+ const toolCallCheck = toolCallChunks && toolCallChunks.length > 0;
93
+ let isToolCallEnd = responseMetadata?.finish_reason === "tool_calls";
90
94
 
91
95
  acc.isToolCallStart = toolCallChunks.some((chunk: any) => chunk.name && chunk.id);
92
96
  acc.isMessageStart = prevMessageId !== acc.lastMessageId && !acc.isToolCallStart;
93
- acc.isToolCall = toolCallChunks && toolCallChunks.length > 0;
94
- acc.isToolCallEnd = responseMetadata?.finish_reason === "tool_calls";
97
+
98
+ let previousRoundHadToolCall = acc.isToolCall;
99
+ acc.isToolCall = toolCallCheck;
100
+ // Previous "acc.isToolCall" was set but now it won't pass the check, it means the tool call just ended.
101
+ if (previousRoundHadToolCall && !toolCallCheck) {
102
+ isToolCallEnd = true;
103
+ }
104
+ acc.isToolCallEnd = isToolCallEnd;
95
105
  acc.isMessageEnd = responseMetadata?.finish_reason === "stop";
96
106
  ({ name: acc.lastToolCallName, id: acc.lastToolCallId } = toolCallChunks.find(
97
107
  (chunk: any) => chunk.name && chunk.id,
@@ -112,6 +122,7 @@ export class RemoteLangGraphEventSource {
112
122
  lastToolCallId: null,
113
123
  lastToolCallName: null,
114
124
  currentContent: null,
125
+ processedToolCallIds: new Set<string>(),
115
126
  } as LangGraphEventWithState,
116
127
  ),
117
128
  mergeMap((acc): RuntimeEvent[] => {
@@ -148,9 +159,13 @@ export class RemoteLangGraphEventSource {
148
159
 
149
160
  // Tool call ended: emit ActionExecutionEnd
150
161
  if (
151
- responseMetadata?.finish_reason === "tool_calls" &&
152
- this.shouldEmitToolCall(shouldEmitToolCalls, acc.lastToolCallName)
162
+ acc.isToolCallEnd &&
163
+ this.shouldEmitToolCall(shouldEmitToolCalls, acc.lastToolCallName) &&
164
+ acc.lastToolCallId &&
165
+ !acc.processedToolCallIds.has(acc.lastToolCallId)
153
166
  ) {
167
+ acc.processedToolCallIds.add(acc.lastToolCallId);
168
+
154
169
  events.push({
155
170
  type: RuntimeEventTypes.ActionExecutionEnd,
156
171
  actionExecutionId: acc.lastToolCallId,
@@ -158,7 +173,7 @@ export class RemoteLangGraphEventSource {
158
173
  }
159
174
 
160
175
  // Message ended: emit TextMessageEnd
161
- if (responseMetadata?.finish_reason === "stop" && shouldEmitMessages) {
176
+ else if (responseMetadata?.finish_reason === "stop" && shouldEmitMessages) {
162
177
  events.push({
163
178
  type: RuntimeEventTypes.TextMessageEnd,
164
179
  messageId: acc.lastMessageId,
@@ -236,6 +251,7 @@ export class RemoteLangGraphEventSource {
236
251
  }
237
252
  // Message started: emit TextMessageStart
238
253
  else if (acc.isMessageStart && shouldEmitMessages) {
254
+ acc.processedToolCallIds.clear();
239
255
  events.push({
240
256
  type: RuntimeEventTypes.TextMessageStart,
241
257
  messageId: acc.lastMessageId,
@@ -5,6 +5,7 @@ export enum MessageRole {
5
5
  assistant = "assistant",
6
6
  system = "system",
7
7
  tool = "tool",
8
+ developer = "developer",
8
9
  }
9
10
 
10
11
  export enum CopilotRequestType {
@@ -0,0 +1,236 @@
1
+ import { TextEncoder } from "util";
2
+ import { RemoteLangGraphEventSource } from "../../../agents/langgraph/event-source";
3
+ import telemetry from "../../telemetry-client";
4
+ import {
5
+ constructLGCRemoteAction,
6
+ constructRemoteActions,
7
+ createHeaders,
8
+ } from "../remote-action-constructors";
9
+ import { execute } from "../remote-lg-action";
10
+
11
+ // Mock external dependencies
12
+ jest.mock("../remote-lg-action", () => ({
13
+ execute: jest.fn(),
14
+ }));
15
+
16
+ jest.mock("../../telemetry-client", () => ({
17
+ capture: jest.fn(),
18
+ }));
19
+
20
+ jest.mock("../../../agents/langgraph/event-source", () => ({
21
+ RemoteLangGraphEventSource: jest.fn(),
22
+ }));
23
+
24
+ // Dummy logger
25
+ const logger = {
26
+ debug: jest.fn(),
27
+ error: jest.fn(),
28
+ child: jest.fn(() => logger),
29
+ };
30
+
31
+ // Dummy graphqlContext
32
+ const graphqlContext = { properties: { dummyProp: "value" } } as any;
33
+
34
+ // Dummy agent state
35
+ const agentStates = [{ agentName: "agent1", state: "{}", configurable: "{}" }];
36
+
37
+ // Dummy agent used in constructLGCRemoteAction
38
+ const dummyAgent = { name: "agent1", description: "test agent" };
39
+ const endpoint = {
40
+ agents: [dummyAgent],
41
+ deploymentUrl: "http://dummy.deployment",
42
+ langsmithApiKey: "dummykey",
43
+ };
44
+
45
+ // Clear mocks before each test
46
+ beforeEach(() => {
47
+ jest.clearAllMocks();
48
+ });
49
+
50
+ describe("remote action constructors", () => {
51
+ describe("constructLGCRemoteAction", () => {
52
+ it("should create an agent with langGraphAgentHandler that processes events", async () => {
53
+ // Arrange: simulate execute returning a dummy ReadableStream
54
+ const dummyEncodedEvent = new TextEncoder().encode(JSON.stringify({ event: "test" }) + "\n");
55
+ const readerMock = {
56
+ read: jest
57
+ .fn()
58
+ .mockResolvedValueOnce({ done: false, value: dummyEncodedEvent })
59
+ .mockResolvedValueOnce({ done: true, value: new Uint8Array() }),
60
+ };
61
+
62
+ const dummyResponse = {
63
+ getReader: () => readerMock,
64
+ };
65
+
66
+ (execute as jest.Mock).mockResolvedValue(dummyResponse);
67
+
68
+ // Mock RemoteLangGraphEventSource to return a dummy processed result
69
+ const processLangGraphEventsMock = jest.fn(() => "processed events");
70
+ (RemoteLangGraphEventSource as jest.Mock).mockImplementation(() => ({
71
+ eventStream$: { next: jest.fn(), complete: jest.fn(), error: jest.fn() },
72
+ processLangGraphEvents: processLangGraphEventsMock,
73
+ }));
74
+
75
+ // Act: build the action and call langGraphAgentHandler
76
+ const actions = constructLGCRemoteAction({
77
+ endpoint,
78
+ graphqlContext,
79
+ logger,
80
+ messages: [],
81
+ agentStates,
82
+ });
83
+ expect(actions).toHaveLength(1);
84
+ const action = actions[0];
85
+ expect(action.name).toEqual(dummyAgent.name);
86
+
87
+ const result = await action.langGraphAgentHandler({
88
+ name: dummyAgent.name,
89
+ actionInputsWithoutAgents: [],
90
+ threadId: "thread1",
91
+ nodeName: "node1",
92
+ additionalMessages: [],
93
+ metaEvents: [],
94
+ });
95
+
96
+ // Assert: processLangGraphEvents is called and result returned
97
+ expect(processLangGraphEventsMock).toHaveBeenCalled();
98
+ expect(result).toBe("processed events");
99
+
100
+ // Check telemetry.capture was called with agentExecution true
101
+ expect(telemetry.capture).toHaveBeenCalledWith(
102
+ "oss.runtime.remote_action_executed",
103
+ expect.objectContaining({
104
+ agentExecution: true,
105
+ type: "langgraph-platform",
106
+ agentsAmount: 1,
107
+ }),
108
+ );
109
+ });
110
+ });
111
+
112
+ describe("constructRemoteActions", () => {
113
+ const json = {
114
+ agents: [{ name: "agent2", description: "agent desc" }],
115
+ actions: [
116
+ {
117
+ name: "action1",
118
+ description: "action desc",
119
+ parameters: { param: "value" },
120
+ },
121
+ ],
122
+ };
123
+ const url = "http://dummy.api";
124
+ const onBeforeRequest = jest.fn(() => ({ headers: { Authorization: "Bearer token" } }));
125
+
126
+ it("should create remote action handler that calls fetch and returns the result", async () => {
127
+ // Arrange: mock fetch for action handler
128
+ global.fetch = jest.fn().mockResolvedValue({
129
+ ok: true,
130
+ json: jest.fn().mockResolvedValue({ result: "action result" }),
131
+ });
132
+
133
+ const actionsArray = constructRemoteActions({
134
+ json,
135
+ url,
136
+ onBeforeRequest,
137
+ graphqlContext,
138
+ logger,
139
+ messages: [],
140
+ agentStates,
141
+ });
142
+ // There should be one action (from json.actions) and one agent (from json.agents)
143
+ expect(actionsArray).toHaveLength(2);
144
+ const actionHandler = actionsArray[0].handler;
145
+
146
+ const result = await actionHandler({ foo: "bar" });
147
+ expect(result).toEqual("action result");
148
+
149
+ expect(global.fetch).toHaveBeenCalledWith(
150
+ `${url}/actions/execute`,
151
+ expect.objectContaining({
152
+ method: "POST",
153
+ headers: expect.objectContaining({
154
+ "Content-Type": "application/json",
155
+ Authorization: "Bearer token",
156
+ }),
157
+ body: expect.any(String),
158
+ }),
159
+ );
160
+ });
161
+
162
+ it("should create remote agent handler that processes events", async () => {
163
+ // Arrange: mock fetch for agent handler to return a dummy stream
164
+ const dummyEncodedAgentEvent = new TextEncoder().encode('{"event":"data"}\n');
165
+ const agentReaderMock = {
166
+ read: jest
167
+ .fn()
168
+ .mockResolvedValueOnce({ done: false, value: dummyEncodedAgentEvent })
169
+ .mockResolvedValueOnce({ done: true, value: new Uint8Array() }),
170
+ };
171
+ const dummyStreamResponse = {
172
+ getReader: () => agentReaderMock,
173
+ };
174
+ global.fetch = jest.fn().mockResolvedValue({
175
+ ok: true,
176
+ text: jest.fn().mockResolvedValue("ok"),
177
+ body: dummyStreamResponse,
178
+ });
179
+
180
+ const processLangGraphEventsMock = jest.fn(() => "agent events processed");
181
+ (RemoteLangGraphEventSource as jest.Mock).mockImplementation(() => ({
182
+ eventStream$: { next: jest.fn(), complete: jest.fn(), error: jest.fn() },
183
+ processLangGraphEvents: processLangGraphEventsMock,
184
+ }));
185
+
186
+ const actionsArray = constructRemoteActions({
187
+ json,
188
+ url,
189
+ onBeforeRequest,
190
+ graphqlContext,
191
+ logger,
192
+ messages: [],
193
+ agentStates,
194
+ });
195
+ // The remote agent is the second item in the array
196
+ expect(actionsArray).toHaveLength(2);
197
+ const remoteAgentHandler = (actionsArray[1] as any).langGraphAgentHandler;
198
+ const result = await remoteAgentHandler({
199
+ name: "agent2",
200
+ actionInputsWithoutAgents: [],
201
+ threadId: "thread2",
202
+ nodeName: "node2",
203
+ additionalMessages: [],
204
+ metaEvents: [],
205
+ });
206
+ expect(processLangGraphEventsMock).toHaveBeenCalled();
207
+ expect(result).toBe("agent events processed");
208
+
209
+ // Check telemetry.capture for agent execution
210
+ expect(telemetry.capture).toHaveBeenCalledWith(
211
+ "oss.runtime.remote_action_executed",
212
+ expect.objectContaining({
213
+ agentExecution: true,
214
+ type: "self-hosted",
215
+ agentsAmount: 1,
216
+ }),
217
+ );
218
+ });
219
+ });
220
+
221
+ describe("createHeaders", () => {
222
+ it("should merge headers from onBeforeRequest", () => {
223
+ const onBeforeRequest = jest.fn(() => ({ headers: { "X-Test": "123" } }));
224
+ const headers = createHeaders(onBeforeRequest, graphqlContext);
225
+ expect(headers).toEqual({
226
+ "Content-Type": "application/json",
227
+ "X-Test": "123",
228
+ });
229
+ });
230
+
231
+ it("should return only Content-Type if no additional headers", () => {
232
+ const headers = createHeaders(undefined, graphqlContext);
233
+ expect(headers).toEqual({ "Content-Type": "application/json" });
234
+ });
235
+ });
236
+ });
@@ -153,7 +153,7 @@ export interface CopilotRuntimeConstructorParams<T extends Parameter[] | [] = []
153
153
  middleware?: Middleware;
154
154
 
155
155
  /*
156
- * A list of server side actions that can be executed.
156
+ * A list of server side actions that can be executed. Will be ignored when remoteActions are set
157
157
  */
158
158
  actions?: ActionsConfiguration<T>;
159
159
 
@@ -190,7 +190,13 @@ export class CopilotRuntime<const T extends Parameter[] | [] = []> {
190
190
  private delegateAgentProcessingToServiceAdapter: boolean;
191
191
 
192
192
  constructor(params?: CopilotRuntimeConstructorParams<T>) {
193
- this.actions = params?.actions || [];
193
+ // Do not register actions if endpoints are set
194
+ if (params?.actions && params?.remoteEndpoints) {
195
+ console.warn("Actions set in runtime instance will be ignored when remote endpoints are set");
196
+ this.actions = [];
197
+ } else {
198
+ this.actions = params?.actions || [];
199
+ }
194
200
 
195
201
  for (const chain of params?.langserve || []) {
196
202
  const remoteChain = new RemoteChain(chain);
@@ -572,7 +578,7 @@ please use an LLM adapter instead.`,
572
578
  threadId,
573
579
  runId: undefined,
574
580
  eventSource,
575
- serverSideActions: [],
581
+ serverSideActions,
576
582
  actionInputsWithoutAgents: allAvailableActions,
577
583
  };
578
584
  } catch (error) {
@@ -18,6 +18,8 @@ import { LangGraphEvent } from "../../agents/langgraph/events";
18
18
  import { execute } from "./remote-lg-action";
19
19
  import { CopilotKitError, CopilotKitLowLevelError } from "@copilotkit/shared";
20
20
  import { CopilotKitApiDiscoveryError, ResolvedCopilotKitError } from "@copilotkit/shared";
21
+ import { parseJson, tryMap } from "@copilotkit/shared";
22
+ import { ActionInput } from "../../graphql/inputs/action.input";
21
23
 
22
24
  export function constructLGCRemoteAction({
23
25
  endpoint,
@@ -59,8 +61,8 @@ export function constructLGCRemoteAction({
59
61
  if (agentStates) {
60
62
  const jsonState = agentStates.find((state) => state.agentName === name);
61
63
  if (jsonState) {
62
- state = JSON.parse(jsonState.state);
63
- configurable = JSON.parse(jsonState.configurable);
64
+ state = parseJson(jsonState.state, {});
65
+ configurable = parseJson(jsonState.configurable, {});
64
66
  }
65
67
  }
66
68
 
@@ -76,10 +78,10 @@ export function constructLGCRemoteAction({
76
78
  state,
77
79
  configurable,
78
80
  properties: graphqlContext.properties,
79
- actions: actionInputsWithoutAgents.map((action) => ({
81
+ actions: tryMap(actionInputsWithoutAgents, (action: ActionInput) => ({
80
82
  name: action.name,
81
83
  description: action.description,
82
- parameters: JSON.parse(action.jsonSchema) as string,
84
+ parameters: JSON.parse(action.jsonSchema),
83
85
  })),
84
86
  metaEvents,
85
87
  });
@@ -203,8 +205,8 @@ export function constructRemoteActions({
203
205
  if (agentStates) {
204
206
  const jsonState = agentStates.find((state) => state.agentName === name);
205
207
  if (jsonState) {
206
- state = JSON.parse(jsonState.state);
207
- configurable = JSON.parse(jsonState.configurable);
208
+ state = parseJson(jsonState.state, {});
209
+ configurable = parseJson(jsonState.configurable, {});
208
210
  }
209
211
  }
210
212
 
@@ -221,7 +223,7 @@ export function constructRemoteActions({
221
223
  state,
222
224
  configurable,
223
225
  properties: graphqlContext.properties,
224
- actions: actionInputsWithoutAgents.map((action) => ({
226
+ actions: tryMap(actionInputsWithoutAgents, (action: ActionInput) => ({
225
227
  name: action.name,
226
228
  description: action.description,
227
229
  parameters: JSON.parse(action.jsonSchema),
@@ -1,4 +1,4 @@
1
- import { Client as LangGraphClient } from "@langchain/langgraph-sdk";
1
+ import { AssistantGraph, Client as LangGraphClient, GraphSchema } from "@langchain/langgraph-sdk";
2
2
  import { createHash } from "node:crypto";
3
3
  import { isValidUUID, randomUUID } from "@copilotkit/shared";
4
4
  import { parse as parsePartialJson } from "partial-json";
@@ -13,6 +13,7 @@ import telemetry from "../telemetry-client";
13
13
  import { MetaEventInput } from "../../graphql/inputs/meta-event.input";
14
14
  import { MetaEventName } from "../../graphql/types/meta-events.type";
15
15
  import { RunsStreamPayload } from "@langchain/langgraph-sdk/dist/types";
16
+ import { parseJson } from "@copilotkit/shared";
16
17
 
17
18
  type State = Record<string, any>;
18
19
 
@@ -164,12 +165,7 @@ async function streamEvents(controller: ReadableStreamDefaultController, args: E
164
165
  }
165
166
  if (lgInterruptMetaEvent?.response) {
166
167
  let response = lgInterruptMetaEvent.response;
167
- try {
168
- payload.command = { resume: JSON.parse(response) };
169
- // In case of unparsable string, we keep the event as is
170
- } catch (e) {
171
- payload.command = { resume: response };
172
- }
168
+ payload.command = { resume: parseJson(response, response) };
173
169
  }
174
170
 
175
171
  if (mode === "continue" && !activeInterruptEvent) {
@@ -208,6 +204,15 @@ async function streamEvents(controller: ReadableStreamDefaultController, args: E
208
204
  await client.assistants.update(assistantId, { config: { configurable } });
209
205
  }
210
206
  const graphInfo = await client.assistants.getGraph(assistantId);
207
+ const graphSchema = await client.assistants.getSchemas(assistantId);
208
+ const schemaKeys = getSchemaKeys(graphSchema);
209
+
210
+ // Do not input keys that are not part of the input schema
211
+ if (payload.input && schemaKeys.input) {
212
+ payload.input = Object.fromEntries(
213
+ Object.entries(payload.input).filter(([key]) => schemaKeys.input.includes(key)),
214
+ );
215
+ }
211
216
 
212
217
  let streamingStateExtractor = new StreamingStateExtractor([]);
213
218
  let prevNodeName = null;
@@ -334,6 +339,7 @@ async function streamEvents(controller: ReadableStreamDefaultController, args: E
334
339
  state: manuallyEmittedState,
335
340
  running: true,
336
341
  active: true,
342
+ schemaKeys,
337
343
  }),
338
344
  );
339
345
  continue;
@@ -384,6 +390,7 @@ async function streamEvents(controller: ReadableStreamDefaultController, args: E
384
390
  state,
385
391
  running: true,
386
392
  active: !exitingNode,
393
+ schemaKeys,
387
394
  }),
388
395
  );
389
396
  }
@@ -408,6 +415,7 @@ async function streamEvents(controller: ReadableStreamDefaultController, args: E
408
415
  running: !shouldExit,
409
416
  active: false,
410
417
  includeMessages: true,
418
+ schemaKeys,
411
419
  }),
412
420
  );
413
421
 
@@ -431,6 +439,7 @@ function getStateSyncEvent({
431
439
  running,
432
440
  active,
433
441
  includeMessages = false,
442
+ schemaKeys,
434
443
  }: {
435
444
  threadId: string;
436
445
  runId: string;
@@ -440,6 +449,7 @@ function getStateSyncEvent({
440
449
  running: boolean;
441
450
  active: boolean;
442
451
  includeMessages?: boolean;
452
+ schemaKeys: { input: string[] | null; output: string[] | null };
443
453
  }): string {
444
454
  if (!includeMessages) {
445
455
  state = Object.keys(state).reduce((acc, key) => {
@@ -455,6 +465,13 @@ function getStateSyncEvent({
455
465
  };
456
466
  }
457
467
 
468
+ // Do not emit state keys that are not part of the output schema
469
+ if (schemaKeys.output) {
470
+ state = Object.fromEntries(
471
+ Object.entries(state).filter(([key]) => schemaKeys.output.includes(key)),
472
+ );
473
+ }
474
+
458
475
  return (
459
476
  JSON.stringify({
460
477
  event: LangGraphEventTypes.OnCopilotKitStateSync,
@@ -743,3 +760,14 @@ function copilotkitMessagesToLangChain(messages: Message[]): LangGraphPlatformMe
743
760
 
744
761
  return result;
745
762
  }
763
+
764
+ function getSchemaKeys(graphSchema: GraphSchema) {
765
+ const CONSTANT_KEYS = ["messages", "copilotkit"];
766
+ const inputSchema = Object.keys(graphSchema.input_schema.properties);
767
+ const outputSchema = Object.keys(graphSchema.output_schema.properties);
768
+
769
+ return {
770
+ input: inputSchema && inputSchema.length ? [...inputSchema, ...CONSTANT_KEYS] : null,
771
+ output: outputSchema && outputSchema.length ? [...outputSchema, ...CONSTANT_KEYS] : null,
772
+ };
773
+ }
@@ -7,58 +7,51 @@ import {
7
7
  } from "../graphql/types/converted";
8
8
  import { MessageInput } from "../graphql/inputs/message.input";
9
9
  import { plainToInstance } from "class-transformer";
10
+ import { tryMap } from "@copilotkit/shared";
10
11
 
11
12
  export function convertGqlInputToMessages(inputMessages: MessageInput[]): Message[] {
12
- const messages: Message[] = [];
13
-
14
- for (const message of inputMessages) {
13
+ const messages = tryMap(inputMessages, (message) => {
15
14
  if (message.textMessage) {
16
- messages.push(
17
- plainToInstance(TextMessage, {
18
- id: message.id,
19
- createdAt: message.createdAt,
20
- role: message.textMessage.role,
21
- content: message.textMessage.content,
22
- parentMessageId: message.textMessage.parentMessageId,
23
- }),
24
- );
15
+ return plainToInstance(TextMessage, {
16
+ id: message.id,
17
+ createdAt: message.createdAt,
18
+ role: message.textMessage.role,
19
+ content: message.textMessage.content,
20
+ parentMessageId: message.textMessage.parentMessageId,
21
+ });
25
22
  } else if (message.actionExecutionMessage) {
26
- messages.push(
27
- plainToInstance(ActionExecutionMessage, {
28
- id: message.id,
29
- createdAt: message.createdAt,
30
- name: message.actionExecutionMessage.name,
31
- arguments: JSON.parse(message.actionExecutionMessage.arguments),
32
- parentMessageId: message.actionExecutionMessage.parentMessageId,
33
- }),
34
- );
23
+ return plainToInstance(ActionExecutionMessage, {
24
+ id: message.id,
25
+ createdAt: message.createdAt,
26
+ name: message.actionExecutionMessage.name,
27
+ arguments: JSON.parse(message.actionExecutionMessage.arguments),
28
+ parentMessageId: message.actionExecutionMessage.parentMessageId,
29
+ });
35
30
  } else if (message.resultMessage) {
36
- messages.push(
37
- plainToInstance(ResultMessage, {
38
- id: message.id,
39
- createdAt: message.createdAt,
40
- actionExecutionId: message.resultMessage.actionExecutionId,
41
- actionName: message.resultMessage.actionName,
42
- result: message.resultMessage.result,
43
- }),
44
- );
31
+ return plainToInstance(ResultMessage, {
32
+ id: message.id,
33
+ createdAt: message.createdAt,
34
+ actionExecutionId: message.resultMessage.actionExecutionId,
35
+ actionName: message.resultMessage.actionName,
36
+ result: message.resultMessage.result,
37
+ });
45
38
  } else if (message.agentStateMessage) {
46
- messages.push(
47
- plainToInstance(AgentStateMessage, {
48
- id: message.id,
49
- threadId: message.agentStateMessage.threadId,
50
- createdAt: message.createdAt,
51
- agentName: message.agentStateMessage.agentName,
52
- nodeName: message.agentStateMessage.nodeName,
53
- runId: message.agentStateMessage.runId,
54
- active: message.agentStateMessage.active,
55
- role: message.agentStateMessage.role,
56
- state: JSON.parse(message.agentStateMessage.state),
57
- running: message.agentStateMessage.running,
58
- }),
59
- );
39
+ return plainToInstance(AgentStateMessage, {
40
+ id: message.id,
41
+ threadId: message.agentStateMessage.threadId,
42
+ createdAt: message.createdAt,
43
+ agentName: message.agentStateMessage.agentName,
44
+ nodeName: message.agentStateMessage.nodeName,
45
+ runId: message.agentStateMessage.runId,
46
+ active: message.agentStateMessage.active,
47
+ role: message.agentStateMessage.role,
48
+ state: JSON.parse(message.agentStateMessage.state),
49
+ running: message.agentStateMessage.running,
50
+ });
51
+ } else {
52
+ return null;
60
53
  }
61
- }
54
+ });
62
55
 
63
- return messages;
56
+ return messages.filter((m) => m);
64
57
  }
@@ -15,6 +15,7 @@
15
15
  * ```
16
16
  */
17
17
  import { Groq } from "groq-sdk";
18
+ import type { ChatCompletionMessageParam } from "groq-sdk/resources/chat";
18
19
  import {
19
20
  CopilotServiceAdapter,
20
21
  CopilotRuntimeChatCompletionRequest,
@@ -27,7 +28,7 @@ import {
27
28
  } from "../openai/utils";
28
29
  import { randomUUID } from "@copilotkit/shared";
29
30
 
30
- const DEFAULT_MODEL = "llama3-groq-70b-8192-tool-use-preview";
31
+ const DEFAULT_MODEL = "llama-3.3-70b-versatile";
31
32
 
32
33
  export interface GroqAdapterParams {
33
34
  /**
@@ -81,7 +82,9 @@ export class GroqAdapter implements CopilotServiceAdapter {
81
82
  } = request;
82
83
  const tools = actions.map(convertActionInputToOpenAITool);
83
84
 
84
- let openaiMessages = messages.map(convertMessageToOpenAIMessage);
85
+ let openaiMessages = messages.map((m) =>
86
+ convertMessageToOpenAIMessage(m, { keepSystemRole: true }),
87
+ );
85
88
  openaiMessages = limitMessagesToTokenCount(openaiMessages, tools, model);
86
89
 
87
90
  let toolChoice: any = forwardedParameters?.toolChoice;
@@ -94,7 +97,7 @@ export class GroqAdapter implements CopilotServiceAdapter {
94
97
  const stream = await this.groq.chat.completions.create({
95
98
  model: model,
96
99
  stream: true,
97
- messages: openaiMessages,
100
+ messages: openaiMessages as unknown as ChatCompletionMessageParam[],
98
101
  ...(tools.length > 0 && { tools }),
99
102
  ...(forwardedParameters?.maxTokens && {
100
103
  max_tokens: forwardedParameters.maxTokens,
@@ -117,7 +117,7 @@ export class OpenAIAdapter implements CopilotServiceAdapter {
117
117
  const tools = actions.map(convertActionInputToOpenAITool);
118
118
  const threadId = threadIdFromRequest ?? randomUUID();
119
119
 
120
- let openaiMessages = messages.map(convertMessageToOpenAIMessage);
120
+ let openaiMessages = messages.map((m) => convertMessageToOpenAIMessage(m));
121
121
  openaiMessages = limitMessagesToTokenCount(openaiMessages, tools, model);
122
122
 
123
123
  let toolChoice: any = forwardedParameters?.toolChoice;
@@ -198,7 +198,7 @@ export class OpenAIAssistantAdapter implements CopilotServiceAdapter {
198
198
 
199
199
  // get the latest user message
200
200
  const userMessage = messages
201
- .map(convertMessageToOpenAIMessage)
201
+ .map((m) => convertMessageToOpenAIMessage(m))
202
202
  .map(convertSystemMessageToAssistantAPI)
203
203
  .at(-1);
204
204