@mastra/ai-sdk 1.1.0 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +18 -0
- package/LICENSE.md +15 -0
- package/dist/index.cjs +199 -0
- package/dist/index.cjs.map +1 -1
- package/dist/index.js +199 -0
- package/dist/index.js.map +1 -1
- package/dist/middleware.d.ts.map +1 -1
- package/package.json +6 -5
package/dist/index.js
CHANGED
|
@@ -4616,6 +4616,142 @@ function withMastra(model, options = {}) {
|
|
|
4616
4616
|
})
|
|
4617
4617
|
});
|
|
4618
4618
|
}
|
|
4619
|
+
var StreamOutputAccumulator = class {
|
|
4620
|
+
/** Ordered sequence of part placeholders: either a text buffer index or a tool call ID */
|
|
4621
|
+
partOrder = [];
|
|
4622
|
+
textBuffers = [];
|
|
4623
|
+
toolStates = /* @__PURE__ */ new Map();
|
|
4624
|
+
addChunk(chunk) {
|
|
4625
|
+
switch (chunk.type) {
|
|
4626
|
+
case "text-delta":
|
|
4627
|
+
if (chunk.payload.text) {
|
|
4628
|
+
this.appendTextDelta(chunk.payload.text);
|
|
4629
|
+
}
|
|
4630
|
+
return;
|
|
4631
|
+
case "tool-call-input-streaming-start": {
|
|
4632
|
+
const state = this.ensureToolState(chunk.payload.toolCallId);
|
|
4633
|
+
state.toolName = state.toolName || chunk.payload.toolName;
|
|
4634
|
+
state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
|
|
4635
|
+
if (!state.argDeltas) {
|
|
4636
|
+
state.argDeltas = [];
|
|
4637
|
+
}
|
|
4638
|
+
return;
|
|
4639
|
+
}
|
|
4640
|
+
case "tool-call-delta": {
|
|
4641
|
+
const state = this.ensureToolState(chunk.payload.toolCallId);
|
|
4642
|
+
if (chunk.payload.argsTextDelta) {
|
|
4643
|
+
if (!state.argDeltas) state.argDeltas = [];
|
|
4644
|
+
state.argDeltas.push(chunk.payload.argsTextDelta);
|
|
4645
|
+
}
|
|
4646
|
+
return;
|
|
4647
|
+
}
|
|
4648
|
+
case "tool-call-input-streaming-end": {
|
|
4649
|
+
const state = this.ensureToolState(chunk.payload.toolCallId);
|
|
4650
|
+
this.finalizeToolArgs(state);
|
|
4651
|
+
return;
|
|
4652
|
+
}
|
|
4653
|
+
case "tool-call": {
|
|
4654
|
+
const state = this.ensureToolState(chunk.payload.toolCallId);
|
|
4655
|
+
state.toolName = state.toolName || chunk.payload.toolName;
|
|
4656
|
+
if (chunk.payload.args !== void 0) {
|
|
4657
|
+
state.args = chunk.payload.args;
|
|
4658
|
+
} else {
|
|
4659
|
+
this.finalizeToolArgs(state);
|
|
4660
|
+
}
|
|
4661
|
+
state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
|
|
4662
|
+
return;
|
|
4663
|
+
}
|
|
4664
|
+
case "tool-result": {
|
|
4665
|
+
const state = this.ensureToolState(chunk.payload.toolCallId);
|
|
4666
|
+
state.toolName = state.toolName || chunk.payload.toolName;
|
|
4667
|
+
state.result = chunk.payload.result;
|
|
4668
|
+
if (state.args === void 0) {
|
|
4669
|
+
this.finalizeToolArgs(state);
|
|
4670
|
+
}
|
|
4671
|
+
state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
|
|
4672
|
+
return;
|
|
4673
|
+
}
|
|
4674
|
+
default:
|
|
4675
|
+
return;
|
|
4676
|
+
}
|
|
4677
|
+
}
|
|
4678
|
+
buildResponseMessage(memory) {
|
|
4679
|
+
if (this.partOrder.length === 0) {
|
|
4680
|
+
return null;
|
|
4681
|
+
}
|
|
4682
|
+
const parts = [];
|
|
4683
|
+
const textSegments = [];
|
|
4684
|
+
for (const entry of this.partOrder) {
|
|
4685
|
+
if (entry.kind === "text") {
|
|
4686
|
+
const text2 = this.textBuffers[entry.bufferIndex].join("");
|
|
4687
|
+
parts.push({ type: "text", text: text2 });
|
|
4688
|
+
textSegments.push(text2);
|
|
4689
|
+
} else {
|
|
4690
|
+
const state = this.toolStates.get(entry.toolCallId);
|
|
4691
|
+
this.finalizeToolArgs(state);
|
|
4692
|
+
parts.push(this.buildToolPart(state));
|
|
4693
|
+
}
|
|
4694
|
+
}
|
|
4695
|
+
const textContent = textSegments.join("");
|
|
4696
|
+
const content = {
|
|
4697
|
+
format: 2,
|
|
4698
|
+
parts,
|
|
4699
|
+
...textContent ? { content: textContent } : {}
|
|
4700
|
+
};
|
|
4701
|
+
return {
|
|
4702
|
+
id: crypto.randomUUID(),
|
|
4703
|
+
role: "assistant",
|
|
4704
|
+
content,
|
|
4705
|
+
createdAt: /* @__PURE__ */ new Date(),
|
|
4706
|
+
...memory?.threadId && { threadId: memory.threadId },
|
|
4707
|
+
...memory?.resourceId && { resourceId: memory.resourceId }
|
|
4708
|
+
};
|
|
4709
|
+
}
|
|
4710
|
+
ensureToolState(toolCallId) {
|
|
4711
|
+
let state = this.toolStates.get(toolCallId);
|
|
4712
|
+
if (!state) {
|
|
4713
|
+
state = { toolCallId };
|
|
4714
|
+
this.toolStates.set(toolCallId, state);
|
|
4715
|
+
this.partOrder.push({ kind: "tool", toolCallId });
|
|
4716
|
+
}
|
|
4717
|
+
return state;
|
|
4718
|
+
}
|
|
4719
|
+
finalizeToolArgs(state) {
|
|
4720
|
+
if (state.args !== void 0 || !state.argDeltas?.length) {
|
|
4721
|
+
return;
|
|
4722
|
+
}
|
|
4723
|
+
try {
|
|
4724
|
+
state.args = JSON.parse(state.argDeltas.join(""));
|
|
4725
|
+
} catch {
|
|
4726
|
+
return;
|
|
4727
|
+
}
|
|
4728
|
+
}
|
|
4729
|
+
appendTextDelta(text2) {
|
|
4730
|
+
const last = this.partOrder[this.partOrder.length - 1];
|
|
4731
|
+
if (last?.kind === "text") {
|
|
4732
|
+
this.textBuffers[last.bufferIndex].push(text2);
|
|
4733
|
+
return;
|
|
4734
|
+
}
|
|
4735
|
+
const bufferIndex = this.textBuffers.length;
|
|
4736
|
+
this.textBuffers.push([text2]);
|
|
4737
|
+
this.partOrder.push({ kind: "text", bufferIndex });
|
|
4738
|
+
}
|
|
4739
|
+
buildToolPart(state) {
|
|
4740
|
+
const hasResult = state.result !== void 0;
|
|
4741
|
+
const toolInvocation = {
|
|
4742
|
+
state: hasResult ? "result" : "call",
|
|
4743
|
+
toolCallId: state.toolCallId,
|
|
4744
|
+
toolName: state.toolName || "unknown",
|
|
4745
|
+
args: state.args ?? {},
|
|
4746
|
+
...hasResult ? { result: state.result } : {}
|
|
4747
|
+
};
|
|
4748
|
+
return {
|
|
4749
|
+
type: "tool-invocation",
|
|
4750
|
+
toolInvocation,
|
|
4751
|
+
...state.providerMetadata ? { providerMetadata: state.providerMetadata } : {}
|
|
4752
|
+
};
|
|
4753
|
+
}
|
|
4754
|
+
};
|
|
4619
4755
|
function createProcessorMiddleware(options) {
|
|
4620
4756
|
const { inputProcessors = [], outputProcessors = [], memory } = options;
|
|
4621
4757
|
const requestContext = new RequestContext();
|
|
@@ -4737,6 +4873,7 @@ function createProcessorMiddleware(options) {
|
|
|
4737
4873
|
await processor.processOutputResult({
|
|
4738
4874
|
messages: messageList.get.all.db(),
|
|
4739
4875
|
messageList,
|
|
4876
|
+
state: {},
|
|
4740
4877
|
requestContext,
|
|
4741
4878
|
abort: (reason) => {
|
|
4742
4879
|
throw new TripWire(reason || "Aborted by processor");
|
|
@@ -4774,8 +4911,12 @@ function createProcessorMiddleware(options) {
|
|
|
4774
4911
|
}
|
|
4775
4912
|
const { stream, ...rest } = await doStream();
|
|
4776
4913
|
if (!outputProcessors.length) return { stream, ...rest };
|
|
4914
|
+
const outputResultProcessors = outputProcessors.filter((processor) => processor.processOutputResult);
|
|
4915
|
+
const streamAccumulator = outputResultProcessors.length ? new StreamOutputAccumulator() : null;
|
|
4777
4916
|
const processorStates = /* @__PURE__ */ new Map();
|
|
4778
4917
|
const runId = crypto.randomUUID();
|
|
4918
|
+
let streamAborted = false;
|
|
4919
|
+
let sawFinish = false;
|
|
4779
4920
|
const transformedStream = stream.pipeThrough(
|
|
4780
4921
|
new TransformStream$1({
|
|
4781
4922
|
async transform(chunk, controller) {
|
|
@@ -4812,6 +4953,7 @@ function createProcessorMiddleware(options) {
|
|
|
4812
4953
|
}
|
|
4813
4954
|
} catch (error) {
|
|
4814
4955
|
if (error instanceof TripWire) {
|
|
4956
|
+
streamAborted = true;
|
|
4815
4957
|
controller.enqueue({
|
|
4816
4958
|
type: "error",
|
|
4817
4959
|
error: new Error(error.message)
|
|
@@ -4823,12 +4965,69 @@ function createProcessorMiddleware(options) {
|
|
|
4823
4965
|
}
|
|
4824
4966
|
}
|
|
4825
4967
|
}
|
|
4968
|
+
if (mastraChunk) {
|
|
4969
|
+
if (mastraChunk.type === "finish") {
|
|
4970
|
+
sawFinish = true;
|
|
4971
|
+
}
|
|
4972
|
+
if (streamAccumulator) {
|
|
4973
|
+
streamAccumulator.addChunk(mastraChunk);
|
|
4974
|
+
}
|
|
4975
|
+
}
|
|
4826
4976
|
if (mastraChunk) {
|
|
4827
4977
|
const aiChunk = convertMastraChunkToAISDKStreamPart(mastraChunk);
|
|
4828
4978
|
if (aiChunk) {
|
|
4829
4979
|
controller.enqueue(aiChunk);
|
|
4830
4980
|
}
|
|
4831
4981
|
}
|
|
4982
|
+
},
|
|
4983
|
+
async flush(controller) {
|
|
4984
|
+
if (!streamAccumulator || streamAborted || !sawFinish) {
|
|
4985
|
+
return;
|
|
4986
|
+
}
|
|
4987
|
+
const messageList = new MessageList({
|
|
4988
|
+
threadId: memory?.threadId,
|
|
4989
|
+
resourceId: memory?.resourceId
|
|
4990
|
+
});
|
|
4991
|
+
const flushOriginalInputCount = processorState?.originalInputCount ?? params.prompt.filter((m) => m.role !== "system").length;
|
|
4992
|
+
const flushNonSystemTotal = params.prompt.filter((m) => m.role !== "system").length;
|
|
4993
|
+
const flushMemoryCount = flushNonSystemTotal - flushOriginalInputCount;
|
|
4994
|
+
let flushNonSystemIndex = 0;
|
|
4995
|
+
for (const msg of params.prompt) {
|
|
4996
|
+
if (msg.role === "system") {
|
|
4997
|
+
messageList.addSystem(msg.content);
|
|
4998
|
+
} else {
|
|
4999
|
+
messageList.add(msg, flushNonSystemIndex < flushMemoryCount ? "memory" : "input");
|
|
5000
|
+
flushNonSystemIndex++;
|
|
5001
|
+
}
|
|
5002
|
+
}
|
|
5003
|
+
const responseMessage = streamAccumulator.buildResponseMessage(memory);
|
|
5004
|
+
if (responseMessage) {
|
|
5005
|
+
messageList.add(responseMessage, "response");
|
|
5006
|
+
}
|
|
5007
|
+
for (const processor of outputResultProcessors) {
|
|
5008
|
+
if (!processor.processOutputResult) continue;
|
|
5009
|
+
try {
|
|
5010
|
+
const procState = processorStates.get(processor.id);
|
|
5011
|
+
await processor.processOutputResult({
|
|
5012
|
+
messages: messageList.get.all.db(),
|
|
5013
|
+
messageList,
|
|
5014
|
+
state: procState?.customState ?? {},
|
|
5015
|
+
requestContext,
|
|
5016
|
+
abort: (reason) => {
|
|
5017
|
+
throw new TripWire(reason || "Aborted by processor");
|
|
5018
|
+
}
|
|
5019
|
+
});
|
|
5020
|
+
} catch (error) {
|
|
5021
|
+
if (error instanceof TripWire) {
|
|
5022
|
+
controller.enqueue({
|
|
5023
|
+
type: "error",
|
|
5024
|
+
error: new Error(error.message)
|
|
5025
|
+
});
|
|
5026
|
+
return;
|
|
5027
|
+
}
|
|
5028
|
+
throw error;
|
|
5029
|
+
}
|
|
5030
|
+
}
|
|
4832
5031
|
}
|
|
4833
5032
|
})
|
|
4834
5033
|
);
|