@langchain/google-common 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,119 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.MessageGeminiSafetyHandler = exports.DefaultGeminiSafetyHandler = exports.isModelGemini = exports.validateGeminiParams = exports.getGeminiAPI = void 0;
3
+ exports.isModelGemini = exports.validateGeminiParams = exports.getGeminiAPI = exports.MessageGeminiSafetyHandler = exports.DefaultGeminiSafetyHandler = void 0;
4
4
  const uuid_1 = require("uuid");
5
5
  const messages_1 = require("@langchain/core/messages");
6
6
  const outputs_1 = require("@langchain/core/outputs");
7
+ const function_calling_1 = require("@langchain/core/utils/function_calling");
7
8
  const safety_js_1 = require("./safety.cjs");
9
+ const zod_to_gemini_parameters_js_1 = require("./zod_to_gemini_parameters.cjs");
10
+ class DefaultGeminiSafetyHandler {
11
+ constructor(settings) {
12
+ Object.defineProperty(this, "errorFinish", {
13
+ enumerable: true,
14
+ configurable: true,
15
+ writable: true,
16
+ value: ["SAFETY", "RECITATION", "OTHER"]
17
+ });
18
+ this.errorFinish = settings?.errorFinish ?? this.errorFinish;
19
+ }
20
+ handleDataPromptFeedback(response, data) {
21
+ // Check to see if our prompt was blocked in the first place
22
+ const promptFeedback = data?.promptFeedback;
23
+ const blockReason = promptFeedback?.blockReason;
24
+ if (blockReason) {
25
+ throw new safety_js_1.GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
26
+ }
27
+ return data;
28
+ }
29
+ handleDataFinishReason(response, data) {
30
+ const firstCandidate = data?.candidates?.[0];
31
+ const finishReason = firstCandidate?.finishReason;
32
+ if (this.errorFinish.includes(finishReason)) {
33
+ throw new safety_js_1.GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
34
+ }
35
+ return data;
36
+ }
37
+ handleData(response, data) {
38
+ let ret = data;
39
+ ret = this.handleDataPromptFeedback(response, ret);
40
+ ret = this.handleDataFinishReason(response, ret);
41
+ return ret;
42
+ }
43
+ handle(response) {
44
+ let newdata;
45
+ if ("nextChunk" in response.data) {
46
+ // TODO: This is a stream. How to handle?
47
+ newdata = response.data;
48
+ }
49
+ else if (Array.isArray(response.data)) {
50
+ // If it is an array, try to handle every item in the array
51
+ try {
52
+ newdata = response.data.map((item) => this.handleData(response, item));
53
+ }
54
+ catch (xx) {
55
+ // eslint-disable-next-line no-instanceof/no-instanceof
56
+ if (xx instanceof safety_js_1.GoogleAISafetyError) {
57
+ throw new safety_js_1.GoogleAISafetyError(response, xx.message);
58
+ }
59
+ else {
60
+ throw xx;
61
+ }
62
+ }
63
+ }
64
+ else {
65
+ const data = response.data;
66
+ newdata = this.handleData(response, data);
67
+ }
68
+ return {
69
+ ...response,
70
+ data: newdata,
71
+ };
72
+ }
73
+ }
74
+ exports.DefaultGeminiSafetyHandler = DefaultGeminiSafetyHandler;
75
+ class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
76
+ constructor(settings) {
77
+ super(settings);
78
+ Object.defineProperty(this, "msg", {
79
+ enumerable: true,
80
+ configurable: true,
81
+ writable: true,
82
+ value: ""
83
+ });
84
+ Object.defineProperty(this, "forceNewMessage", {
85
+ enumerable: true,
86
+ configurable: true,
87
+ writable: true,
88
+ value: false
89
+ });
90
+ this.msg = settings?.msg ?? this.msg;
91
+ this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
92
+ }
93
+ setMessage(data) {
94
+ const ret = data;
95
+ if (this.forceNewMessage ||
96
+ !data?.candidates?.[0]?.content?.parts?.length) {
97
+ ret.candidates = data.candidates ?? [];
98
+ ret.candidates[0] = data.candidates[0] ?? {};
99
+ ret.candidates[0].content = data.candidates[0].content ?? {};
100
+ ret.candidates[0].content = {
101
+ role: "model",
102
+ parts: [{ text: this.msg }],
103
+ };
104
+ }
105
+ return ret;
106
+ }
107
+ handleData(response, data) {
108
+ try {
109
+ return super.handleData(response, data);
110
+ }
111
+ catch (xx) {
112
+ return this.setMessage(data);
113
+ }
114
+ }
115
+ }
116
+ exports.MessageGeminiSafetyHandler = MessageGeminiSafetyHandler;
8
117
  const extractMimeType = (str) => {
9
118
  if (str.startsWith("data:")) {
10
119
  return {
@@ -85,7 +194,7 @@ function getGeminiAPI(config) {
85
194
  return await blobToFileData(blob);
86
195
  }
87
196
  }
88
- throw new Error("Invalid media content");
197
+ throw new Error(`Invalid media content: ${JSON.stringify(content, null, 1)}`);
89
198
  }
90
199
  async function messageContentComplexToPart(content) {
91
200
  switch (content.type) {
@@ -103,7 +212,7 @@ function getGeminiAPI(config) {
103
212
  case "media":
104
213
  return await messageContentMedia(content);
105
214
  default:
106
- throw new Error(`Unsupported type received while converting message to message parts`);
215
+ throw new Error(`Unsupported type "${content.type}" received while converting message to message parts: ${content}`);
107
216
  }
108
217
  throw new Error(`Cannot coerce "${content.type}" message part into a string.`);
109
218
  }
@@ -181,8 +290,8 @@ function getGeminiAPI(config) {
181
290
  },
182
291
  ];
183
292
  }
184
- async function systemMessageToContent(message, useSystemInstruction) {
185
- return useSystemInstruction
293
+ async function systemMessageToContent(message) {
294
+ return config?.useSystemInstruction
186
295
  ? roleMessageToContent("system", message)
187
296
  : [
188
297
  ...(await roleMessageToContent("user", message)),
@@ -236,11 +345,11 @@ function getGeminiAPI(config) {
236
345
  ];
237
346
  }
238
347
  }
239
- async function baseMessageToContent(message, prevMessage, useSystemInstruction) {
348
+ async function baseMessageToContent(message, prevMessage) {
240
349
  const type = message._getType();
241
350
  switch (type) {
242
351
  case "system":
243
- return systemMessageToContent(message, useSystemInstruction);
352
+ return systemMessageToContent(message);
244
353
  case "human":
245
354
  return roleMessageToContent("user", message);
246
355
  case "ai":
@@ -378,7 +487,8 @@ function getGeminiAPI(config) {
378
487
  }, "");
379
488
  return ret;
380
489
  }
381
- function safeResponseTo(response, safetyHandler, responseTo) {
490
+ function safeResponseTo(response, responseTo) {
491
+ const safetyHandler = config?.safetyHandler ?? new DefaultGeminiSafetyHandler();
382
492
  try {
383
493
  const safeResponse = safetyHandler.handle(response);
384
494
  return responseTo(safeResponse);
@@ -392,8 +502,8 @@ function getGeminiAPI(config) {
392
502
  throw xx;
393
503
  }
394
504
  }
395
- function safeResponseToString(response, safetyHandler) {
396
- return safeResponseTo(response, safetyHandler, responseToString);
505
+ function safeResponseToString(response) {
506
+ return safeResponseTo(response, responseToString);
397
507
  }
398
508
  function responseToGenerationInfo(response) {
399
509
  if (!Array.isArray(response.data)) {
@@ -423,8 +533,8 @@ function getGeminiAPI(config) {
423
533
  generationInfo: responseToGenerationInfo(response),
424
534
  });
425
535
  }
426
- function safeResponseToChatGeneration(response, safetyHandler) {
427
- return safeResponseTo(response, safetyHandler, responseToChatGeneration);
536
+ function safeResponseToChatGeneration(response) {
537
+ return safeResponseTo(response, responseToChatGeneration);
428
538
  }
429
539
  function chunkToString(chunk) {
430
540
  if (chunk === null) {
@@ -556,8 +666,8 @@ function getGeminiAPI(config) {
556
666
  const fields = responseToBaseMessageFields(response);
557
667
  return new messages_1.AIMessage(fields);
558
668
  }
559
- function safeResponseToBaseMessage(response, safetyHandler) {
560
- return safeResponseTo(response, safetyHandler, responseToBaseMessage);
669
+ function safeResponseToBaseMessage(response) {
670
+ return safeResponseTo(response, responseToBaseMessage);
561
671
  }
562
672
  function responseToChatResult(response) {
563
673
  const generations = responseToChatGenerations(response);
@@ -566,17 +676,190 @@ function getGeminiAPI(config) {
566
676
  llmOutput: responseToGenerationInfo(response),
567
677
  };
568
678
  }
569
- function safeResponseToChatResult(response, safetyHandler) {
570
- return safeResponseTo(response, safetyHandler, responseToChatResult);
679
+ function safeResponseToChatResult(response) {
680
+ return safeResponseTo(response, responseToChatResult);
681
+ }
682
+ function inputType(input) {
683
+ if (typeof input === "string") {
684
+ return "MessageContent";
685
+ }
686
+ else {
687
+ const firstItem = input[0];
688
+ if (Object.hasOwn(firstItem, "content")) {
689
+ return "BaseMessageArray";
690
+ }
691
+ else {
692
+ return "MessageContent";
693
+ }
694
+ }
695
+ }
696
+ async function formatMessageContents(input, _parameters) {
697
+ const parts = await messageContentToParts(input);
698
+ const contents = [
699
+ {
700
+ role: "user",
701
+ parts,
702
+ },
703
+ ];
704
+ return contents;
705
+ }
706
+ async function formatBaseMessageContents(input, _parameters) {
707
+ const inputPromises = input.map((msg, i) => baseMessageToContent(msg, input[i - 1]));
708
+ const inputs = await Promise.all(inputPromises);
709
+ return inputs.reduce((acc, cur) => {
710
+ // Filter out the system content
711
+ if (cur.every((content) => content.role === "system")) {
712
+ return acc;
713
+ }
714
+ // Combine adjacent function messages
715
+ if (cur[0]?.role === "function" &&
716
+ acc.length > 0 &&
717
+ acc[acc.length - 1].role === "function") {
718
+ acc[acc.length - 1].parts = [
719
+ ...acc[acc.length - 1].parts,
720
+ ...cur[0].parts,
721
+ ];
722
+ }
723
+ else {
724
+ acc.push(...cur);
725
+ }
726
+ return acc;
727
+ }, []);
728
+ }
729
+ async function formatContents(input, parameters) {
730
+ const it = inputType(input);
731
+ switch (it) {
732
+ case "MessageContent":
733
+ return formatMessageContents(input, parameters);
734
+ case "BaseMessageArray":
735
+ return formatBaseMessageContents(input, parameters);
736
+ default:
737
+ throw new Error(`Unknown input type "${it}": ${input}`);
738
+ }
739
+ }
740
+ function formatGenerationConfig(parameters) {
741
+ return {
742
+ temperature: parameters.temperature,
743
+ topK: parameters.topK,
744
+ topP: parameters.topP,
745
+ maxOutputTokens: parameters.maxOutputTokens,
746
+ stopSequences: parameters.stopSequences,
747
+ responseMimeType: parameters.responseMimeType,
748
+ };
749
+ }
750
+ function formatSafetySettings(parameters) {
751
+ return parameters.safetySettings ?? [];
752
+ }
753
+ async function formatBaseMessageSystemInstruction(input) {
754
+ let ret = {};
755
+ for (let index = 0; index < input.length; index += 1) {
756
+ const message = input[index];
757
+ if (message._getType() === "system") {
758
+ // For system types, we only want it if it is the first message,
759
+ // if it appears anywhere else, it should be an error.
760
+ if (index === 0) {
761
+ // eslint-disable-next-line prefer-destructuring
762
+ ret = (await baseMessageToContent(message, undefined))[0];
763
+ }
764
+ else {
765
+ throw new Error("System messages are only permitted as the first passed message.");
766
+ }
767
+ }
768
+ }
769
+ return ret;
770
+ }
771
+ async function formatSystemInstruction(input) {
772
+ if (!config?.useSystemInstruction) {
773
+ return {};
774
+ }
775
+ const it = inputType(input);
776
+ switch (it) {
777
+ case "BaseMessageArray":
778
+ return formatBaseMessageSystemInstruction(input);
779
+ default:
780
+ return {};
781
+ }
782
+ }
783
+ function structuredToolToFunctionDeclaration(tool) {
784
+ const jsonSchema = (0, zod_to_gemini_parameters_js_1.zodToGeminiParameters)(tool.schema);
785
+ return {
786
+ name: tool.name,
787
+ description: tool.description ?? `A function available to call.`,
788
+ parameters: jsonSchema,
789
+ };
790
+ }
791
+ function structuredToolsToGeminiTools(tools) {
792
+ return [
793
+ {
794
+ functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
795
+ },
796
+ ];
797
+ }
798
+ function formatTools(parameters) {
799
+ const tools = parameters?.tools;
800
+ if (!tools || tools.length === 0) {
801
+ return [];
802
+ }
803
+ if (tools.every(function_calling_1.isLangChainTool)) {
804
+ return structuredToolsToGeminiTools(tools);
805
+ }
806
+ else {
807
+ if (tools.length === 1 &&
808
+ (!("functionDeclarations" in tools[0]) ||
809
+ !tools[0].functionDeclarations?.length)) {
810
+ return [];
811
+ }
812
+ return tools;
813
+ }
814
+ }
815
+ function formatToolConfig(parameters) {
816
+ if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
817
+ return undefined;
818
+ }
819
+ return {
820
+ functionCallingConfig: {
821
+ mode: parameters.tool_choice,
822
+ allowedFunctionNames: parameters.allowed_function_names,
823
+ },
824
+ };
825
+ }
826
+ async function formatData(input, parameters) {
827
+ const typedInput = input;
828
+ const contents = await formatContents(typedInput, parameters);
829
+ const generationConfig = formatGenerationConfig(parameters);
830
+ const tools = formatTools(parameters);
831
+ const toolConfig = formatToolConfig(parameters);
832
+ const safetySettings = formatSafetySettings(parameters);
833
+ const systemInstruction = await formatSystemInstruction(typedInput);
834
+ const ret = {
835
+ contents,
836
+ generationConfig,
837
+ };
838
+ if (tools && tools.length) {
839
+ ret.tools = tools;
840
+ }
841
+ if (toolConfig) {
842
+ ret.toolConfig = toolConfig;
843
+ }
844
+ if (safetySettings && safetySettings.length) {
845
+ ret.safetySettings = safetySettings;
846
+ }
847
+ if (systemInstruction?.role &&
848
+ systemInstruction?.parts &&
849
+ systemInstruction?.parts?.length) {
850
+ ret.systemInstruction = systemInstruction;
851
+ }
852
+ return ret;
571
853
  }
572
854
  return {
573
855
  messageContentToParts,
574
856
  baseMessageToContent,
575
- safeResponseToString,
576
- safeResponseToChatGeneration,
857
+ responseToString: safeResponseToString,
858
+ responseToChatGeneration: safeResponseToChatGeneration,
577
859
  chunkToString,
578
- safeResponseToBaseMessage,
579
- safeResponseToChatResult,
860
+ responseToBaseMessage: safeResponseToBaseMessage,
861
+ responseToChatResult: safeResponseToChatResult,
862
+ formatData,
580
863
  };
581
864
  }
582
865
  exports.getGeminiAPI = getGeminiAPI;
@@ -600,110 +883,3 @@ function isModelGemini(modelName) {
600
883
  return modelName.toLowerCase().startsWith("gemini");
601
884
  }
602
885
  exports.isModelGemini = isModelGemini;
603
- class DefaultGeminiSafetyHandler {
604
- constructor(settings) {
605
- Object.defineProperty(this, "errorFinish", {
606
- enumerable: true,
607
- configurable: true,
608
- writable: true,
609
- value: ["SAFETY", "RECITATION", "OTHER"]
610
- });
611
- this.errorFinish = settings?.errorFinish ?? this.errorFinish;
612
- }
613
- handleDataPromptFeedback(response, data) {
614
- // Check to see if our prompt was blocked in the first place
615
- const promptFeedback = data?.promptFeedback;
616
- const blockReason = promptFeedback?.blockReason;
617
- if (blockReason) {
618
- throw new safety_js_1.GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
619
- }
620
- return data;
621
- }
622
- handleDataFinishReason(response, data) {
623
- const firstCandidate = data?.candidates?.[0];
624
- const finishReason = firstCandidate?.finishReason;
625
- if (this.errorFinish.includes(finishReason)) {
626
- throw new safety_js_1.GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
627
- }
628
- return data;
629
- }
630
- handleData(response, data) {
631
- let ret = data;
632
- ret = this.handleDataPromptFeedback(response, ret);
633
- ret = this.handleDataFinishReason(response, ret);
634
- return ret;
635
- }
636
- handle(response) {
637
- let newdata;
638
- if ("nextChunk" in response.data) {
639
- // TODO: This is a stream. How to handle?
640
- newdata = response.data;
641
- }
642
- else if (Array.isArray(response.data)) {
643
- // If it is an array, try to handle every item in the array
644
- try {
645
- newdata = response.data.map((item) => this.handleData(response, item));
646
- }
647
- catch (xx) {
648
- // eslint-disable-next-line no-instanceof/no-instanceof
649
- if (xx instanceof safety_js_1.GoogleAISafetyError) {
650
- throw new safety_js_1.GoogleAISafetyError(response, xx.message);
651
- }
652
- else {
653
- throw xx;
654
- }
655
- }
656
- }
657
- else {
658
- const data = response.data;
659
- newdata = this.handleData(response, data);
660
- }
661
- return {
662
- ...response,
663
- data: newdata,
664
- };
665
- }
666
- }
667
- exports.DefaultGeminiSafetyHandler = DefaultGeminiSafetyHandler;
668
- class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
669
- constructor(settings) {
670
- super(settings);
671
- Object.defineProperty(this, "msg", {
672
- enumerable: true,
673
- configurable: true,
674
- writable: true,
675
- value: ""
676
- });
677
- Object.defineProperty(this, "forceNewMessage", {
678
- enumerable: true,
679
- configurable: true,
680
- writable: true,
681
- value: false
682
- });
683
- this.msg = settings?.msg ?? this.msg;
684
- this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
685
- }
686
- setMessage(data) {
687
- const ret = data;
688
- if (this.forceNewMessage ||
689
- !data?.candidates?.[0]?.content?.parts?.length) {
690
- ret.candidates = data.candidates ?? [];
691
- ret.candidates[0] = data.candidates[0] ?? {};
692
- ret.candidates[0].content = data.candidates[0].content ?? {};
693
- ret.candidates[0].content = {
694
- role: "model",
695
- parts: [{ text: this.msg }],
696
- };
697
- }
698
- return ret;
699
- }
700
- handleData(response, data) {
701
- try {
702
- return super.handleData(response, data);
703
- }
704
- catch (xx) {
705
- return this.setMessage(data);
706
- }
707
- }
708
- }
709
- exports.MessageGeminiSafetyHandler = MessageGeminiSafetyHandler;
@@ -1,6 +1,4 @@
1
- import { BaseMessage, BaseMessageChunk, MessageContent } from "@langchain/core/messages";
2
- import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
3
- import type { GoogleLLMResponse, GoogleAIModelParams, GeminiPart, GeminiContent, GenerateContentResponseData, GoogleAISafetyHandler, GeminiAPIConfig } from "../types.js";
1
+ import type { GoogleLLMResponse, GoogleAIModelParams, GenerateContentResponseData, GoogleAISafetyHandler, GoogleAIAPI, GeminiAPIConfig } from "../types.js";
4
2
  export interface FunctionCall {
5
3
  name: string;
6
4
  arguments: string;
@@ -19,17 +17,6 @@ export interface ToolCallRaw {
19
17
  type: "function";
20
18
  function: FunctionCallRaw;
21
19
  }
22
- export declare function getGeminiAPI(config?: GeminiAPIConfig): {
23
- messageContentToParts: (content: MessageContent) => Promise<GeminiPart[]>;
24
- baseMessageToContent: (message: BaseMessage, prevMessage: BaseMessage | undefined, useSystemInstruction: boolean) => Promise<GeminiContent[]>;
25
- safeResponseToString: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => string;
26
- safeResponseToChatGeneration: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => ChatGenerationChunk;
27
- chunkToString: (chunk: BaseMessageChunk) => string;
28
- safeResponseToBaseMessage: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => BaseMessage;
29
- safeResponseToChatResult: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => ChatResult;
30
- };
31
- export declare function validateGeminiParams(params: GoogleAIModelParams): void;
32
- export declare function isModelGemini(modelName: string): boolean;
33
20
  export interface DefaultGeminiSafetySettings {
34
21
  errorFinish?: string[];
35
22
  }
@@ -52,3 +39,6 @@ export declare class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandl
52
39
  setMessage(data: GenerateContentResponseData): GenerateContentResponseData;
53
40
  handleData(response: GoogleLLMResponse, data: GenerateContentResponseData): GenerateContentResponseData;
54
41
  }
42
+ export declare function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI;
43
+ export declare function validateGeminiParams(params: GoogleAIModelParams): void;
44
+ export declare function isModelGemini(modelName: string): boolean;