@mastra/ai-sdk 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -4616,6 +4616,142 @@ function withMastra(model, options = {}) {
4616
4616
  })
4617
4617
  });
4618
4618
  }
4619
+ var StreamOutputAccumulator = class {
4620
+ /** Ordered sequence of part placeholders: either a text buffer index or a tool call ID */
4621
+ partOrder = [];
4622
+ textBuffers = [];
4623
+ toolStates = /* @__PURE__ */ new Map();
4624
+ addChunk(chunk) {
4625
+ switch (chunk.type) {
4626
+ case "text-delta":
4627
+ if (chunk.payload.text) {
4628
+ this.appendTextDelta(chunk.payload.text);
4629
+ }
4630
+ return;
4631
+ case "tool-call-input-streaming-start": {
4632
+ const state = this.ensureToolState(chunk.payload.toolCallId);
4633
+ state.toolName = state.toolName || chunk.payload.toolName;
4634
+ state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
4635
+ if (!state.argDeltas) {
4636
+ state.argDeltas = [];
4637
+ }
4638
+ return;
4639
+ }
4640
+ case "tool-call-delta": {
4641
+ const state = this.ensureToolState(chunk.payload.toolCallId);
4642
+ if (chunk.payload.argsTextDelta) {
4643
+ if (!state.argDeltas) state.argDeltas = [];
4644
+ state.argDeltas.push(chunk.payload.argsTextDelta);
4645
+ }
4646
+ return;
4647
+ }
4648
+ case "tool-call-input-streaming-end": {
4649
+ const state = this.ensureToolState(chunk.payload.toolCallId);
4650
+ this.finalizeToolArgs(state);
4651
+ return;
4652
+ }
4653
+ case "tool-call": {
4654
+ const state = this.ensureToolState(chunk.payload.toolCallId);
4655
+ state.toolName = state.toolName || chunk.payload.toolName;
4656
+ if (chunk.payload.args !== void 0) {
4657
+ state.args = chunk.payload.args;
4658
+ } else {
4659
+ this.finalizeToolArgs(state);
4660
+ }
4661
+ state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
4662
+ return;
4663
+ }
4664
+ case "tool-result": {
4665
+ const state = this.ensureToolState(chunk.payload.toolCallId);
4666
+ state.toolName = state.toolName || chunk.payload.toolName;
4667
+ state.result = chunk.payload.result;
4668
+ if (state.args === void 0) {
4669
+ this.finalizeToolArgs(state);
4670
+ }
4671
+ state.providerMetadata = chunk.payload.providerMetadata || state.providerMetadata;
4672
+ return;
4673
+ }
4674
+ default:
4675
+ return;
4676
+ }
4677
+ }
4678
+ buildResponseMessage(memory) {
4679
+ if (this.partOrder.length === 0) {
4680
+ return null;
4681
+ }
4682
+ const parts = [];
4683
+ const textSegments = [];
4684
+ for (const entry of this.partOrder) {
4685
+ if (entry.kind === "text") {
4686
+ const text2 = this.textBuffers[entry.bufferIndex].join("");
4687
+ parts.push({ type: "text", text: text2 });
4688
+ textSegments.push(text2);
4689
+ } else {
4690
+ const state = this.toolStates.get(entry.toolCallId);
4691
+ this.finalizeToolArgs(state);
4692
+ parts.push(this.buildToolPart(state));
4693
+ }
4694
+ }
4695
+ const textContent = textSegments.join("");
4696
+ const content = {
4697
+ format: 2,
4698
+ parts,
4699
+ ...textContent ? { content: textContent } : {}
4700
+ };
4701
+ return {
4702
+ id: crypto.randomUUID(),
4703
+ role: "assistant",
4704
+ content,
4705
+ createdAt: /* @__PURE__ */ new Date(),
4706
+ ...memory?.threadId && { threadId: memory.threadId },
4707
+ ...memory?.resourceId && { resourceId: memory.resourceId }
4708
+ };
4709
+ }
4710
+ ensureToolState(toolCallId) {
4711
+ let state = this.toolStates.get(toolCallId);
4712
+ if (!state) {
4713
+ state = { toolCallId };
4714
+ this.toolStates.set(toolCallId, state);
4715
+ this.partOrder.push({ kind: "tool", toolCallId });
4716
+ }
4717
+ return state;
4718
+ }
4719
+ finalizeToolArgs(state) {
4720
+ if (state.args !== void 0 || !state.argDeltas?.length) {
4721
+ return;
4722
+ }
4723
+ try {
4724
+ state.args = JSON.parse(state.argDeltas.join(""));
4725
+ } catch {
4726
+ return;
4727
+ }
4728
+ }
4729
+ appendTextDelta(text2) {
4730
+ const last = this.partOrder[this.partOrder.length - 1];
4731
+ if (last?.kind === "text") {
4732
+ this.textBuffers[last.bufferIndex].push(text2);
4733
+ return;
4734
+ }
4735
+ const bufferIndex = this.textBuffers.length;
4736
+ this.textBuffers.push([text2]);
4737
+ this.partOrder.push({ kind: "text", bufferIndex });
4738
+ }
4739
+ buildToolPart(state) {
4740
+ const hasResult = state.result !== void 0;
4741
+ const toolInvocation = {
4742
+ state: hasResult ? "result" : "call",
4743
+ toolCallId: state.toolCallId,
4744
+ toolName: state.toolName || "unknown",
4745
+ args: state.args ?? {},
4746
+ ...hasResult ? { result: state.result } : {}
4747
+ };
4748
+ return {
4749
+ type: "tool-invocation",
4750
+ toolInvocation,
4751
+ ...state.providerMetadata ? { providerMetadata: state.providerMetadata } : {}
4752
+ };
4753
+ }
4754
+ };
4619
4755
  function createProcessorMiddleware(options) {
4620
4756
  const { inputProcessors = [], outputProcessors = [], memory } = options;
4621
4757
  const requestContext = new RequestContext();
@@ -4737,6 +4873,7 @@ function createProcessorMiddleware(options) {
4737
4873
  await processor.processOutputResult({
4738
4874
  messages: messageList.get.all.db(),
4739
4875
  messageList,
4876
+ state: {},
4740
4877
  requestContext,
4741
4878
  abort: (reason) => {
4742
4879
  throw new TripWire(reason || "Aborted by processor");
@@ -4774,8 +4911,12 @@ function createProcessorMiddleware(options) {
4774
4911
  }
4775
4912
  const { stream, ...rest } = await doStream();
4776
4913
  if (!outputProcessors.length) return { stream, ...rest };
4914
+ const outputResultProcessors = outputProcessors.filter((processor) => processor.processOutputResult);
4915
+ const streamAccumulator = outputResultProcessors.length ? new StreamOutputAccumulator() : null;
4777
4916
  const processorStates = /* @__PURE__ */ new Map();
4778
4917
  const runId = crypto.randomUUID();
4918
+ let streamAborted = false;
4919
+ let sawFinish = false;
4779
4920
  const transformedStream = stream.pipeThrough(
4780
4921
  new TransformStream$1({
4781
4922
  async transform(chunk, controller) {
@@ -4812,6 +4953,7 @@ function createProcessorMiddleware(options) {
4812
4953
  }
4813
4954
  } catch (error) {
4814
4955
  if (error instanceof TripWire) {
4956
+ streamAborted = true;
4815
4957
  controller.enqueue({
4816
4958
  type: "error",
4817
4959
  error: new Error(error.message)
@@ -4823,12 +4965,69 @@ function createProcessorMiddleware(options) {
4823
4965
  }
4824
4966
  }
4825
4967
  }
4968
+ if (mastraChunk) {
4969
+ if (mastraChunk.type === "finish") {
4970
+ sawFinish = true;
4971
+ }
4972
+ if (streamAccumulator) {
4973
+ streamAccumulator.addChunk(mastraChunk);
4974
+ }
4975
+ }
4826
4976
  if (mastraChunk) {
4827
4977
  const aiChunk = convertMastraChunkToAISDKStreamPart(mastraChunk);
4828
4978
  if (aiChunk) {
4829
4979
  controller.enqueue(aiChunk);
4830
4980
  }
4831
4981
  }
4982
+ },
4983
+ async flush(controller) {
4984
+ if (!streamAccumulator || streamAborted || !sawFinish) {
4985
+ return;
4986
+ }
4987
+ const messageList = new MessageList({
4988
+ threadId: memory?.threadId,
4989
+ resourceId: memory?.resourceId
4990
+ });
4991
+ const flushOriginalInputCount = processorState?.originalInputCount ?? params.prompt.filter((m) => m.role !== "system").length;
4992
+ const flushNonSystemTotal = params.prompt.filter((m) => m.role !== "system").length;
4993
+ const flushMemoryCount = flushNonSystemTotal - flushOriginalInputCount;
4994
+ let flushNonSystemIndex = 0;
4995
+ for (const msg of params.prompt) {
4996
+ if (msg.role === "system") {
4997
+ messageList.addSystem(msg.content);
4998
+ } else {
4999
+ messageList.add(msg, flushNonSystemIndex < flushMemoryCount ? "memory" : "input");
5000
+ flushNonSystemIndex++;
5001
+ }
5002
+ }
5003
+ const responseMessage = streamAccumulator.buildResponseMessage(memory);
5004
+ if (responseMessage) {
5005
+ messageList.add(responseMessage, "response");
5006
+ }
5007
+ for (const processor of outputResultProcessors) {
5008
+ if (!processor.processOutputResult) continue;
5009
+ try {
5010
+ const procState = processorStates.get(processor.id);
5011
+ await processor.processOutputResult({
5012
+ messages: messageList.get.all.db(),
5013
+ messageList,
5014
+ state: procState?.customState ?? {},
5015
+ requestContext,
5016
+ abort: (reason) => {
5017
+ throw new TripWire(reason || "Aborted by processor");
5018
+ }
5019
+ });
5020
+ } catch (error) {
5021
+ if (error instanceof TripWire) {
5022
+ controller.enqueue({
5023
+ type: "error",
5024
+ error: new Error(error.message)
5025
+ });
5026
+ return;
5027
+ }
5028
+ throw error;
5029
+ }
5030
+ }
4832
5031
  }
4833
5032
  })
4834
5033
  );