@langchain/google-common 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,114 @@
1
1
  import { v4 as uuidv4 } from "uuid";
2
2
  import { AIMessage, AIMessageChunk, isAIMessage, } from "@langchain/core/messages";
3
3
  import { ChatGenerationChunk, } from "@langchain/core/outputs";
4
+ import { isLangChainTool } from "@langchain/core/utils/function_calling";
4
5
  import { GoogleAISafetyError } from "./safety.js";
6
+ import { zodToGeminiParameters } from "./zod_to_gemini_parameters.js";
7
+ export class DefaultGeminiSafetyHandler {
8
+ constructor(settings) {
9
+ Object.defineProperty(this, "errorFinish", {
10
+ enumerable: true,
11
+ configurable: true,
12
+ writable: true,
13
+ value: ["SAFETY", "RECITATION", "OTHER"]
14
+ });
15
+ this.errorFinish = settings?.errorFinish ?? this.errorFinish;
16
+ }
17
+ handleDataPromptFeedback(response, data) {
18
+ // Check to see if our prompt was blocked in the first place
19
+ const promptFeedback = data?.promptFeedback;
20
+ const blockReason = promptFeedback?.blockReason;
21
+ if (blockReason) {
22
+ throw new GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
23
+ }
24
+ return data;
25
+ }
26
+ handleDataFinishReason(response, data) {
27
+ const firstCandidate = data?.candidates?.[0];
28
+ const finishReason = firstCandidate?.finishReason;
29
+ if (this.errorFinish.includes(finishReason)) {
30
+ throw new GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
31
+ }
32
+ return data;
33
+ }
34
+ handleData(response, data) {
35
+ let ret = data;
36
+ ret = this.handleDataPromptFeedback(response, ret);
37
+ ret = this.handleDataFinishReason(response, ret);
38
+ return ret;
39
+ }
40
+ handle(response) {
41
+ let newdata;
42
+ if ("nextChunk" in response.data) {
43
+ // TODO: This is a stream. How to handle?
44
+ newdata = response.data;
45
+ }
46
+ else if (Array.isArray(response.data)) {
47
+ // If it is an array, try to handle every item in the array
48
+ try {
49
+ newdata = response.data.map((item) => this.handleData(response, item));
50
+ }
51
+ catch (xx) {
52
+ // eslint-disable-next-line no-instanceof/no-instanceof
53
+ if (xx instanceof GoogleAISafetyError) {
54
+ throw new GoogleAISafetyError(response, xx.message);
55
+ }
56
+ else {
57
+ throw xx;
58
+ }
59
+ }
60
+ }
61
+ else {
62
+ const data = response.data;
63
+ newdata = this.handleData(response, data);
64
+ }
65
+ return {
66
+ ...response,
67
+ data: newdata,
68
+ };
69
+ }
70
+ }
71
+ export class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
72
+ constructor(settings) {
73
+ super(settings);
74
+ Object.defineProperty(this, "msg", {
75
+ enumerable: true,
76
+ configurable: true,
77
+ writable: true,
78
+ value: ""
79
+ });
80
+ Object.defineProperty(this, "forceNewMessage", {
81
+ enumerable: true,
82
+ configurable: true,
83
+ writable: true,
84
+ value: false
85
+ });
86
+ this.msg = settings?.msg ?? this.msg;
87
+ this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
88
+ }
89
+ setMessage(data) {
90
+ const ret = data;
91
+ if (this.forceNewMessage ||
92
+ !data?.candidates?.[0]?.content?.parts?.length) {
93
+ ret.candidates = data.candidates ?? [];
94
+ ret.candidates[0] = data.candidates[0] ?? {};
95
+ ret.candidates[0].content = data.candidates[0].content ?? {};
96
+ ret.candidates[0].content = {
97
+ role: "model",
98
+ parts: [{ text: this.msg }],
99
+ };
100
+ }
101
+ return ret;
102
+ }
103
+ handleData(response, data) {
104
+ try {
105
+ return super.handleData(response, data);
106
+ }
107
+ catch (xx) {
108
+ return this.setMessage(data);
109
+ }
110
+ }
111
+ }
5
112
  const extractMimeType = (str) => {
6
113
  if (str.startsWith("data:")) {
7
114
  return {
@@ -82,7 +189,7 @@ export function getGeminiAPI(config) {
82
189
  return await blobToFileData(blob);
83
190
  }
84
191
  }
85
- throw new Error("Invalid media content");
192
+ throw new Error(`Invalid media content: ${JSON.stringify(content, null, 1)}`);
86
193
  }
87
194
  async function messageContentComplexToPart(content) {
88
195
  switch (content.type) {
@@ -100,7 +207,7 @@ export function getGeminiAPI(config) {
100
207
  case "media":
101
208
  return await messageContentMedia(content);
102
209
  default:
103
- throw new Error(`Unsupported type received while converting message to message parts`);
210
+ throw new Error(`Unsupported type "${content.type}" received while converting message to message parts: ${content}`);
104
211
  }
105
212
  throw new Error(`Cannot coerce "${content.type}" message part into a string.`);
106
213
  }
@@ -178,8 +285,8 @@ export function getGeminiAPI(config) {
178
285
  },
179
286
  ];
180
287
  }
181
- async function systemMessageToContent(message, useSystemInstruction) {
182
- return useSystemInstruction
288
+ async function systemMessageToContent(message) {
289
+ return config?.useSystemInstruction
183
290
  ? roleMessageToContent("system", message)
184
291
  : [
185
292
  ...(await roleMessageToContent("user", message)),
@@ -233,11 +340,11 @@ export function getGeminiAPI(config) {
233
340
  ];
234
341
  }
235
342
  }
236
- async function baseMessageToContent(message, prevMessage, useSystemInstruction) {
343
+ async function baseMessageToContent(message, prevMessage) {
237
344
  const type = message._getType();
238
345
  switch (type) {
239
346
  case "system":
240
- return systemMessageToContent(message, useSystemInstruction);
347
+ return systemMessageToContent(message);
241
348
  case "human":
242
349
  return roleMessageToContent("user", message);
243
350
  case "ai":
@@ -375,7 +482,8 @@ export function getGeminiAPI(config) {
375
482
  }, "");
376
483
  return ret;
377
484
  }
378
- function safeResponseTo(response, safetyHandler, responseTo) {
485
+ function safeResponseTo(response, responseTo) {
486
+ const safetyHandler = config?.safetyHandler ?? new DefaultGeminiSafetyHandler();
379
487
  try {
380
488
  const safeResponse = safetyHandler.handle(response);
381
489
  return responseTo(safeResponse);
@@ -389,8 +497,8 @@ export function getGeminiAPI(config) {
389
497
  throw xx;
390
498
  }
391
499
  }
392
- function safeResponseToString(response, safetyHandler) {
393
- return safeResponseTo(response, safetyHandler, responseToString);
500
+ function safeResponseToString(response) {
501
+ return safeResponseTo(response, responseToString);
394
502
  }
395
503
  function responseToGenerationInfo(response) {
396
504
  if (!Array.isArray(response.data)) {
@@ -420,8 +528,8 @@ export function getGeminiAPI(config) {
420
528
  generationInfo: responseToGenerationInfo(response),
421
529
  });
422
530
  }
423
- function safeResponseToChatGeneration(response, safetyHandler) {
424
- return safeResponseTo(response, safetyHandler, responseToChatGeneration);
531
+ function safeResponseToChatGeneration(response) {
532
+ return safeResponseTo(response, responseToChatGeneration);
425
533
  }
426
534
  function chunkToString(chunk) {
427
535
  if (chunk === null) {
@@ -553,8 +661,8 @@ export function getGeminiAPI(config) {
553
661
  const fields = responseToBaseMessageFields(response);
554
662
  return new AIMessage(fields);
555
663
  }
556
- function safeResponseToBaseMessage(response, safetyHandler) {
557
- return safeResponseTo(response, safetyHandler, responseToBaseMessage);
664
+ function safeResponseToBaseMessage(response) {
665
+ return safeResponseTo(response, responseToBaseMessage);
558
666
  }
559
667
  function responseToChatResult(response) {
560
668
  const generations = responseToChatGenerations(response);
@@ -563,17 +671,190 @@ export function getGeminiAPI(config) {
563
671
  llmOutput: responseToGenerationInfo(response),
564
672
  };
565
673
  }
566
- function safeResponseToChatResult(response, safetyHandler) {
567
- return safeResponseTo(response, safetyHandler, responseToChatResult);
674
+ function safeResponseToChatResult(response) {
675
+ return safeResponseTo(response, responseToChatResult);
676
+ }
677
+ function inputType(input) {
678
+ if (typeof input === "string") {
679
+ return "MessageContent";
680
+ }
681
+ else {
682
+ const firstItem = input[0];
683
+ if (Object.hasOwn(firstItem, "content")) {
684
+ return "BaseMessageArray";
685
+ }
686
+ else {
687
+ return "MessageContent";
688
+ }
689
+ }
690
+ }
691
+ async function formatMessageContents(input, _parameters) {
692
+ const parts = await messageContentToParts(input);
693
+ const contents = [
694
+ {
695
+ role: "user",
696
+ parts,
697
+ },
698
+ ];
699
+ return contents;
700
+ }
701
+ async function formatBaseMessageContents(input, _parameters) {
702
+ const inputPromises = input.map((msg, i) => baseMessageToContent(msg, input[i - 1]));
703
+ const inputs = await Promise.all(inputPromises);
704
+ return inputs.reduce((acc, cur) => {
705
+ // Filter out the system content
706
+ if (cur.every((content) => content.role === "system")) {
707
+ return acc;
708
+ }
709
+ // Combine adjacent function messages
710
+ if (cur[0]?.role === "function" &&
711
+ acc.length > 0 &&
712
+ acc[acc.length - 1].role === "function") {
713
+ acc[acc.length - 1].parts = [
714
+ ...acc[acc.length - 1].parts,
715
+ ...cur[0].parts,
716
+ ];
717
+ }
718
+ else {
719
+ acc.push(...cur);
720
+ }
721
+ return acc;
722
+ }, []);
723
+ }
724
+ async function formatContents(input, parameters) {
725
+ const it = inputType(input);
726
+ switch (it) {
727
+ case "MessageContent":
728
+ return formatMessageContents(input, parameters);
729
+ case "BaseMessageArray":
730
+ return formatBaseMessageContents(input, parameters);
731
+ default:
732
+ throw new Error(`Unknown input type "${it}": ${input}`);
733
+ }
734
+ }
735
+ function formatGenerationConfig(parameters) {
736
+ return {
737
+ temperature: parameters.temperature,
738
+ topK: parameters.topK,
739
+ topP: parameters.topP,
740
+ maxOutputTokens: parameters.maxOutputTokens,
741
+ stopSequences: parameters.stopSequences,
742
+ responseMimeType: parameters.responseMimeType,
743
+ };
744
+ }
745
+ function formatSafetySettings(parameters) {
746
+ return parameters.safetySettings ?? [];
747
+ }
748
+ async function formatBaseMessageSystemInstruction(input) {
749
+ let ret = {};
750
+ for (let index = 0; index < input.length; index += 1) {
751
+ const message = input[index];
752
+ if (message._getType() === "system") {
753
+ // For system types, we only want it if it is the first message,
754
+ // if it appears anywhere else, it should be an error.
755
+ if (index === 0) {
756
+ // eslint-disable-next-line prefer-destructuring
757
+ ret = (await baseMessageToContent(message, undefined))[0];
758
+ }
759
+ else {
760
+ throw new Error("System messages are only permitted as the first passed message.");
761
+ }
762
+ }
763
+ }
764
+ return ret;
765
+ }
766
+ async function formatSystemInstruction(input) {
767
+ if (!config?.useSystemInstruction) {
768
+ return {};
769
+ }
770
+ const it = inputType(input);
771
+ switch (it) {
772
+ case "BaseMessageArray":
773
+ return formatBaseMessageSystemInstruction(input);
774
+ default:
775
+ return {};
776
+ }
777
+ }
778
+ function structuredToolToFunctionDeclaration(tool) {
779
+ const jsonSchema = zodToGeminiParameters(tool.schema);
780
+ return {
781
+ name: tool.name,
782
+ description: tool.description ?? `A function available to call.`,
783
+ parameters: jsonSchema,
784
+ };
785
+ }
786
+ function structuredToolsToGeminiTools(tools) {
787
+ return [
788
+ {
789
+ functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
790
+ },
791
+ ];
792
+ }
793
+ function formatTools(parameters) {
794
+ const tools = parameters?.tools;
795
+ if (!tools || tools.length === 0) {
796
+ return [];
797
+ }
798
+ if (tools.every(isLangChainTool)) {
799
+ return structuredToolsToGeminiTools(tools);
800
+ }
801
+ else {
802
+ if (tools.length === 1 &&
803
+ (!("functionDeclarations" in tools[0]) ||
804
+ !tools[0].functionDeclarations?.length)) {
805
+ return [];
806
+ }
807
+ return tools;
808
+ }
809
+ }
810
+ function formatToolConfig(parameters) {
811
+ if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
812
+ return undefined;
813
+ }
814
+ return {
815
+ functionCallingConfig: {
816
+ mode: parameters.tool_choice,
817
+ allowedFunctionNames: parameters.allowed_function_names,
818
+ },
819
+ };
820
+ }
821
+ async function formatData(input, parameters) {
822
+ const typedInput = input;
823
+ const contents = await formatContents(typedInput, parameters);
824
+ const generationConfig = formatGenerationConfig(parameters);
825
+ const tools = formatTools(parameters);
826
+ const toolConfig = formatToolConfig(parameters);
827
+ const safetySettings = formatSafetySettings(parameters);
828
+ const systemInstruction = await formatSystemInstruction(typedInput);
829
+ const ret = {
830
+ contents,
831
+ generationConfig,
832
+ };
833
+ if (tools && tools.length) {
834
+ ret.tools = tools;
835
+ }
836
+ if (toolConfig) {
837
+ ret.toolConfig = toolConfig;
838
+ }
839
+ if (safetySettings && safetySettings.length) {
840
+ ret.safetySettings = safetySettings;
841
+ }
842
+ if (systemInstruction?.role &&
843
+ systemInstruction?.parts &&
844
+ systemInstruction?.parts?.length) {
845
+ ret.systemInstruction = systemInstruction;
846
+ }
847
+ return ret;
568
848
  }
569
849
  return {
570
850
  messageContentToParts,
571
851
  baseMessageToContent,
572
- safeResponseToString,
573
- safeResponseToChatGeneration,
852
+ responseToString: safeResponseToString,
853
+ responseToChatGeneration: safeResponseToChatGeneration,
574
854
  chunkToString,
575
- safeResponseToBaseMessage,
576
- safeResponseToChatResult,
855
+ responseToBaseMessage: safeResponseToBaseMessage,
856
+ responseToChatResult: safeResponseToChatResult,
857
+ formatData,
577
858
  };
578
859
  }
579
860
  export function validateGeminiParams(params) {
@@ -594,108 +875,3 @@ export function validateGeminiParams(params) {
594
875
  export function isModelGemini(modelName) {
595
876
  return modelName.toLowerCase().startsWith("gemini");
596
877
  }
597
- export class DefaultGeminiSafetyHandler {
598
- constructor(settings) {
599
- Object.defineProperty(this, "errorFinish", {
600
- enumerable: true,
601
- configurable: true,
602
- writable: true,
603
- value: ["SAFETY", "RECITATION", "OTHER"]
604
- });
605
- this.errorFinish = settings?.errorFinish ?? this.errorFinish;
606
- }
607
- handleDataPromptFeedback(response, data) {
608
- // Check to see if our prompt was blocked in the first place
609
- const promptFeedback = data?.promptFeedback;
610
- const blockReason = promptFeedback?.blockReason;
611
- if (blockReason) {
612
- throw new GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
613
- }
614
- return data;
615
- }
616
- handleDataFinishReason(response, data) {
617
- const firstCandidate = data?.candidates?.[0];
618
- const finishReason = firstCandidate?.finishReason;
619
- if (this.errorFinish.includes(finishReason)) {
620
- throw new GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
621
- }
622
- return data;
623
- }
624
- handleData(response, data) {
625
- let ret = data;
626
- ret = this.handleDataPromptFeedback(response, ret);
627
- ret = this.handleDataFinishReason(response, ret);
628
- return ret;
629
- }
630
- handle(response) {
631
- let newdata;
632
- if ("nextChunk" in response.data) {
633
- // TODO: This is a stream. How to handle?
634
- newdata = response.data;
635
- }
636
- else if (Array.isArray(response.data)) {
637
- // If it is an array, try to handle every item in the array
638
- try {
639
- newdata = response.data.map((item) => this.handleData(response, item));
640
- }
641
- catch (xx) {
642
- // eslint-disable-next-line no-instanceof/no-instanceof
643
- if (xx instanceof GoogleAISafetyError) {
644
- throw new GoogleAISafetyError(response, xx.message);
645
- }
646
- else {
647
- throw xx;
648
- }
649
- }
650
- }
651
- else {
652
- const data = response.data;
653
- newdata = this.handleData(response, data);
654
- }
655
- return {
656
- ...response,
657
- data: newdata,
658
- };
659
- }
660
- }
661
- export class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
662
- constructor(settings) {
663
- super(settings);
664
- Object.defineProperty(this, "msg", {
665
- enumerable: true,
666
- configurable: true,
667
- writable: true,
668
- value: ""
669
- });
670
- Object.defineProperty(this, "forceNewMessage", {
671
- enumerable: true,
672
- configurable: true,
673
- writable: true,
674
- value: false
675
- });
676
- this.msg = settings?.msg ?? this.msg;
677
- this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
678
- }
679
- setMessage(data) {
680
- const ret = data;
681
- if (this.forceNewMessage ||
682
- !data?.candidates?.[0]?.content?.parts?.length) {
683
- ret.candidates = data.candidates ?? [];
684
- ret.candidates[0] = data.candidates[0] ?? {};
685
- ret.candidates[0].content = data.candidates[0].content ?? {};
686
- ret.candidates[0].content = {
687
- role: "model",
688
- parts: [{ text: this.msg }],
689
- };
690
- }
691
- return ret;
692
- }
693
- handleData(response, data) {
694
- try {
695
- return super.handleData(response, data);
696
- }
697
- catch (xx) {
698
- return this.setMessage(data);
699
- }
700
- }
701
- }