@langchain/google-common 0.0.0 → 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ const common_js_1 = require("./utils/common.cjs");
9
9
  const connection_js_1 = require("./connection.cjs");
10
10
  const gemini_js_1 = require("./utils/gemini.cjs");
11
11
  const auth_js_1 = require("./auth.cjs");
12
+ const failed_handler_js_1 = require("./utils/failed_handler.cjs");
12
13
  class ChatConnection extends connection_js_1.AbstractGoogleLLMConnection {
13
14
  formatContents(input, _parameters) {
14
15
  return input
@@ -25,7 +26,7 @@ class ChatGoogleBase extends chat_models_1.BaseChatModel {
25
26
  return "ChatGoogle";
26
27
  }
27
28
  constructor(fields) {
28
- super(fields ?? {});
29
+ super((0, failed_handler_js_1.ensureParams)(fields));
29
30
  Object.defineProperty(this, "lc_serializable", {
30
31
  enumerable: true,
31
32
  configurable: true,
@@ -74,6 +75,12 @@ class ChatGoogleBase extends chat_models_1.BaseChatModel {
74
75
  writable: true,
75
76
  value: []
76
77
  });
78
+ Object.defineProperty(this, "safetyHandler", {
79
+ enumerable: true,
80
+ configurable: true,
81
+ writable: true,
82
+ value: void 0
83
+ });
77
84
  Object.defineProperty(this, "connection", {
78
85
  enumerable: true,
79
86
  configurable: true,
@@ -87,6 +94,8 @@ class ChatGoogleBase extends chat_models_1.BaseChatModel {
87
94
  value: void 0
88
95
  });
89
96
  (0, common_js_1.copyAndValidateModelParamsInto)(fields, this);
97
+ this.safetyHandler =
98
+ fields?.safetyHandler ?? new gemini_js_1.DefaultGeminiSafetyHandler();
90
99
  const client = this.buildClient(fields);
91
100
  this.buildConnection(fields ?? {}, client);
92
101
  }
@@ -119,7 +128,7 @@ class ChatGoogleBase extends chat_models_1.BaseChatModel {
119
128
  async _generate(messages, options, _runManager) {
120
129
  const parameters = (0, common_js_1.copyAIModelParams)(this);
121
130
  const response = await this.connection.request(messages, parameters, options);
122
- const ret = (0, gemini_js_1.responseToChatResult)(response);
131
+ const ret = (0, gemini_js_1.safeResponseToChatResult)(response, this.safetyHandler);
123
132
  return ret;
124
133
  }
125
134
  async *_streamResponseChunks(_messages, _options, _runManager) {
@@ -134,7 +143,7 @@ class ChatGoogleBase extends chat_models_1.BaseChatModel {
134
143
  while (!stream.streamDone) {
135
144
  const output = await stream.nextChunk();
136
145
  const chunk = output !== null
137
- ? (0, gemini_js_1.responseToChatGeneration)({ data: output })
146
+ ? (0, gemini_js_1.safeResponseToChatGeneration)({ data: output }, this.safetyHandler)
138
147
  : new outputs_1.ChatGenerationChunk({
139
148
  text: "",
140
149
  generationInfo: { finishReason: "stop" },
@@ -6,14 +6,14 @@ import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs";
6
6
  import { GoogleAIBaseLLMInput, GoogleAIModelParams, GoogleAISafetySetting, GoogleConnectionParams, GooglePlatformType, GeminiContent } from "./types.js";
7
7
  import { AbstractGoogleLLMConnection } from "./connection.js";
8
8
  import { GoogleAbstractedClient } from "./auth.js";
9
- import { GoogleBaseLLMInput } from "./llms.js";
9
+ import type { GoogleBaseLLMInput, GoogleAISafetyHandler, GoogleAISafetyParams } from "./types.js";
10
10
  declare class ChatConnection<AuthOptions> extends AbstractGoogleLLMConnection<BaseMessage[], AuthOptions> {
11
11
  formatContents(input: BaseMessage[], _parameters: GoogleAIModelParams): GeminiContent[];
12
12
  }
13
13
  /**
14
14
  * Input to chat model class.
15
15
  */
16
- export interface ChatGoogleBaseInput<AuthOptions> extends BaseChatModelParams, GoogleConnectionParams<AuthOptions>, GoogleAIModelParams {
16
+ export interface ChatGoogleBaseInput<AuthOptions> extends BaseChatModelParams, GoogleConnectionParams<AuthOptions>, GoogleAIModelParams, GoogleAISafetyParams {
17
17
  }
18
18
  /**
19
19
  * Integration with a chat model.
@@ -28,6 +28,7 @@ export declare abstract class ChatGoogleBase<AuthOptions> extends BaseChatModel<
28
28
  topK: number;
29
29
  stopSequences: string[];
30
30
  safetySettings: GoogleAISafetySetting[];
31
+ safetyHandler: GoogleAISafetyHandler;
31
32
  protected connection: ChatConnection<AuthOptions>;
32
33
  protected streamedConnection: ChatConnection<AuthOptions>;
33
34
  constructor(fields?: ChatGoogleBaseInput<AuthOptions>);
@@ -4,8 +4,9 @@ import { ChatGenerationChunk } from "@langchain/core/outputs";
4
4
  import { AIMessageChunk } from "@langchain/core/messages";
5
5
  import { copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js";
6
6
  import { AbstractGoogleLLMConnection } from "./connection.js";
7
- import { baseMessageToContent, responseToChatGeneration, responseToChatResult, } from "./utils/gemini.js";
7
+ import { baseMessageToContent, safeResponseToChatGeneration, safeResponseToChatResult, DefaultGeminiSafetyHandler, } from "./utils/gemini.js";
8
8
  import { ApiKeyGoogleAuth } from "./auth.js";
9
+ import { ensureParams } from "./utils/failed_handler.js";
9
10
  class ChatConnection extends AbstractGoogleLLMConnection {
10
11
  formatContents(input, _parameters) {
11
12
  return input
@@ -22,7 +23,7 @@ export class ChatGoogleBase extends BaseChatModel {
22
23
  return "ChatGoogle";
23
24
  }
24
25
  constructor(fields) {
25
- super(fields ?? {});
26
+ super(ensureParams(fields));
26
27
  Object.defineProperty(this, "lc_serializable", {
27
28
  enumerable: true,
28
29
  configurable: true,
@@ -71,6 +72,12 @@ export class ChatGoogleBase extends BaseChatModel {
71
72
  writable: true,
72
73
  value: []
73
74
  });
75
+ Object.defineProperty(this, "safetyHandler", {
76
+ enumerable: true,
77
+ configurable: true,
78
+ writable: true,
79
+ value: void 0
80
+ });
74
81
  Object.defineProperty(this, "connection", {
75
82
  enumerable: true,
76
83
  configurable: true,
@@ -84,6 +91,8 @@ export class ChatGoogleBase extends BaseChatModel {
84
91
  value: void 0
85
92
  });
86
93
  copyAndValidateModelParamsInto(fields, this);
94
+ this.safetyHandler =
95
+ fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();
87
96
  const client = this.buildClient(fields);
88
97
  this.buildConnection(fields ?? {}, client);
89
98
  }
@@ -116,7 +125,7 @@ export class ChatGoogleBase extends BaseChatModel {
116
125
  async _generate(messages, options, _runManager) {
117
126
  const parameters = copyAIModelParams(this);
118
127
  const response = await this.connection.request(messages, parameters, options);
119
- const ret = responseToChatResult(response);
128
+ const ret = safeResponseToChatResult(response, this.safetyHandler);
120
129
  return ret;
121
130
  }
122
131
  async *_streamResponseChunks(_messages, _options, _runManager) {
@@ -131,7 +140,7 @@ export class ChatGoogleBase extends BaseChatModel {
131
140
  while (!stream.streamDone) {
132
141
  const output = await stream.nextChunk();
133
142
  const chunk = output !== null
134
- ? responseToChatGeneration({ data: output })
143
+ ? safeResponseToChatGeneration({ data: output }, this.safetyHandler)
135
144
  : new ChatGenerationChunk({
136
145
  text: "",
137
146
  generationInfo: { finishReason: "stop" },
package/dist/llms.cjs CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.GoogleBaseLLM = void 0;
4
+ const manager_1 = require("@langchain/core/callbacks/manager");
4
5
  const llms_1 = require("@langchain/core/language_models/llms");
5
6
  const outputs_1 = require("@langchain/core/outputs");
6
7
  const env_1 = require("@langchain/core/utils/env");
@@ -8,6 +9,8 @@ const connection_js_1 = require("./connection.cjs");
8
9
  const common_js_1 = require("./utils/common.cjs");
9
10
  const gemini_js_1 = require("./utils/gemini.cjs");
10
11
  const auth_js_1 = require("./auth.cjs");
12
+ const failed_handler_js_1 = require("./utils/failed_handler.cjs");
13
+ const chat_models_js_1 = require("./chat_models.cjs");
11
14
  class GoogleLLMConnection extends connection_js_1.AbstractGoogleLLMConnection {
12
15
  formatContents(input, _parameters) {
13
16
  const parts = (0, gemini_js_1.messageContentToParts)(input);
@@ -20,6 +23,14 @@ class GoogleLLMConnection extends connection_js_1.AbstractGoogleLLMConnection {
20
23
  return contents;
21
24
  }
22
25
  }
26
+ class ProxyChatGoogle extends chat_models_js_1.ChatGoogleBase {
27
+ constructor(fields) {
28
+ super(fields);
29
+ }
30
+ buildAbstractedClient(fields) {
31
+ return fields.connection.client;
32
+ }
33
+ }
23
34
  /**
24
35
  * Integration with an LLM.
25
36
  */
@@ -29,7 +40,13 @@ class GoogleBaseLLM extends llms_1.LLM {
29
40
  return "GoogleLLM";
30
41
  }
31
42
  constructor(fields) {
32
- super(fields ?? {});
43
+ super((0, failed_handler_js_1.ensureParams)(fields));
44
+ Object.defineProperty(this, "originalFields", {
45
+ enumerable: true,
46
+ configurable: true,
47
+ writable: true,
48
+ value: void 0
49
+ });
33
50
  Object.defineProperty(this, "lc_serializable", {
34
51
  enumerable: true,
35
52
  configurable: true,
@@ -78,6 +95,12 @@ class GoogleBaseLLM extends llms_1.LLM {
78
95
  writable: true,
79
96
  value: []
80
97
  });
98
+ Object.defineProperty(this, "safetyHandler", {
99
+ enumerable: true,
100
+ configurable: true,
101
+ writable: true,
102
+ value: void 0
103
+ });
81
104
  Object.defineProperty(this, "connection", {
82
105
  enumerable: true,
83
106
  configurable: true,
@@ -90,7 +113,10 @@ class GoogleBaseLLM extends llms_1.LLM {
90
113
  writable: true,
91
114
  value: void 0
92
115
  });
116
+ this.originalFields = fields;
93
117
  (0, common_js_1.copyAndValidateModelParamsInto)(fields, this);
118
+ this.safetyHandler =
119
+ fields?.safetyHandler ?? new gemini_js_1.DefaultGeminiSafetyHandler();
94
120
  const client = this.buildClient(fields);
95
121
  this.buildConnection(fields ?? {}, client);
96
122
  }
@@ -125,38 +151,79 @@ class GoogleBaseLLM extends llms_1.LLM {
125
151
  }
126
152
  /**
127
153
  * For some given input string and options, return a string output.
154
+ *
155
+ * Despite the fact that `invoke` is overridden below, we still need this
156
+ * in order to handle public APi calls to `generate()`.
128
157
  */
129
- async _call(_prompt, _options, _runManager) {
158
+ async _call(prompt, options) {
130
159
  const parameters = (0, common_js_1.copyAIModelParams)(this);
131
- const result = await this.connection.request(_prompt, parameters, _options);
132
- const ret = (0, gemini_js_1.responseToString)(result);
160
+ const result = await this.connection.request(prompt, parameters, options);
161
+ const ret = (0, gemini_js_1.safeResponseToString)(result, this.safetyHandler);
133
162
  return ret;
134
163
  }
135
- async *_streamResponseChunks(_prompt, _options, _runManager) {
136
- // Make the call as a streaming request
137
- const parameters = (0, common_js_1.copyAIModelParams)(this);
138
- const result = await this.streamedConnection.request(_prompt, parameters, _options);
139
- // Get the streaming parser of the response
140
- const stream = result.data;
141
- // Loop until the end of the stream
142
- // During the loop, yield each time we get a chunk from the streaming parser
143
- // that is either available or added to the queue
144
- while (!stream.streamDone) {
145
- const output = await stream.nextChunk();
146
- const chunk = output !== null
147
- ? new outputs_1.GenerationChunk((0, gemini_js_1.responseToGeneration)({ data: output }))
148
- : new outputs_1.GenerationChunk({
149
- text: "",
150
- generationInfo: { finishReason: "stop" },
164
+ // Normally, you should not override this method and instead should override
165
+ // _streamResponseChunks. We are doing so here to allow for multimodal inputs into
166
+ // the LLM.
167
+ async *_streamIterator(input, options) {
168
+ // TODO: Refactor callback setup and teardown code into core
169
+ const prompt = llms_1.BaseLLM._convertInputToPromptValue(input);
170
+ const [runnableConfig, callOptions] = this._separateRunnableConfigFromCallOptions(options);
171
+ const callbackManager_ = await manager_1.CallbackManager.configure(runnableConfig.callbacks, this.callbacks, runnableConfig.tags, this.tags, runnableConfig.metadata, this.metadata, { verbose: this.verbose });
172
+ const extra = {
173
+ options: callOptions,
174
+ invocation_params: this?.invocationParams(callOptions),
175
+ batch_size: 1,
176
+ };
177
+ const runManagers = await callbackManager_?.handleLLMStart(this.toJSON(), [prompt.toString()], undefined, undefined, extra, undefined, undefined, runnableConfig.runName);
178
+ let generation = new outputs_1.GenerationChunk({
179
+ text: "",
180
+ });
181
+ const proxyChat = this.createProxyChat();
182
+ try {
183
+ for await (const chunk of proxyChat._streamIterator(input, options)) {
184
+ const stringValue = (0, gemini_js_1.chunkToString)(chunk);
185
+ const generationChunk = new outputs_1.GenerationChunk({
186
+ text: stringValue,
151
187
  });
152
- yield chunk;
188
+ generation = generation.concat(generationChunk);
189
+ yield stringValue;
190
+ }
191
+ }
192
+ catch (err) {
193
+ await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMError(err)));
194
+ throw err;
153
195
  }
196
+ await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMEnd({
197
+ generations: [[generation]],
198
+ })));
154
199
  }
155
200
  async predictMessages(messages, options, _callbacks) {
156
201
  const { content } = messages[0];
157
202
  const result = await this.connection.request(content, {}, options);
158
- const ret = (0, gemini_js_1.responseToBaseMessage)(result);
203
+ const ret = (0, gemini_js_1.safeResponseToBaseMessage)(result, this.safetyHandler);
159
204
  return ret;
160
205
  }
206
+ /**
207
+ * Internal implementation detail to allow Google LLMs to support
208
+ * multimodal input by delegating to the chat model implementation.
209
+ *
210
+ * TODO: Replace with something less hacky.
211
+ */
212
+ createProxyChat() {
213
+ return new ProxyChatGoogle({
214
+ ...this.originalFields,
215
+ connection: this.connection,
216
+ });
217
+ }
218
+ // TODO: Remove the need to override this - we are doing it to
219
+ // allow the LLM to handle multimodal types of input.
220
+ async invoke(input, options) {
221
+ const stream = await this._streamIterator(input, options);
222
+ let generatedOutput = "";
223
+ for await (const chunk of stream) {
224
+ generatedOutput += chunk;
225
+ }
226
+ return generatedOutput;
227
+ }
161
228
  }
162
229
  exports.GoogleBaseLLM = GoogleBaseLLM;
package/dist/llms.d.ts CHANGED
@@ -1,24 +1,22 @@
1
- import { CallbackManagerForLLMRun, Callbacks } from "@langchain/core/callbacks/manager";
1
+ import { Callbacks } from "@langchain/core/callbacks/manager";
2
2
  import { LLM } from "@langchain/core/language_models/llms";
3
- import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
3
+ import { type BaseLanguageModelCallOptions, BaseLanguageModelInput } from "@langchain/core/language_models/base";
4
4
  import { BaseMessage, MessageContent } from "@langchain/core/messages";
5
- import { GenerationChunk } from "@langchain/core/outputs";
6
5
  import { AbstractGoogleLLMConnection } from "./connection.js";
7
6
  import { GoogleAIBaseLLMInput, GoogleAIModelParams, GoogleAISafetySetting, GooglePlatformType, GeminiContent } from "./types.js";
8
7
  import { GoogleAbstractedClient } from "./auth.js";
8
+ import { ChatGoogleBase } from "./chat_models.js";
9
+ import type { GoogleBaseLLMInput, GoogleAISafetyHandler } from "./types.js";
10
+ export { GoogleBaseLLMInput };
9
11
  declare class GoogleLLMConnection<AuthOptions> extends AbstractGoogleLLMConnection<MessageContent, AuthOptions> {
10
12
  formatContents(input: MessageContent, _parameters: GoogleAIModelParams): GeminiContent[];
11
13
  }
12
- /**
13
- * Input to LLM class.
14
- */
15
- export interface GoogleBaseLLMInput<AuthOptions> extends GoogleAIBaseLLMInput<AuthOptions> {
16
- }
17
14
  /**
18
15
  * Integration with an LLM.
19
16
  */
20
17
  export declare abstract class GoogleBaseLLM<AuthOptions> extends LLM<BaseLanguageModelCallOptions> implements GoogleBaseLLMInput<AuthOptions> {
21
18
  static lc_name(): string;
19
+ originalFields?: GoogleBaseLLMInput<AuthOptions>;
22
20
  lc_serializable: boolean;
23
21
  model: string;
24
22
  temperature: number;
@@ -27,6 +25,7 @@ export declare abstract class GoogleBaseLLM<AuthOptions> extends LLM<BaseLanguag
27
25
  topK: number;
28
26
  stopSequences: string[];
29
27
  safetySettings: GoogleAISafetySetting[];
28
+ safetyHandler: GoogleAISafetyHandler;
30
29
  protected connection: GoogleLLMConnection<AuthOptions>;
31
30
  protected streamedConnection: GoogleLLMConnection<AuthOptions>;
32
31
  constructor(fields?: GoogleBaseLLMInput<AuthOptions>);
@@ -40,9 +39,19 @@ export declare abstract class GoogleBaseLLM<AuthOptions> extends LLM<BaseLanguag
40
39
  formatPrompt(prompt: string): MessageContent;
41
40
  /**
42
41
  * For some given input string and options, return a string output.
42
+ *
43
+ * Despite the fact that `invoke` is overridden below, we still need this
44
+ * in order to handle public APi calls to `generate()`.
43
45
  */
44
- _call(_prompt: string, _options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun): Promise<string>;
45
- _streamResponseChunks(_prompt: string, _options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun): AsyncGenerator<GenerationChunk>;
46
+ _call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
47
+ _streamIterator(input: BaseLanguageModelInput, options?: BaseLanguageModelCallOptions): AsyncGenerator<string>;
46
48
  predictMessages(messages: BaseMessage[], options?: string[] | BaseLanguageModelCallOptions, _callbacks?: Callbacks): Promise<BaseMessage>;
49
+ /**
50
+ * Internal implementation detail to allow Google LLMs to support
51
+ * multimodal input by delegating to the chat model implementation.
52
+ *
53
+ * TODO: Replace with something less hacky.
54
+ */
55
+ protected createProxyChat(): ChatGoogleBase<AuthOptions>;
56
+ invoke(input: BaseLanguageModelInput, options?: BaseLanguageModelCallOptions): Promise<string>;
47
57
  }
48
- export {};
package/dist/llms.js CHANGED
@@ -1,10 +1,13 @@
1
- import { LLM } from "@langchain/core/language_models/llms";
1
+ import { CallbackManager } from "@langchain/core/callbacks/manager";
2
+ import { BaseLLM, LLM } from "@langchain/core/language_models/llms";
2
3
  import { GenerationChunk } from "@langchain/core/outputs";
3
4
  import { getEnvironmentVariable } from "@langchain/core/utils/env";
4
5
  import { AbstractGoogleLLMConnection } from "./connection.js";
5
6
  import { copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js";
6
- import { messageContentToParts, responseToBaseMessage, responseToGeneration, responseToString, } from "./utils/gemini.js";
7
+ import { chunkToString, messageContentToParts, safeResponseToBaseMessage, safeResponseToString, DefaultGeminiSafetyHandler, } from "./utils/gemini.js";
7
8
  import { ApiKeyGoogleAuth } from "./auth.js";
9
+ import { ensureParams } from "./utils/failed_handler.js";
10
+ import { ChatGoogleBase } from "./chat_models.js";
8
11
  class GoogleLLMConnection extends AbstractGoogleLLMConnection {
9
12
  formatContents(input, _parameters) {
10
13
  const parts = messageContentToParts(input);
@@ -17,6 +20,14 @@ class GoogleLLMConnection extends AbstractGoogleLLMConnection {
17
20
  return contents;
18
21
  }
19
22
  }
23
+ class ProxyChatGoogle extends ChatGoogleBase {
24
+ constructor(fields) {
25
+ super(fields);
26
+ }
27
+ buildAbstractedClient(fields) {
28
+ return fields.connection.client;
29
+ }
30
+ }
20
31
  /**
21
32
  * Integration with an LLM.
22
33
  */
@@ -26,7 +37,13 @@ export class GoogleBaseLLM extends LLM {
26
37
  return "GoogleLLM";
27
38
  }
28
39
  constructor(fields) {
29
- super(fields ?? {});
40
+ super(ensureParams(fields));
41
+ Object.defineProperty(this, "originalFields", {
42
+ enumerable: true,
43
+ configurable: true,
44
+ writable: true,
45
+ value: void 0
46
+ });
30
47
  Object.defineProperty(this, "lc_serializable", {
31
48
  enumerable: true,
32
49
  configurable: true,
@@ -75,6 +92,12 @@ export class GoogleBaseLLM extends LLM {
75
92
  writable: true,
76
93
  value: []
77
94
  });
95
+ Object.defineProperty(this, "safetyHandler", {
96
+ enumerable: true,
97
+ configurable: true,
98
+ writable: true,
99
+ value: void 0
100
+ });
78
101
  Object.defineProperty(this, "connection", {
79
102
  enumerable: true,
80
103
  configurable: true,
@@ -87,7 +110,10 @@ export class GoogleBaseLLM extends LLM {
87
110
  writable: true,
88
111
  value: void 0
89
112
  });
113
+ this.originalFields = fields;
90
114
  copyAndValidateModelParamsInto(fields, this);
115
+ this.safetyHandler =
116
+ fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();
91
117
  const client = this.buildClient(fields);
92
118
  this.buildConnection(fields ?? {}, client);
93
119
  }
@@ -122,37 +148,78 @@ export class GoogleBaseLLM extends LLM {
122
148
  }
123
149
  /**
124
150
  * For some given input string and options, return a string output.
151
+ *
152
+ * Despite the fact that `invoke` is overridden below, we still need this
153
+ * in order to handle public APi calls to `generate()`.
125
154
  */
126
- async _call(_prompt, _options, _runManager) {
155
+ async _call(prompt, options) {
127
156
  const parameters = copyAIModelParams(this);
128
- const result = await this.connection.request(_prompt, parameters, _options);
129
- const ret = responseToString(result);
157
+ const result = await this.connection.request(prompt, parameters, options);
158
+ const ret = safeResponseToString(result, this.safetyHandler);
130
159
  return ret;
131
160
  }
132
- async *_streamResponseChunks(_prompt, _options, _runManager) {
133
- // Make the call as a streaming request
134
- const parameters = copyAIModelParams(this);
135
- const result = await this.streamedConnection.request(_prompt, parameters, _options);
136
- // Get the streaming parser of the response
137
- const stream = result.data;
138
- // Loop until the end of the stream
139
- // During the loop, yield each time we get a chunk from the streaming parser
140
- // that is either available or added to the queue
141
- while (!stream.streamDone) {
142
- const output = await stream.nextChunk();
143
- const chunk = output !== null
144
- ? new GenerationChunk(responseToGeneration({ data: output }))
145
- : new GenerationChunk({
146
- text: "",
147
- generationInfo: { finishReason: "stop" },
161
+ // Normally, you should not override this method and instead should override
162
+ // _streamResponseChunks. We are doing so here to allow for multimodal inputs into
163
+ // the LLM.
164
+ async *_streamIterator(input, options) {
165
+ // TODO: Refactor callback setup and teardown code into core
166
+ const prompt = BaseLLM._convertInputToPromptValue(input);
167
+ const [runnableConfig, callOptions] = this._separateRunnableConfigFromCallOptions(options);
168
+ const callbackManager_ = await CallbackManager.configure(runnableConfig.callbacks, this.callbacks, runnableConfig.tags, this.tags, runnableConfig.metadata, this.metadata, { verbose: this.verbose });
169
+ const extra = {
170
+ options: callOptions,
171
+ invocation_params: this?.invocationParams(callOptions),
172
+ batch_size: 1,
173
+ };
174
+ const runManagers = await callbackManager_?.handleLLMStart(this.toJSON(), [prompt.toString()], undefined, undefined, extra, undefined, undefined, runnableConfig.runName);
175
+ let generation = new GenerationChunk({
176
+ text: "",
177
+ });
178
+ const proxyChat = this.createProxyChat();
179
+ try {
180
+ for await (const chunk of proxyChat._streamIterator(input, options)) {
181
+ const stringValue = chunkToString(chunk);
182
+ const generationChunk = new GenerationChunk({
183
+ text: stringValue,
148
184
  });
149
- yield chunk;
185
+ generation = generation.concat(generationChunk);
186
+ yield stringValue;
187
+ }
188
+ }
189
+ catch (err) {
190
+ await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMError(err)));
191
+ throw err;
150
192
  }
193
+ await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMEnd({
194
+ generations: [[generation]],
195
+ })));
151
196
  }
152
197
  async predictMessages(messages, options, _callbacks) {
153
198
  const { content } = messages[0];
154
199
  const result = await this.connection.request(content, {}, options);
155
- const ret = responseToBaseMessage(result);
200
+ const ret = safeResponseToBaseMessage(result, this.safetyHandler);
156
201
  return ret;
157
202
  }
203
+ /**
204
+ * Internal implementation detail to allow Google LLMs to support
205
+ * multimodal input by delegating to the chat model implementation.
206
+ *
207
+ * TODO: Replace with something less hacky.
208
+ */
209
+ createProxyChat() {
210
+ return new ProxyChatGoogle({
211
+ ...this.originalFields,
212
+ connection: this.connection,
213
+ });
214
+ }
215
+ // TODO: Remove the need to override this - we are doing it to
216
+ // allow the LLM to handle multimodal types of input.
217
+ async invoke(input, options) {
218
+ const stream = await this._streamIterator(input, options);
219
+ let generatedOutput = "";
220
+ for await (const chunk of stream) {
221
+ generatedOutput += chunk;
222
+ }
223
+ return generatedOutput;
224
+ }
158
225
  }
package/dist/types.d.ts CHANGED
@@ -67,7 +67,12 @@ export interface GoogleAIModelParams {
67
67
  stopSequences?: string[];
68
68
  safetySettings?: GoogleAISafetySetting[];
69
69
  }
70
- export interface GoogleAIBaseLLMInput<AuthOptions> extends BaseLLMParams, GoogleConnectionParams<AuthOptions>, GoogleAIModelParams {
70
+ export interface GoogleAIBaseLLMInput<AuthOptions> extends BaseLLMParams, GoogleConnectionParams<AuthOptions>, GoogleAIModelParams, GoogleAISafetyParams {
71
+ }
72
+ /**
73
+ * Input to LLM class.
74
+ */
75
+ export interface GoogleBaseLLMInput<AuthOptions> extends GoogleAIBaseLLMInput<AuthOptions> {
71
76
  }
72
77
  export interface GoogleResponse {
73
78
  data: any;
@@ -76,20 +81,28 @@ export interface GeminiPartText {
76
81
  text: string;
77
82
  }
78
83
  export interface GeminiPartInlineData {
79
- mimeType: string;
80
- data: string;
84
+ inlineData: {
85
+ mimeType: string;
86
+ data: string;
87
+ };
81
88
  }
82
89
  export interface GeminiPartFileData {
83
- mimeType: string;
84
- fileUri: string;
90
+ fileData: {
91
+ mimeType: string;
92
+ fileUri: string;
93
+ };
85
94
  }
86
95
  export interface GeminiPartFunctionCall {
87
- name: string;
88
- args?: object;
96
+ functionCall: {
97
+ name: string;
98
+ args?: object;
99
+ };
89
100
  }
90
101
  export interface GeminiPartFunctionResponse {
91
- name: string;
92
- response: object;
102
+ functionResponse: {
103
+ name: string;
104
+ response: object;
105
+ };
93
106
  }
94
107
  export type GeminiPart = GeminiPartText | GeminiPartInlineData | GeminiPartFileData | GeminiPartFunctionCall | GeminiPartFunctionResponse;
95
108
  export interface GeminiSafetySetting {
@@ -132,6 +145,7 @@ interface GeminiResponseCandidate {
132
145
  safetyRatings: GeminiSafetyRating[];
133
146
  }
134
147
  interface GeminiResponsePromptFeedback {
148
+ blockReason?: string;
135
149
  safetyRatings: GeminiSafetyRating[];
136
150
  }
137
151
  export interface GenerateContentResponseData {
@@ -143,4 +157,16 @@ export type GoogleLLMResponseData = JsonStream | GenerateContentResponseData | G
143
157
  export interface GoogleLLMResponse extends GoogleResponse {
144
158
  data: GoogleLLMResponseData;
145
159
  }
160
+ export interface GoogleAISafetyHandler {
161
+ /**
162
+ * A function that will take a response and return the, possibly modified,
163
+ * response or throw an exception if there are safety issues.
164
+ *
165
+ * @throws GoogleAISafetyError
166
+ */
167
+ handle(response: GoogleLLMResponse): GoogleLLMResponse;
168
+ }
169
+ export interface GoogleAISafetyParams {
170
+ safetyHandler?: GoogleAISafetyHandler;
171
+ }
146
172
  export {};
@@ -0,0 +1,37 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.ensureParams = exports.failedAttemptHandler = void 0;
4
+ const STATUS_NO_RETRY = [
5
+ 400,
6
+ 401,
7
+ 402,
8
+ 403,
9
+ 404,
10
+ 405,
11
+ 406,
12
+ 407,
13
+ 408,
14
+ 409, // Conflict
15
+ ];
16
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
17
+ function failedAttemptHandler(error) {
18
+ const status = error?.response?.status ?? 0;
19
+ if (status === 0) {
20
+ // What is this?
21
+ console.error("failedAttemptHandler", error);
22
+ }
23
+ // What errors shouldn't be retried?
24
+ if (STATUS_NO_RETRY.includes(+status)) {
25
+ throw error;
26
+ }
27
+ throw error;
28
+ }
29
+ exports.failedAttemptHandler = failedAttemptHandler;
30
+ function ensureParams(params) {
31
+ const base = params ?? {};
32
+ return {
33
+ onFailedAttempt: failedAttemptHandler,
34
+ ...base,
35
+ };
36
+ }
37
+ exports.ensureParams = ensureParams;
@@ -0,0 +1,3 @@
1
+ import { AsyncCallerParams } from "@langchain/core/utils/async_caller";
2
+ export declare function failedAttemptHandler(error: any): void;
3
+ export declare function ensureParams(params?: AsyncCallerParams): AsyncCallerParams;