@mastra/claude 0.1.0 → 0.2.0-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -6,6 +6,7 @@ import { RequestContext } from '@mastra/core/request-context';
6
6
  import { ChunkFrom, MastraModelOutput } from '@mastra/core/stream';
7
7
  import { MessageList } from '@mastra/core/agent/message-list';
8
8
  import { getOrCreateSpan, EntityType, SpanType, executeWithContext } from '@mastra/core/observability';
9
+ import { toStandardSchema, standardSchemaToJSONSchema } from '@mastra/core/schema';
9
10
 
10
11
  // src/index.ts
11
12
  function createNoopModel({ modelId, provider }) {
@@ -33,7 +34,8 @@ function createCompletedMastraStream({
33
34
  modelId,
34
35
  usage,
35
36
  providerMetadata,
36
- costContext
37
+ costContext,
38
+ object
37
39
  }) {
38
40
  return new ReadableStream({
39
41
  start(controller) {
@@ -58,7 +60,8 @@ function createCompletedMastraStream({
58
60
  modelId,
59
61
  usage,
60
62
  providerMetadata,
61
- costContext
63
+ costContext,
64
+ object
62
65
  });
63
66
  controller.close();
64
67
  }
@@ -70,11 +73,12 @@ function createMastraOutput({
70
73
  modelId,
71
74
  provider,
72
75
  stream,
76
+ responseText = "",
73
77
  options
74
78
  }) {
75
79
  const messageList = new MessageList();
76
80
  messageList.add(messages, "input");
77
- messageList.add([{ role: "assistant", content: "" }], "response");
81
+ messageList.add([{ role: "assistant", content: responseText }], "response");
78
82
  return new MastraModelOutput({
79
83
  model: {
80
84
  modelId,
@@ -106,7 +110,8 @@ function toFullOutput({
106
110
  modelId: result.response.modelId,
107
111
  usage: toLanguageModelUsage(result.usage),
108
112
  providerMetadata: result.providerMetadata,
109
- costContext: result.costContext
113
+ costContext: result.costContext,
114
+ object: result.object
110
115
  });
111
116
  return createMastraOutput({
112
117
  messages,
@@ -114,6 +119,7 @@ function toFullOutput({
114
119
  modelId: result.response.modelId,
115
120
  provider,
116
121
  stream,
122
+ responseText: text,
117
123
  options
118
124
  }).getFullOutput();
119
125
  }
@@ -484,7 +490,8 @@ function enqueueFinishChunks(controller, {
484
490
  modelId,
485
491
  usage,
486
492
  providerMetadata,
487
- costContext
493
+ costContext,
494
+ object
488
495
  }) {
489
496
  const timestamp = /* @__PURE__ */ new Date();
490
497
  const response = {
@@ -508,6 +515,14 @@ function enqueueFinishChunks(controller, {
508
515
  providerMetadata
509
516
  }
510
517
  });
518
+ if (object !== void 0) {
519
+ controller.enqueue({
520
+ type: "object-result",
521
+ runId,
522
+ from: ChunkFrom.AGENT,
523
+ object
524
+ });
525
+ }
511
526
  controller.enqueue({
512
527
  type: "step-finish",
513
528
  runId,
@@ -612,6 +627,48 @@ function promptToText(prompt) {
612
627
  }
613
628
  return "";
614
629
  }
630
+ function getStructuredOutputSchema(structuredOutput) {
631
+ if (!structuredOutput?.schema) {
632
+ return void 0;
633
+ }
634
+ return standardSchemaToJSONSchema(toStandardSchema(structuredOutput.schema));
635
+ }
636
+ async function getStructuredOutputFromValue(value, structuredOutput) {
637
+ if (!structuredOutput?.schema) {
638
+ return void 0;
639
+ }
640
+ let parsed;
641
+ if (typeof value === "string") {
642
+ try {
643
+ parsed = JSON.parse(value);
644
+ } catch (error) {
645
+ return handleStructuredOutputError(
646
+ new Error("Structured output must be valid JSON.", { cause: error }),
647
+ structuredOutput
648
+ );
649
+ }
650
+ } else {
651
+ parsed = value;
652
+ }
653
+ const schema = toStandardSchema(structuredOutput.schema);
654
+ const result = await schema["~standard"].validate(parsed);
655
+ if (!result.issues) {
656
+ return result.value;
657
+ }
658
+ const message = result.issues.map((issue) => `- ${issue.path?.join(".") || "root"}: ${issue.message}`).join("\n");
659
+ return handleStructuredOutputError(new Error(`Structured output validation failed:
660
+ ${message}`), structuredOutput);
661
+ }
662
+ function handleStructuredOutputError(error, structuredOutput) {
663
+ if (structuredOutput.errorStrategy === "fallback") {
664
+ return structuredOutput.fallbackValue;
665
+ }
666
+ if (structuredOutput.errorStrategy === "warn") {
667
+ structuredOutput.logger?.warn(error.message);
668
+ return void 0;
669
+ }
670
+ throw error;
671
+ }
615
672
  function sumDefined(...values) {
616
673
  const defined = values.filter((value) => typeof value === "number");
617
674
  if (defined.length === 0) {
@@ -672,9 +729,7 @@ var ClaudeSDKAgent = class extends Agent {
672
729
  });
673
730
  let result;
674
731
  try {
675
- result = await telemetry.execute(
676
- () => runClaudeGenerate(prompt, this.options, telemetry, options?.abortSignal ?? options?.signal)
677
- );
732
+ result = await telemetry.execute(() => runClaudeGenerate(prompt, this.options, telemetry, options));
678
733
  telemetry.endGenerate(result);
679
734
  } catch (error) {
680
735
  telemetry.fail(error);
@@ -685,7 +740,7 @@ var ClaudeSDKAgent = class extends Agent {
685
740
  runId,
686
741
  provider: PROVIDER,
687
742
  result,
688
- options: telemetry.outputOptions()
743
+ options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
689
744
  });
690
745
  }
691
746
  async stream(messages, options) {
@@ -718,26 +773,82 @@ var ClaudeSDKAgent = class extends Agent {
718
773
  runId,
719
774
  modelId,
720
775
  provider: PROVIDER,
721
- stream: telemetry.wrapStream(
722
- runClaudeAsMastraStream(prompt, this.options, runId, telemetry, options?.abortSignal ?? options?.signal)
723
- ),
724
- options: telemetry.outputOptions()
776
+ stream: telemetry.wrapStream(runClaudeAsMastraStream(prompt, this.options, runId, telemetry, options)),
777
+ options: { ...telemetry.outputOptions(), structuredOutput: options?.structuredOutput }
725
778
  });
726
779
  }
780
+ async resumeGenerate(resumeData, options) {
781
+ const data = validateClaudeResumeData(resumeData);
782
+ return this.generate(data.message, createClaudeResumeRunOptions(data, options));
783
+ }
784
+ async resumeStream(resumeData, options) {
785
+ const data = validateClaudeResumeData(resumeData);
786
+ return this.stream(data.message, createClaudeResumeRunOptions(data, options));
787
+ }
727
788
  };
728
- async function runClaudeGenerate(prompt, options, telemetry, signal) {
789
+ function validateClaudeResumeData(resumeData) {
790
+ if (!isRecord(resumeData) || !("message" in resumeData)) {
791
+ throw new Error("ClaudeSDKAgent resumeData must include a message.");
792
+ }
793
+ const hasSessionId = "sessionId" in resumeData;
794
+ const hasContinue = "continue" in resumeData;
795
+ if (hasSessionId && hasContinue) {
796
+ throw new Error("ClaudeSDKAgent resumeData must include either sessionId or continue: true, not both.");
797
+ }
798
+ if (hasSessionId) {
799
+ if (typeof resumeData.sessionId !== "string") {
800
+ throw new Error("ClaudeSDKAgent resumeData.sessionId must be a string.");
801
+ }
802
+ return resumeData;
803
+ }
804
+ if (hasContinue) {
805
+ if (resumeData.continue !== true) {
806
+ throw new Error("ClaudeSDKAgent resumeData.continue must be true when provided.");
807
+ }
808
+ return resumeData;
809
+ }
810
+ throw new Error("ClaudeSDKAgent resumeData must include sessionId or continue: true.");
811
+ }
812
+ function createClaudeResumeRunOptions(resumeData, options) {
813
+ const sdkOptions = { ...options?.sdkOptions };
814
+ if ("sessionId" in resumeData && typeof resumeData.sessionId === "string") {
815
+ sdkOptions.resume = resumeData.sessionId;
816
+ if (resumeData.forkSession !== void 0) {
817
+ sdkOptions.forkSession = resumeData.forkSession;
818
+ }
819
+ if (resumeData.resumeSessionAt !== void 0) {
820
+ sdkOptions.resumeSessionAt = resumeData.resumeSessionAt;
821
+ }
822
+ } else {
823
+ sdkOptions.continue = true;
824
+ }
825
+ return {
826
+ ...options,
827
+ sdkOptions
828
+ };
829
+ }
830
+ async function runClaudeGenerate(prompt, options, telemetry, runOptions) {
729
831
  let text = "";
832
+ let structuredOutputValue;
730
833
  const usage = createClaudeUsageCollector();
731
- for await (const message of observeClaudeMessages(runClaude(prompt, options, signal), telemetry)) {
834
+ for await (const message of observeClaudeMessages(
835
+ runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
836
+ telemetry
837
+ )) {
732
838
  usage.record(message);
733
839
  if (message.type === "result") {
734
840
  if (message.subtype !== "success") {
735
841
  throw new Error(message.errors.join("\n") || `Claude Agent SDK failed with ${message.subtype}`);
736
842
  }
737
843
  text = message.result;
844
+ structuredOutputValue = getClaudeStructuredOutput(message);
738
845
  }
739
846
  }
740
847
  const totals = usage.totals();
848
+ const object = await getStructuredOutputFromValue(
849
+ structuredOutputValue === void 0 ? text : structuredOutputValue,
850
+ runOptions?.structuredOutput
851
+ );
741
852
  return {
742
853
  content: [{ type: "text", text }],
743
854
  finishReason: { unified: "stop", raw: "stop" },
@@ -748,10 +859,11 @@ async function runClaudeGenerate(prompt, options, telemetry, signal) {
748
859
  timestamp: /* @__PURE__ */ new Date()
749
860
  },
750
861
  providerMetadata: getClaudeProviderMetadata(options, totals),
751
- costContext: getClaudeCostContext(options, totals)
862
+ costContext: getClaudeCostContext(options, totals),
863
+ object
752
864
  };
753
865
  }
754
- function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
866
+ function runClaudeAsMastraStream(prompt, options, runId, telemetry, runOptions) {
755
867
  return new ReadableStream({
756
868
  start: async (controller) => {
757
869
  const textId = randomUUID();
@@ -759,6 +871,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
759
871
  const modelId = getModelId(options);
760
872
  const usage = createClaudeUsageCollector();
761
873
  let text = "";
874
+ let structuredOutputValue;
762
875
  let sawDelta = false;
763
876
  try {
764
877
  enqueueStartChunks(controller, {
@@ -769,7 +882,10 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
769
882
  modelId,
770
883
  providerMetadata: getClaudeProviderMetadata(options, usage.totals())
771
884
  });
772
- for await (const message of observeClaudeMessages(runClaude(prompt, options, signal), telemetry)) {
885
+ for await (const message of observeClaudeMessages(
886
+ runClaude(prompt, options, runOptions?.abortSignal ?? runOptions?.signal, runOptions),
887
+ telemetry
888
+ )) {
773
889
  usage.record(message);
774
890
  const delta = getTextDelta(message);
775
891
  if (delta) {
@@ -785,6 +901,7 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
785
901
  text += message.result;
786
902
  enqueueTextDelta(controller, runId, textId, message.result);
787
903
  }
904
+ structuredOutputValue = getClaudeStructuredOutput(message);
788
905
  }
789
906
  }
790
907
  const totals = usage.totals();
@@ -798,7 +915,11 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
798
915
  modelId,
799
916
  usage: usage.toLanguageModelUsage(),
800
917
  providerMetadata,
801
- costContext: getClaudeCostContext(options, totals)
918
+ costContext: getClaudeCostContext(options, totals),
919
+ object: await getStructuredOutputFromValue(
920
+ structuredOutputValue === void 0 ? text : structuredOutputValue,
921
+ runOptions?.structuredOutput
922
+ )
802
923
  });
803
924
  controller.close();
804
925
  } catch (error) {
@@ -813,11 +934,19 @@ function runClaudeAsMastraStream(prompt, options, runId, telemetry, signal) {
813
934
  }
814
935
  });
815
936
  }
816
- function runClaude(prompt, options, signal) {
937
+ function runClaude(prompt, options, signal, runOptions) {
817
938
  const abortController = createAbortController(signal);
818
939
  const queryOptions = {
819
- ...options.sdkOptions
940
+ ...options.sdkOptions,
941
+ ...runOptions?.sdkOptions
820
942
  };
943
+ const outputSchema = getStructuredOutputSchema(runOptions?.structuredOutput);
944
+ if (outputSchema) {
945
+ queryOptions.outputFormat = {
946
+ type: "json_schema",
947
+ schema: outputSchema
948
+ };
949
+ }
821
950
  if (abortController) {
822
951
  queryOptions.abortController = abortController;
823
952
  }
@@ -826,6 +955,12 @@ function runClaude(prompt, options, signal) {
826
955
  options: queryOptions
827
956
  });
828
957
  }
958
+ function getClaudeStructuredOutput(message) {
959
+ if (message.type !== "result") {
960
+ return void 0;
961
+ }
962
+ return message.structured_output;
963
+ }
829
964
  async function* observeClaudeMessages(messages, telemetry) {
830
965
  for await (const message of messages) {
831
966
  recordClaudeToolTelemetry(message, telemetry);