@langchain/google-common 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,119 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.MessageGeminiSafetyHandler = exports.DefaultGeminiSafetyHandler = exports.isModelGemini = exports.validateGeminiParams = exports.getGeminiAPI = void 0;
3
+ exports.isModelGemini = exports.validateGeminiParams = exports.getGeminiAPI = exports.MessageGeminiSafetyHandler = exports.DefaultGeminiSafetyHandler = void 0;
4
4
  const uuid_1 = require("uuid");
5
5
  const messages_1 = require("@langchain/core/messages");
6
6
  const outputs_1 = require("@langchain/core/outputs");
7
+ const function_calling_1 = require("@langchain/core/utils/function_calling");
7
8
  const safety_js_1 = require("./safety.cjs");
9
+ const zod_to_gemini_parameters_js_1 = require("./zod_to_gemini_parameters.cjs");
10
+ class DefaultGeminiSafetyHandler {
11
+ constructor(settings) {
12
+ Object.defineProperty(this, "errorFinish", {
13
+ enumerable: true,
14
+ configurable: true,
15
+ writable: true,
16
+ value: ["SAFETY", "RECITATION", "OTHER"]
17
+ });
18
+ this.errorFinish = settings?.errorFinish ?? this.errorFinish;
19
+ }
20
+ handleDataPromptFeedback(response, data) {
21
+ // Check to see if our prompt was blocked in the first place
22
+ const promptFeedback = data?.promptFeedback;
23
+ const blockReason = promptFeedback?.blockReason;
24
+ if (blockReason) {
25
+ throw new safety_js_1.GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
26
+ }
27
+ return data;
28
+ }
29
+ handleDataFinishReason(response, data) {
30
+ const firstCandidate = data?.candidates?.[0];
31
+ const finishReason = firstCandidate?.finishReason;
32
+ if (this.errorFinish.includes(finishReason)) {
33
+ throw new safety_js_1.GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
34
+ }
35
+ return data;
36
+ }
37
+ handleData(response, data) {
38
+ let ret = data;
39
+ ret = this.handleDataPromptFeedback(response, ret);
40
+ ret = this.handleDataFinishReason(response, ret);
41
+ return ret;
42
+ }
43
+ handle(response) {
44
+ let newdata;
45
+ if ("nextChunk" in response.data) {
46
+ // TODO: This is a stream. How to handle?
47
+ newdata = response.data;
48
+ }
49
+ else if (Array.isArray(response.data)) {
50
+ // If it is an array, try to handle every item in the array
51
+ try {
52
+ newdata = response.data.map((item) => this.handleData(response, item));
53
+ }
54
+ catch (xx) {
55
+ // eslint-disable-next-line no-instanceof/no-instanceof
56
+ if (xx instanceof safety_js_1.GoogleAISafetyError) {
57
+ throw new safety_js_1.GoogleAISafetyError(response, xx.message);
58
+ }
59
+ else {
60
+ throw xx;
61
+ }
62
+ }
63
+ }
64
+ else {
65
+ const data = response.data;
66
+ newdata = this.handleData(response, data);
67
+ }
68
+ return {
69
+ ...response,
70
+ data: newdata,
71
+ };
72
+ }
73
+ }
74
+ exports.DefaultGeminiSafetyHandler = DefaultGeminiSafetyHandler;
75
+ class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
76
+ constructor(settings) {
77
+ super(settings);
78
+ Object.defineProperty(this, "msg", {
79
+ enumerable: true,
80
+ configurable: true,
81
+ writable: true,
82
+ value: ""
83
+ });
84
+ Object.defineProperty(this, "forceNewMessage", {
85
+ enumerable: true,
86
+ configurable: true,
87
+ writable: true,
88
+ value: false
89
+ });
90
+ this.msg = settings?.msg ?? this.msg;
91
+ this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
92
+ }
93
+ setMessage(data) {
94
+ const ret = data;
95
+ if (this.forceNewMessage ||
96
+ !data?.candidates?.[0]?.content?.parts?.length) {
97
+ ret.candidates = data.candidates ?? [];
98
+ ret.candidates[0] = data.candidates[0] ?? {};
99
+ ret.candidates[0].content = data.candidates[0].content ?? {};
100
+ ret.candidates[0].content = {
101
+ role: "model",
102
+ parts: [{ text: this.msg }],
103
+ };
104
+ }
105
+ return ret;
106
+ }
107
+ handleData(response, data) {
108
+ try {
109
+ return super.handleData(response, data);
110
+ }
111
+ catch (xx) {
112
+ return this.setMessage(data);
113
+ }
114
+ }
115
+ }
116
+ exports.MessageGeminiSafetyHandler = MessageGeminiSafetyHandler;
8
117
  const extractMimeType = (str) => {
9
118
  if (str.startsWith("data:")) {
10
119
  return {
@@ -85,7 +194,7 @@ function getGeminiAPI(config) {
85
194
  return await blobToFileData(blob);
86
195
  }
87
196
  }
88
- throw new Error("Invalid media content");
197
+ throw new Error(`Invalid media content: ${JSON.stringify(content, null, 1)}`);
89
198
  }
90
199
  async function messageContentComplexToPart(content) {
91
200
  switch (content.type) {
@@ -103,7 +212,7 @@ function getGeminiAPI(config) {
103
212
  case "media":
104
213
  return await messageContentMedia(content);
105
214
  default:
106
- throw new Error(`Unsupported type received while converting message to message parts`);
215
+ throw new Error(`Unsupported type "${content.type}" received while converting message to message parts: ${content}`);
107
216
  }
108
217
  throw new Error(`Cannot coerce "${content.type}" message part into a string.`);
109
218
  }
@@ -181,8 +290,8 @@ function getGeminiAPI(config) {
181
290
  },
182
291
  ];
183
292
  }
184
- async function systemMessageToContent(message, useSystemInstruction) {
185
- return useSystemInstruction
293
+ async function systemMessageToContent(message) {
294
+ return config?.useSystemInstruction
186
295
  ? roleMessageToContent("system", message)
187
296
  : [
188
297
  ...(await roleMessageToContent("user", message)),
@@ -236,11 +345,11 @@ function getGeminiAPI(config) {
236
345
  ];
237
346
  }
238
347
  }
239
- async function baseMessageToContent(message, prevMessage, useSystemInstruction) {
348
+ async function baseMessageToContent(message, prevMessage) {
240
349
  const type = message._getType();
241
350
  switch (type) {
242
351
  case "system":
243
- return systemMessageToContent(message, useSystemInstruction);
352
+ return systemMessageToContent(message);
244
353
  case "human":
245
354
  return roleMessageToContent("user", message);
246
355
  case "ai":
@@ -378,7 +487,8 @@ function getGeminiAPI(config) {
378
487
  }, "");
379
488
  return ret;
380
489
  }
381
- function safeResponseTo(response, safetyHandler, responseTo) {
490
+ function safeResponseTo(response, responseTo) {
491
+ const safetyHandler = config?.safetyHandler ?? new DefaultGeminiSafetyHandler();
382
492
  try {
383
493
  const safeResponse = safetyHandler.handle(response);
384
494
  return responseTo(safeResponse);
@@ -392,8 +502,8 @@ function getGeminiAPI(config) {
392
502
  throw xx;
393
503
  }
394
504
  }
395
- function safeResponseToString(response, safetyHandler) {
396
- return safeResponseTo(response, safetyHandler, responseToString);
505
+ function safeResponseToString(response) {
506
+ return safeResponseTo(response, responseToString);
397
507
  }
398
508
  function responseToGenerationInfo(response) {
399
509
  if (!Array.isArray(response.data)) {
@@ -423,8 +533,8 @@ function getGeminiAPI(config) {
423
533
  generationInfo: responseToGenerationInfo(response),
424
534
  });
425
535
  }
426
- function safeResponseToChatGeneration(response, safetyHandler) {
427
- return safeResponseTo(response, safetyHandler, responseToChatGeneration);
536
+ function safeResponseToChatGeneration(response) {
537
+ return safeResponseTo(response, responseToChatGeneration);
428
538
  }
429
539
  function chunkToString(chunk) {
430
540
  if (chunk === null) {
@@ -469,6 +579,9 @@ function getGeminiAPI(config) {
469
579
  }
470
580
  function responseToChatGenerations(response) {
471
581
  const parts = responseToParts(response);
582
+ if (parts.length === 0) {
583
+ return [];
584
+ }
472
585
  let ret = parts.map((part) => partToChatGeneration(part));
473
586
  if (ret.every((item) => typeof item.message.content === "string")) {
474
587
  const combinedContent = ret.map((item) => item.message.content).join("");
@@ -553,8 +666,8 @@ function getGeminiAPI(config) {
553
666
  const fields = responseToBaseMessageFields(response);
554
667
  return new messages_1.AIMessage(fields);
555
668
  }
556
- function safeResponseToBaseMessage(response, safetyHandler) {
557
- return safeResponseTo(response, safetyHandler, responseToBaseMessage);
669
+ function safeResponseToBaseMessage(response) {
670
+ return safeResponseTo(response, responseToBaseMessage);
558
671
  }
559
672
  function responseToChatResult(response) {
560
673
  const generations = responseToChatGenerations(response);
@@ -563,17 +676,190 @@ function getGeminiAPI(config) {
563
676
  llmOutput: responseToGenerationInfo(response),
564
677
  };
565
678
  }
566
- function safeResponseToChatResult(response, safetyHandler) {
567
- return safeResponseTo(response, safetyHandler, responseToChatResult);
679
+ function safeResponseToChatResult(response) {
680
+ return safeResponseTo(response, responseToChatResult);
681
+ }
682
+ function inputType(input) {
683
+ if (typeof input === "string") {
684
+ return "MessageContent";
685
+ }
686
+ else {
687
+ const firstItem = input[0];
688
+ if (Object.hasOwn(firstItem, "content")) {
689
+ return "BaseMessageArray";
690
+ }
691
+ else {
692
+ return "MessageContent";
693
+ }
694
+ }
695
+ }
696
+ async function formatMessageContents(input, _parameters) {
697
+ const parts = await messageContentToParts(input);
698
+ const contents = [
699
+ {
700
+ role: "user",
701
+ parts,
702
+ },
703
+ ];
704
+ return contents;
705
+ }
706
+ async function formatBaseMessageContents(input, _parameters) {
707
+ const inputPromises = input.map((msg, i) => baseMessageToContent(msg, input[i - 1]));
708
+ const inputs = await Promise.all(inputPromises);
709
+ return inputs.reduce((acc, cur) => {
710
+ // Filter out the system content
711
+ if (cur.every((content) => content.role === "system")) {
712
+ return acc;
713
+ }
714
+ // Combine adjacent function messages
715
+ if (cur[0]?.role === "function" &&
716
+ acc.length > 0 &&
717
+ acc[acc.length - 1].role === "function") {
718
+ acc[acc.length - 1].parts = [
719
+ ...acc[acc.length - 1].parts,
720
+ ...cur[0].parts,
721
+ ];
722
+ }
723
+ else {
724
+ acc.push(...cur);
725
+ }
726
+ return acc;
727
+ }, []);
728
+ }
729
+ async function formatContents(input, parameters) {
730
+ const it = inputType(input);
731
+ switch (it) {
732
+ case "MessageContent":
733
+ return formatMessageContents(input, parameters);
734
+ case "BaseMessageArray":
735
+ return formatBaseMessageContents(input, parameters);
736
+ default:
737
+ throw new Error(`Unknown input type "${it}": ${input}`);
738
+ }
739
+ }
740
+ function formatGenerationConfig(parameters) {
741
+ return {
742
+ temperature: parameters.temperature,
743
+ topK: parameters.topK,
744
+ topP: parameters.topP,
745
+ maxOutputTokens: parameters.maxOutputTokens,
746
+ stopSequences: parameters.stopSequences,
747
+ responseMimeType: parameters.responseMimeType,
748
+ };
749
+ }
750
+ function formatSafetySettings(parameters) {
751
+ return parameters.safetySettings ?? [];
752
+ }
753
+ async function formatBaseMessageSystemInstruction(input) {
754
+ let ret = {};
755
+ for (let index = 0; index < input.length; index += 1) {
756
+ const message = input[index];
757
+ if (message._getType() === "system") {
758
+ // For system types, we only want it if it is the first message,
759
+ // if it appears anywhere else, it should be an error.
760
+ if (index === 0) {
761
+ // eslint-disable-next-line prefer-destructuring
762
+ ret = (await baseMessageToContent(message, undefined))[0];
763
+ }
764
+ else {
765
+ throw new Error("System messages are only permitted as the first passed message.");
766
+ }
767
+ }
768
+ }
769
+ return ret;
770
+ }
771
+ async function formatSystemInstruction(input) {
772
+ if (!config?.useSystemInstruction) {
773
+ return {};
774
+ }
775
+ const it = inputType(input);
776
+ switch (it) {
777
+ case "BaseMessageArray":
778
+ return formatBaseMessageSystemInstruction(input);
779
+ default:
780
+ return {};
781
+ }
782
+ }
783
+ function structuredToolToFunctionDeclaration(tool) {
784
+ const jsonSchema = (0, zod_to_gemini_parameters_js_1.zodToGeminiParameters)(tool.schema);
785
+ return {
786
+ name: tool.name,
787
+ description: tool.description ?? `A function available to call.`,
788
+ parameters: jsonSchema,
789
+ };
790
+ }
791
+ function structuredToolsToGeminiTools(tools) {
792
+ return [
793
+ {
794
+ functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
795
+ },
796
+ ];
797
+ }
798
+ function formatTools(parameters) {
799
+ const tools = parameters?.tools;
800
+ if (!tools || tools.length === 0) {
801
+ return [];
802
+ }
803
+ if (tools.every(function_calling_1.isLangChainTool)) {
804
+ return structuredToolsToGeminiTools(tools);
805
+ }
806
+ else {
807
+ if (tools.length === 1 &&
808
+ (!("functionDeclarations" in tools[0]) ||
809
+ !tools[0].functionDeclarations?.length)) {
810
+ return [];
811
+ }
812
+ return tools;
813
+ }
814
+ }
815
+ function formatToolConfig(parameters) {
816
+ if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
817
+ return undefined;
818
+ }
819
+ return {
820
+ functionCallingConfig: {
821
+ mode: parameters.tool_choice,
822
+ allowedFunctionNames: parameters.allowed_function_names,
823
+ },
824
+ };
825
+ }
826
+ async function formatData(input, parameters) {
827
+ const typedInput = input;
828
+ const contents = await formatContents(typedInput, parameters);
829
+ const generationConfig = formatGenerationConfig(parameters);
830
+ const tools = formatTools(parameters);
831
+ const toolConfig = formatToolConfig(parameters);
832
+ const safetySettings = formatSafetySettings(parameters);
833
+ const systemInstruction = await formatSystemInstruction(typedInput);
834
+ const ret = {
835
+ contents,
836
+ generationConfig,
837
+ };
838
+ if (tools && tools.length) {
839
+ ret.tools = tools;
840
+ }
841
+ if (toolConfig) {
842
+ ret.toolConfig = toolConfig;
843
+ }
844
+ if (safetySettings && safetySettings.length) {
845
+ ret.safetySettings = safetySettings;
846
+ }
847
+ if (systemInstruction?.role &&
848
+ systemInstruction?.parts &&
849
+ systemInstruction?.parts?.length) {
850
+ ret.systemInstruction = systemInstruction;
851
+ }
852
+ return ret;
568
853
  }
569
854
  return {
570
855
  messageContentToParts,
571
856
  baseMessageToContent,
572
- safeResponseToString,
573
- safeResponseToChatGeneration,
857
+ responseToString: safeResponseToString,
858
+ responseToChatGeneration: safeResponseToChatGeneration,
574
859
  chunkToString,
575
- safeResponseToBaseMessage,
576
- safeResponseToChatResult,
860
+ responseToBaseMessage: safeResponseToBaseMessage,
861
+ responseToChatResult: safeResponseToChatResult,
862
+ formatData,
577
863
  };
578
864
  }
579
865
  exports.getGeminiAPI = getGeminiAPI;
@@ -597,110 +883,3 @@ function isModelGemini(modelName) {
597
883
  return modelName.toLowerCase().startsWith("gemini");
598
884
  }
599
885
  exports.isModelGemini = isModelGemini;
600
- class DefaultGeminiSafetyHandler {
601
- constructor(settings) {
602
- Object.defineProperty(this, "errorFinish", {
603
- enumerable: true,
604
- configurable: true,
605
- writable: true,
606
- value: ["SAFETY", "RECITATION", "OTHER"]
607
- });
608
- this.errorFinish = settings?.errorFinish ?? this.errorFinish;
609
- }
610
- handleDataPromptFeedback(response, data) {
611
- // Check to see if our prompt was blocked in the first place
612
- const promptFeedback = data?.promptFeedback;
613
- const blockReason = promptFeedback?.blockReason;
614
- if (blockReason) {
615
- throw new safety_js_1.GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`);
616
- }
617
- return data;
618
- }
619
- handleDataFinishReason(response, data) {
620
- const firstCandidate = data?.candidates?.[0];
621
- const finishReason = firstCandidate?.finishReason;
622
- if (this.errorFinish.includes(finishReason)) {
623
- throw new safety_js_1.GoogleAISafetyError(response, `Finish reason: ${finishReason}`);
624
- }
625
- return data;
626
- }
627
- handleData(response, data) {
628
- let ret = data;
629
- ret = this.handleDataPromptFeedback(response, ret);
630
- ret = this.handleDataFinishReason(response, ret);
631
- return ret;
632
- }
633
- handle(response) {
634
- let newdata;
635
- if ("nextChunk" in response.data) {
636
- // TODO: This is a stream. How to handle?
637
- newdata = response.data;
638
- }
639
- else if (Array.isArray(response.data)) {
640
- // If it is an array, try to handle every item in the array
641
- try {
642
- newdata = response.data.map((item) => this.handleData(response, item));
643
- }
644
- catch (xx) {
645
- // eslint-disable-next-line no-instanceof/no-instanceof
646
- if (xx instanceof safety_js_1.GoogleAISafetyError) {
647
- throw new safety_js_1.GoogleAISafetyError(response, xx.message);
648
- }
649
- else {
650
- throw xx;
651
- }
652
- }
653
- }
654
- else {
655
- const data = response.data;
656
- newdata = this.handleData(response, data);
657
- }
658
- return {
659
- ...response,
660
- data: newdata,
661
- };
662
- }
663
- }
664
- exports.DefaultGeminiSafetyHandler = DefaultGeminiSafetyHandler;
665
- class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler {
666
- constructor(settings) {
667
- super(settings);
668
- Object.defineProperty(this, "msg", {
669
- enumerable: true,
670
- configurable: true,
671
- writable: true,
672
- value: ""
673
- });
674
- Object.defineProperty(this, "forceNewMessage", {
675
- enumerable: true,
676
- configurable: true,
677
- writable: true,
678
- value: false
679
- });
680
- this.msg = settings?.msg ?? this.msg;
681
- this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage;
682
- }
683
- setMessage(data) {
684
- const ret = data;
685
- if (this.forceNewMessage ||
686
- !data?.candidates?.[0]?.content?.parts?.length) {
687
- ret.candidates = data.candidates ?? [];
688
- ret.candidates[0] = data.candidates[0] ?? {};
689
- ret.candidates[0].content = data.candidates[0].content ?? {};
690
- ret.candidates[0].content = {
691
- role: "model",
692
- parts: [{ text: this.msg }],
693
- };
694
- }
695
- return ret;
696
- }
697
- handleData(response, data) {
698
- try {
699
- return super.handleData(response, data);
700
- }
701
- catch (xx) {
702
- return this.setMessage(data);
703
- }
704
- }
705
- }
706
- exports.MessageGeminiSafetyHandler = MessageGeminiSafetyHandler;
@@ -1,6 +1,4 @@
1
- import { BaseMessage, BaseMessageChunk, MessageContent } from "@langchain/core/messages";
2
- import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
3
- import type { GoogleLLMResponse, GoogleAIModelParams, GeminiPart, GeminiContent, GenerateContentResponseData, GoogleAISafetyHandler, GeminiAPIConfig } from "../types.js";
1
+ import type { GoogleLLMResponse, GoogleAIModelParams, GenerateContentResponseData, GoogleAISafetyHandler, GoogleAIAPI, GeminiAPIConfig } from "../types.js";
4
2
  export interface FunctionCall {
5
3
  name: string;
6
4
  arguments: string;
@@ -19,17 +17,6 @@ export interface ToolCallRaw {
19
17
  type: "function";
20
18
  function: FunctionCallRaw;
21
19
  }
22
- export declare function getGeminiAPI(config?: GeminiAPIConfig): {
23
- messageContentToParts: (content: MessageContent) => Promise<GeminiPart[]>;
24
- baseMessageToContent: (message: BaseMessage, prevMessage: BaseMessage | undefined, useSystemInstruction: boolean) => Promise<GeminiContent[]>;
25
- safeResponseToString: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => string;
26
- safeResponseToChatGeneration: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => ChatGenerationChunk;
27
- chunkToString: (chunk: BaseMessageChunk) => string;
28
- safeResponseToBaseMessage: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => BaseMessage;
29
- safeResponseToChatResult: (response: GoogleLLMResponse, safetyHandler: GoogleAISafetyHandler) => ChatResult;
30
- };
31
- export declare function validateGeminiParams(params: GoogleAIModelParams): void;
32
- export declare function isModelGemini(modelName: string): boolean;
33
20
  export interface DefaultGeminiSafetySettings {
34
21
  errorFinish?: string[];
35
22
  }
@@ -52,3 +39,6 @@ export declare class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandl
52
39
  setMessage(data: GenerateContentResponseData): GenerateContentResponseData;
53
40
  handleData(response: GoogleLLMResponse, data: GenerateContentResponseData): GenerateContentResponseData;
54
41
  }
42
+ export declare function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI;
43
+ export declare function validateGeminiParams(params: GoogleAIModelParams): void;
44
+ export declare function isModelGemini(modelName: string): boolean;