@langchain/google-common 0.0.0 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +7 -2
- package/dist/chat_models.cjs +115 -7
- package/dist/chat_models.d.ts +16 -6
- package/dist/chat_models.js +116 -8
- package/dist/connection.cjs +48 -4
- package/dist/connection.d.ts +14 -7
- package/dist/connection.js +48 -4
- package/dist/index.cjs +1 -0
- package/dist/index.d.ts +1 -0
- package/dist/index.js +1 -0
- package/dist/llms.cjs +91 -24
- package/dist/llms.d.ts +21 -12
- package/dist/llms.js +93 -26
- package/dist/types.d.ts +66 -11
- package/dist/utils/common.cjs +19 -12
- package/dist/utils/common.d.ts +3 -3
- package/dist/utils/common.js +19 -12
- package/dist/utils/failed_handler.cjs +37 -0
- package/dist/utils/failed_handler.d.ts +3 -0
- package/dist/utils/failed_handler.js +32 -0
- package/dist/utils/gemini.cjs +321 -25
- package/dist/utils/gemini.d.ts +53 -3
- package/dist/utils/gemini.js +308 -23
- package/dist/utils/index.cjs +23 -0
- package/dist/utils/index.d.ts +7 -0
- package/dist/utils/index.js +7 -0
- package/dist/utils/safety.cjs +23 -0
- package/dist/utils/safety.d.ts +6 -0
- package/dist/utils/safety.js +19 -0
- package/dist/utils/zod_to_gemini_parameters.cjs +15 -0
- package/dist/utils/zod_to_gemini_parameters.d.ts +3 -0
- package/dist/utils/zod_to_gemini_parameters.js +11 -0
- package/index.d.cts +1 -0
- package/package.json +42 -7
- package/types.cjs +1 -0
- package/types.d.cts +1 -0
- package/types.d.ts +1 -0
- package/types.js +1 -0
- package/utils.cjs +1 -0
- package/utils.d.cts +1 -0
- package/utils.d.ts +1 -0
- package/utils.js +1 -0
package/dist/connection.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { getRuntimeEnvironment } from "@langchain/core/utils/env";
|
|
2
|
+
import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js";
|
|
2
3
|
export class GoogleConnection {
|
|
3
4
|
constructor(caller, client, streaming) {
|
|
4
5
|
Object.defineProperty(this, "caller", {
|
|
@@ -118,12 +119,19 @@ export class GoogleHostConnection extends GoogleConnection {
|
|
|
118
119
|
export class GoogleAIConnection extends GoogleHostConnection {
|
|
119
120
|
constructor(fields, caller, client, streaming) {
|
|
120
121
|
super(fields, caller, client, streaming);
|
|
122
|
+
/** @deprecated Prefer `modelName` */
|
|
121
123
|
Object.defineProperty(this, "model", {
|
|
122
124
|
enumerable: true,
|
|
123
125
|
configurable: true,
|
|
124
126
|
writable: true,
|
|
125
127
|
value: void 0
|
|
126
128
|
});
|
|
129
|
+
Object.defineProperty(this, "modelName", {
|
|
130
|
+
enumerable: true,
|
|
131
|
+
configurable: true,
|
|
132
|
+
writable: true,
|
|
133
|
+
value: void 0
|
|
134
|
+
});
|
|
127
135
|
Object.defineProperty(this, "client", {
|
|
128
136
|
enumerable: true,
|
|
129
137
|
configurable: true,
|
|
@@ -131,10 +139,10 @@ export class GoogleAIConnection extends GoogleHostConnection {
|
|
|
131
139
|
value: void 0
|
|
132
140
|
});
|
|
133
141
|
this.client = client;
|
|
134
|
-
this.
|
|
142
|
+
this.modelName = fields?.modelName ?? fields?.model ?? this.modelName;
|
|
135
143
|
}
|
|
136
144
|
get modelFamily() {
|
|
137
|
-
if (this.
|
|
145
|
+
if (this.modelName.startsWith("gemini")) {
|
|
138
146
|
return "gemini";
|
|
139
147
|
}
|
|
140
148
|
else {
|
|
@@ -151,13 +159,13 @@ export class GoogleAIConnection extends GoogleHostConnection {
|
|
|
151
159
|
}
|
|
152
160
|
async buildUrlGenerativeLanguage() {
|
|
153
161
|
const method = await this.buildUrlMethod();
|
|
154
|
-
const url = `https://generativelanguage.googleapis.com/${this.apiVersion}/models/${this.
|
|
162
|
+
const url = `https://generativelanguage.googleapis.com/${this.apiVersion}/models/${this.modelName}:${method}`;
|
|
155
163
|
return url;
|
|
156
164
|
}
|
|
157
165
|
async buildUrlVertex() {
|
|
158
166
|
const projectId = await this.client.getProjectId();
|
|
159
167
|
const method = await this.buildUrlMethod();
|
|
160
|
-
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.
|
|
168
|
+
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.modelName}:${method}`;
|
|
161
169
|
return url;
|
|
162
170
|
}
|
|
163
171
|
async buildUrl() {
|
|
@@ -199,6 +207,38 @@ export class AbstractGoogleLLMConnection extends GoogleAIConnection {
|
|
|
199
207
|
formatSafetySettings(_input, parameters) {
|
|
200
208
|
return parameters.safetySettings ?? [];
|
|
201
209
|
}
|
|
210
|
+
// Borrowed from the OpenAI invocation params test
|
|
211
|
+
isStructuredToolArray(tools) {
|
|
212
|
+
return (tools !== undefined &&
|
|
213
|
+
tools.every((tool) => Array.isArray(tool.lc_namespace)));
|
|
214
|
+
}
|
|
215
|
+
structuredToolToFunctionDeclaration(tool) {
|
|
216
|
+
const jsonSchema = zodToGeminiParameters(tool.schema);
|
|
217
|
+
return {
|
|
218
|
+
name: tool.name,
|
|
219
|
+
description: tool.description,
|
|
220
|
+
parameters: jsonSchema,
|
|
221
|
+
};
|
|
222
|
+
}
|
|
223
|
+
structuredToolsToGeminiTools(tools) {
|
|
224
|
+
return [
|
|
225
|
+
{
|
|
226
|
+
functionDeclarations: tools.map(this.structuredToolToFunctionDeclaration),
|
|
227
|
+
},
|
|
228
|
+
];
|
|
229
|
+
}
|
|
230
|
+
formatTools(_input, parameters) {
|
|
231
|
+
const tools = parameters?.tools;
|
|
232
|
+
if (!tools || tools.length === 0) {
|
|
233
|
+
return [];
|
|
234
|
+
}
|
|
235
|
+
if (this.isStructuredToolArray(tools)) {
|
|
236
|
+
return this.structuredToolsToGeminiTools(tools);
|
|
237
|
+
}
|
|
238
|
+
else {
|
|
239
|
+
return tools;
|
|
240
|
+
}
|
|
241
|
+
}
|
|
202
242
|
formatData(input, parameters) {
|
|
203
243
|
/*
|
|
204
244
|
const parts = messageContentToParts(input);
|
|
@@ -211,11 +251,15 @@ export class AbstractGoogleLLMConnection extends GoogleAIConnection {
|
|
|
211
251
|
*/
|
|
212
252
|
const contents = this.formatContents(input, parameters);
|
|
213
253
|
const generationConfig = this.formatGenerationConfig(input, parameters);
|
|
254
|
+
const tools = this.formatTools(input, parameters);
|
|
214
255
|
const safetySettings = this.formatSafetySettings(input, parameters);
|
|
215
256
|
const ret = {
|
|
216
257
|
contents,
|
|
217
258
|
generationConfig,
|
|
218
259
|
};
|
|
260
|
+
if (tools && tools.length) {
|
|
261
|
+
ret.tools = tools;
|
|
262
|
+
}
|
|
219
263
|
if (safetySettings && safetySettings.length) {
|
|
220
264
|
ret.safetySettings = safetySettings;
|
|
221
265
|
}
|
package/dist/index.cjs
CHANGED
|
@@ -21,3 +21,4 @@ __exportStar(require("./connection.cjs"), exports);
|
|
|
21
21
|
__exportStar(require("./types.cjs"), exports);
|
|
22
22
|
__exportStar(require("./utils/stream.cjs"), exports);
|
|
23
23
|
__exportStar(require("./utils/common.cjs"), exports);
|
|
24
|
+
__exportStar(require("./utils/zod_to_gemini_parameters.cjs"), exports);
|
package/dist/index.d.ts
CHANGED
package/dist/index.js
CHANGED
package/dist/llms.cjs
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.GoogleBaseLLM = void 0;
|
|
4
|
+
const manager_1 = require("@langchain/core/callbacks/manager");
|
|
4
5
|
const llms_1 = require("@langchain/core/language_models/llms");
|
|
5
6
|
const outputs_1 = require("@langchain/core/outputs");
|
|
6
7
|
const env_1 = require("@langchain/core/utils/env");
|
|
@@ -8,6 +9,8 @@ const connection_js_1 = require("./connection.cjs");
|
|
|
8
9
|
const common_js_1 = require("./utils/common.cjs");
|
|
9
10
|
const gemini_js_1 = require("./utils/gemini.cjs");
|
|
10
11
|
const auth_js_1 = require("./auth.cjs");
|
|
12
|
+
const failed_handler_js_1 = require("./utils/failed_handler.cjs");
|
|
13
|
+
const chat_models_js_1 = require("./chat_models.cjs");
|
|
11
14
|
class GoogleLLMConnection extends connection_js_1.AbstractGoogleLLMConnection {
|
|
12
15
|
formatContents(input, _parameters) {
|
|
13
16
|
const parts = (0, gemini_js_1.messageContentToParts)(input);
|
|
@@ -20,6 +23,14 @@ class GoogleLLMConnection extends connection_js_1.AbstractGoogleLLMConnection {
|
|
|
20
23
|
return contents;
|
|
21
24
|
}
|
|
22
25
|
}
|
|
26
|
+
class ProxyChatGoogle extends chat_models_js_1.ChatGoogleBase {
|
|
27
|
+
constructor(fields) {
|
|
28
|
+
super(fields);
|
|
29
|
+
}
|
|
30
|
+
buildAbstractedClient(fields) {
|
|
31
|
+
return fields.connection.client;
|
|
32
|
+
}
|
|
33
|
+
}
|
|
23
34
|
/**
|
|
24
35
|
* Integration with an LLM.
|
|
25
36
|
*/
|
|
@@ -29,14 +40,20 @@ class GoogleBaseLLM extends llms_1.LLM {
|
|
|
29
40
|
return "GoogleLLM";
|
|
30
41
|
}
|
|
31
42
|
constructor(fields) {
|
|
32
|
-
super(
|
|
43
|
+
super((0, failed_handler_js_1.ensureParams)(fields));
|
|
44
|
+
Object.defineProperty(this, "originalFields", {
|
|
45
|
+
enumerable: true,
|
|
46
|
+
configurable: true,
|
|
47
|
+
writable: true,
|
|
48
|
+
value: void 0
|
|
49
|
+
});
|
|
33
50
|
Object.defineProperty(this, "lc_serializable", {
|
|
34
51
|
enumerable: true,
|
|
35
52
|
configurable: true,
|
|
36
53
|
writable: true,
|
|
37
54
|
value: true
|
|
38
55
|
});
|
|
39
|
-
Object.defineProperty(this, "
|
|
56
|
+
Object.defineProperty(this, "modelName", {
|
|
40
57
|
enumerable: true,
|
|
41
58
|
configurable: true,
|
|
42
59
|
writable: true,
|
|
@@ -78,6 +95,12 @@ class GoogleBaseLLM extends llms_1.LLM {
|
|
|
78
95
|
writable: true,
|
|
79
96
|
value: []
|
|
80
97
|
});
|
|
98
|
+
Object.defineProperty(this, "safetyHandler", {
|
|
99
|
+
enumerable: true,
|
|
100
|
+
configurable: true,
|
|
101
|
+
writable: true,
|
|
102
|
+
value: void 0
|
|
103
|
+
});
|
|
81
104
|
Object.defineProperty(this, "connection", {
|
|
82
105
|
enumerable: true,
|
|
83
106
|
configurable: true,
|
|
@@ -90,7 +113,10 @@ class GoogleBaseLLM extends llms_1.LLM {
|
|
|
90
113
|
writable: true,
|
|
91
114
|
value: void 0
|
|
92
115
|
});
|
|
116
|
+
this.originalFields = fields;
|
|
93
117
|
(0, common_js_1.copyAndValidateModelParamsInto)(fields, this);
|
|
118
|
+
this.safetyHandler =
|
|
119
|
+
fields?.safetyHandler ?? new gemini_js_1.DefaultGeminiSafetyHandler();
|
|
94
120
|
const client = this.buildClient(fields);
|
|
95
121
|
this.buildConnection(fields ?? {}, client);
|
|
96
122
|
}
|
|
@@ -125,38 +151,79 @@ class GoogleBaseLLM extends llms_1.LLM {
|
|
|
125
151
|
}
|
|
126
152
|
/**
|
|
127
153
|
* For some given input string and options, return a string output.
|
|
154
|
+
*
|
|
155
|
+
* Despite the fact that `invoke` is overridden below, we still need this
|
|
156
|
+
* in order to handle public APi calls to `generate()`.
|
|
128
157
|
*/
|
|
129
|
-
async _call(
|
|
130
|
-
const parameters = (0, common_js_1.copyAIModelParams)(this);
|
|
131
|
-
const result = await this.connection.request(
|
|
132
|
-
const ret = (0, gemini_js_1.
|
|
158
|
+
async _call(prompt, options) {
|
|
159
|
+
const parameters = (0, common_js_1.copyAIModelParams)(this, options);
|
|
160
|
+
const result = await this.connection.request(prompt, parameters, options);
|
|
161
|
+
const ret = (0, gemini_js_1.safeResponseToString)(result, this.safetyHandler);
|
|
133
162
|
return ret;
|
|
134
163
|
}
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
//
|
|
140
|
-
const
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
164
|
+
// Normally, you should not override this method and instead should override
|
|
165
|
+
// _streamResponseChunks. We are doing so here to allow for multimodal inputs into
|
|
166
|
+
// the LLM.
|
|
167
|
+
async *_streamIterator(input, options) {
|
|
168
|
+
// TODO: Refactor callback setup and teardown code into core
|
|
169
|
+
const prompt = llms_1.BaseLLM._convertInputToPromptValue(input);
|
|
170
|
+
const [runnableConfig, callOptions] = this._separateRunnableConfigFromCallOptions(options);
|
|
171
|
+
const callbackManager_ = await manager_1.CallbackManager.configure(runnableConfig.callbacks, this.callbacks, runnableConfig.tags, this.tags, runnableConfig.metadata, this.metadata, { verbose: this.verbose });
|
|
172
|
+
const extra = {
|
|
173
|
+
options: callOptions,
|
|
174
|
+
invocation_params: this?.invocationParams(callOptions),
|
|
175
|
+
batch_size: 1,
|
|
176
|
+
};
|
|
177
|
+
const runManagers = await callbackManager_?.handleLLMStart(this.toJSON(), [prompt.toString()], undefined, undefined, extra, undefined, undefined, runnableConfig.runName);
|
|
178
|
+
let generation = new outputs_1.GenerationChunk({
|
|
179
|
+
text: "",
|
|
180
|
+
});
|
|
181
|
+
const proxyChat = this.createProxyChat();
|
|
182
|
+
try {
|
|
183
|
+
for await (const chunk of proxyChat._streamIterator(input, options)) {
|
|
184
|
+
const stringValue = (0, gemini_js_1.chunkToString)(chunk);
|
|
185
|
+
const generationChunk = new outputs_1.GenerationChunk({
|
|
186
|
+
text: stringValue,
|
|
151
187
|
});
|
|
152
|
-
|
|
188
|
+
generation = generation.concat(generationChunk);
|
|
189
|
+
yield stringValue;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
catch (err) {
|
|
193
|
+
await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMError(err)));
|
|
194
|
+
throw err;
|
|
153
195
|
}
|
|
196
|
+
await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMEnd({
|
|
197
|
+
generations: [[generation]],
|
|
198
|
+
})));
|
|
154
199
|
}
|
|
155
200
|
async predictMessages(messages, options, _callbacks) {
|
|
156
201
|
const { content } = messages[0];
|
|
157
202
|
const result = await this.connection.request(content, {}, options);
|
|
158
|
-
const ret = (0, gemini_js_1.
|
|
203
|
+
const ret = (0, gemini_js_1.safeResponseToBaseMessage)(result, this.safetyHandler);
|
|
159
204
|
return ret;
|
|
160
205
|
}
|
|
206
|
+
/**
|
|
207
|
+
* Internal implementation detail to allow Google LLMs to support
|
|
208
|
+
* multimodal input by delegating to the chat model implementation.
|
|
209
|
+
*
|
|
210
|
+
* TODO: Replace with something less hacky.
|
|
211
|
+
*/
|
|
212
|
+
createProxyChat() {
|
|
213
|
+
return new ProxyChatGoogle({
|
|
214
|
+
...this.originalFields,
|
|
215
|
+
connection: this.connection,
|
|
216
|
+
});
|
|
217
|
+
}
|
|
218
|
+
// TODO: Remove the need to override this - we are doing it to
|
|
219
|
+
// allow the LLM to handle multimodal types of input.
|
|
220
|
+
async invoke(input, options) {
|
|
221
|
+
const stream = await this._streamIterator(input, options);
|
|
222
|
+
let generatedOutput = "";
|
|
223
|
+
for await (const chunk of stream) {
|
|
224
|
+
generatedOutput += chunk;
|
|
225
|
+
}
|
|
226
|
+
return generatedOutput;
|
|
227
|
+
}
|
|
161
228
|
}
|
|
162
229
|
exports.GoogleBaseLLM = GoogleBaseLLM;
|
package/dist/llms.d.ts
CHANGED
|
@@ -1,32 +1,31 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { Callbacks } from "@langchain/core/callbacks/manager";
|
|
2
2
|
import { LLM } from "@langchain/core/language_models/llms";
|
|
3
|
-
import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
|
|
3
|
+
import { type BaseLanguageModelCallOptions, BaseLanguageModelInput } from "@langchain/core/language_models/base";
|
|
4
4
|
import { BaseMessage, MessageContent } from "@langchain/core/messages";
|
|
5
|
-
import { GenerationChunk } from "@langchain/core/outputs";
|
|
6
5
|
import { AbstractGoogleLLMConnection } from "./connection.js";
|
|
7
6
|
import { GoogleAIBaseLLMInput, GoogleAIModelParams, GoogleAISafetySetting, GooglePlatformType, GeminiContent } from "./types.js";
|
|
8
7
|
import { GoogleAbstractedClient } from "./auth.js";
|
|
8
|
+
import { ChatGoogleBase } from "./chat_models.js";
|
|
9
|
+
import type { GoogleBaseLLMInput, GoogleAISafetyHandler } from "./types.js";
|
|
10
|
+
export { GoogleBaseLLMInput };
|
|
9
11
|
declare class GoogleLLMConnection<AuthOptions> extends AbstractGoogleLLMConnection<MessageContent, AuthOptions> {
|
|
10
12
|
formatContents(input: MessageContent, _parameters: GoogleAIModelParams): GeminiContent[];
|
|
11
13
|
}
|
|
12
|
-
/**
|
|
13
|
-
* Input to LLM class.
|
|
14
|
-
*/
|
|
15
|
-
export interface GoogleBaseLLMInput<AuthOptions> extends GoogleAIBaseLLMInput<AuthOptions> {
|
|
16
|
-
}
|
|
17
14
|
/**
|
|
18
15
|
* Integration with an LLM.
|
|
19
16
|
*/
|
|
20
17
|
export declare abstract class GoogleBaseLLM<AuthOptions> extends LLM<BaseLanguageModelCallOptions> implements GoogleBaseLLMInput<AuthOptions> {
|
|
21
18
|
static lc_name(): string;
|
|
19
|
+
originalFields?: GoogleBaseLLMInput<AuthOptions>;
|
|
22
20
|
lc_serializable: boolean;
|
|
23
|
-
|
|
21
|
+
modelName: string;
|
|
24
22
|
temperature: number;
|
|
25
23
|
maxOutputTokens: number;
|
|
26
24
|
topP: number;
|
|
27
25
|
topK: number;
|
|
28
26
|
stopSequences: string[];
|
|
29
27
|
safetySettings: GoogleAISafetySetting[];
|
|
28
|
+
safetyHandler: GoogleAISafetyHandler;
|
|
30
29
|
protected connection: GoogleLLMConnection<AuthOptions>;
|
|
31
30
|
protected streamedConnection: GoogleLLMConnection<AuthOptions>;
|
|
32
31
|
constructor(fields?: GoogleBaseLLMInput<AuthOptions>);
|
|
@@ -40,9 +39,19 @@ export declare abstract class GoogleBaseLLM<AuthOptions> extends LLM<BaseLanguag
|
|
|
40
39
|
formatPrompt(prompt: string): MessageContent;
|
|
41
40
|
/**
|
|
42
41
|
* For some given input string and options, return a string output.
|
|
42
|
+
*
|
|
43
|
+
* Despite the fact that `invoke` is overridden below, we still need this
|
|
44
|
+
* in order to handle public APi calls to `generate()`.
|
|
43
45
|
*/
|
|
44
|
-
_call(
|
|
45
|
-
|
|
46
|
+
_call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
|
|
47
|
+
_streamIterator(input: BaseLanguageModelInput, options?: BaseLanguageModelCallOptions): AsyncGenerator<string>;
|
|
46
48
|
predictMessages(messages: BaseMessage[], options?: string[] | BaseLanguageModelCallOptions, _callbacks?: Callbacks): Promise<BaseMessage>;
|
|
49
|
+
/**
|
|
50
|
+
* Internal implementation detail to allow Google LLMs to support
|
|
51
|
+
* multimodal input by delegating to the chat model implementation.
|
|
52
|
+
*
|
|
53
|
+
* TODO: Replace with something less hacky.
|
|
54
|
+
*/
|
|
55
|
+
protected createProxyChat(): ChatGoogleBase<AuthOptions>;
|
|
56
|
+
invoke(input: BaseLanguageModelInput, options?: BaseLanguageModelCallOptions): Promise<string>;
|
|
47
57
|
}
|
|
48
|
-
export {};
|
package/dist/llms.js
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { CallbackManager } from "@langchain/core/callbacks/manager";
|
|
2
|
+
import { BaseLLM, LLM } from "@langchain/core/language_models/llms";
|
|
2
3
|
import { GenerationChunk } from "@langchain/core/outputs";
|
|
3
4
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
|
4
5
|
import { AbstractGoogleLLMConnection } from "./connection.js";
|
|
5
6
|
import { copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js";
|
|
6
|
-
import { messageContentToParts,
|
|
7
|
+
import { chunkToString, messageContentToParts, safeResponseToBaseMessage, safeResponseToString, DefaultGeminiSafetyHandler, } from "./utils/gemini.js";
|
|
7
8
|
import { ApiKeyGoogleAuth } from "./auth.js";
|
|
9
|
+
import { ensureParams } from "./utils/failed_handler.js";
|
|
10
|
+
import { ChatGoogleBase } from "./chat_models.js";
|
|
8
11
|
class GoogleLLMConnection extends AbstractGoogleLLMConnection {
|
|
9
12
|
formatContents(input, _parameters) {
|
|
10
13
|
const parts = messageContentToParts(input);
|
|
@@ -17,6 +20,14 @@ class GoogleLLMConnection extends AbstractGoogleLLMConnection {
|
|
|
17
20
|
return contents;
|
|
18
21
|
}
|
|
19
22
|
}
|
|
23
|
+
class ProxyChatGoogle extends ChatGoogleBase {
|
|
24
|
+
constructor(fields) {
|
|
25
|
+
super(fields);
|
|
26
|
+
}
|
|
27
|
+
buildAbstractedClient(fields) {
|
|
28
|
+
return fields.connection.client;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
20
31
|
/**
|
|
21
32
|
* Integration with an LLM.
|
|
22
33
|
*/
|
|
@@ -26,14 +37,20 @@ export class GoogleBaseLLM extends LLM {
|
|
|
26
37
|
return "GoogleLLM";
|
|
27
38
|
}
|
|
28
39
|
constructor(fields) {
|
|
29
|
-
super(fields
|
|
40
|
+
super(ensureParams(fields));
|
|
41
|
+
Object.defineProperty(this, "originalFields", {
|
|
42
|
+
enumerable: true,
|
|
43
|
+
configurable: true,
|
|
44
|
+
writable: true,
|
|
45
|
+
value: void 0
|
|
46
|
+
});
|
|
30
47
|
Object.defineProperty(this, "lc_serializable", {
|
|
31
48
|
enumerable: true,
|
|
32
49
|
configurable: true,
|
|
33
50
|
writable: true,
|
|
34
51
|
value: true
|
|
35
52
|
});
|
|
36
|
-
Object.defineProperty(this, "
|
|
53
|
+
Object.defineProperty(this, "modelName", {
|
|
37
54
|
enumerable: true,
|
|
38
55
|
configurable: true,
|
|
39
56
|
writable: true,
|
|
@@ -75,6 +92,12 @@ export class GoogleBaseLLM extends LLM {
|
|
|
75
92
|
writable: true,
|
|
76
93
|
value: []
|
|
77
94
|
});
|
|
95
|
+
Object.defineProperty(this, "safetyHandler", {
|
|
96
|
+
enumerable: true,
|
|
97
|
+
configurable: true,
|
|
98
|
+
writable: true,
|
|
99
|
+
value: void 0
|
|
100
|
+
});
|
|
78
101
|
Object.defineProperty(this, "connection", {
|
|
79
102
|
enumerable: true,
|
|
80
103
|
configurable: true,
|
|
@@ -87,7 +110,10 @@ export class GoogleBaseLLM extends LLM {
|
|
|
87
110
|
writable: true,
|
|
88
111
|
value: void 0
|
|
89
112
|
});
|
|
113
|
+
this.originalFields = fields;
|
|
90
114
|
copyAndValidateModelParamsInto(fields, this);
|
|
115
|
+
this.safetyHandler =
|
|
116
|
+
fields?.safetyHandler ?? new DefaultGeminiSafetyHandler();
|
|
91
117
|
const client = this.buildClient(fields);
|
|
92
118
|
this.buildConnection(fields ?? {}, client);
|
|
93
119
|
}
|
|
@@ -122,37 +148,78 @@ export class GoogleBaseLLM extends LLM {
|
|
|
122
148
|
}
|
|
123
149
|
/**
|
|
124
150
|
* For some given input string and options, return a string output.
|
|
151
|
+
*
|
|
152
|
+
* Despite the fact that `invoke` is overridden below, we still need this
|
|
153
|
+
* in order to handle public APi calls to `generate()`.
|
|
125
154
|
*/
|
|
126
|
-
async _call(
|
|
127
|
-
const parameters = copyAIModelParams(this);
|
|
128
|
-
const result = await this.connection.request(
|
|
129
|
-
const ret =
|
|
155
|
+
async _call(prompt, options) {
|
|
156
|
+
const parameters = copyAIModelParams(this, options);
|
|
157
|
+
const result = await this.connection.request(prompt, parameters, options);
|
|
158
|
+
const ret = safeResponseToString(result, this.safetyHandler);
|
|
130
159
|
return ret;
|
|
131
160
|
}
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
//
|
|
137
|
-
const
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
161
|
+
// Normally, you should not override this method and instead should override
|
|
162
|
+
// _streamResponseChunks. We are doing so here to allow for multimodal inputs into
|
|
163
|
+
// the LLM.
|
|
164
|
+
async *_streamIterator(input, options) {
|
|
165
|
+
// TODO: Refactor callback setup and teardown code into core
|
|
166
|
+
const prompt = BaseLLM._convertInputToPromptValue(input);
|
|
167
|
+
const [runnableConfig, callOptions] = this._separateRunnableConfigFromCallOptions(options);
|
|
168
|
+
const callbackManager_ = await CallbackManager.configure(runnableConfig.callbacks, this.callbacks, runnableConfig.tags, this.tags, runnableConfig.metadata, this.metadata, { verbose: this.verbose });
|
|
169
|
+
const extra = {
|
|
170
|
+
options: callOptions,
|
|
171
|
+
invocation_params: this?.invocationParams(callOptions),
|
|
172
|
+
batch_size: 1,
|
|
173
|
+
};
|
|
174
|
+
const runManagers = await callbackManager_?.handleLLMStart(this.toJSON(), [prompt.toString()], undefined, undefined, extra, undefined, undefined, runnableConfig.runName);
|
|
175
|
+
let generation = new GenerationChunk({
|
|
176
|
+
text: "",
|
|
177
|
+
});
|
|
178
|
+
const proxyChat = this.createProxyChat();
|
|
179
|
+
try {
|
|
180
|
+
for await (const chunk of proxyChat._streamIterator(input, options)) {
|
|
181
|
+
const stringValue = chunkToString(chunk);
|
|
182
|
+
const generationChunk = new GenerationChunk({
|
|
183
|
+
text: stringValue,
|
|
148
184
|
});
|
|
149
|
-
|
|
185
|
+
generation = generation.concat(generationChunk);
|
|
186
|
+
yield stringValue;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
catch (err) {
|
|
190
|
+
await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMError(err)));
|
|
191
|
+
throw err;
|
|
150
192
|
}
|
|
193
|
+
await Promise.all((runManagers ?? []).map((runManager) => runManager?.handleLLMEnd({
|
|
194
|
+
generations: [[generation]],
|
|
195
|
+
})));
|
|
151
196
|
}
|
|
152
197
|
async predictMessages(messages, options, _callbacks) {
|
|
153
198
|
const { content } = messages[0];
|
|
154
199
|
const result = await this.connection.request(content, {}, options);
|
|
155
|
-
const ret =
|
|
200
|
+
const ret = safeResponseToBaseMessage(result, this.safetyHandler);
|
|
156
201
|
return ret;
|
|
157
202
|
}
|
|
203
|
+
/**
|
|
204
|
+
* Internal implementation detail to allow Google LLMs to support
|
|
205
|
+
* multimodal input by delegating to the chat model implementation.
|
|
206
|
+
*
|
|
207
|
+
* TODO: Replace with something less hacky.
|
|
208
|
+
*/
|
|
209
|
+
createProxyChat() {
|
|
210
|
+
return new ProxyChatGoogle({
|
|
211
|
+
...this.originalFields,
|
|
212
|
+
connection: this.connection,
|
|
213
|
+
});
|
|
214
|
+
}
|
|
215
|
+
// TODO: Remove the need to override this - we are doing it to
|
|
216
|
+
// allow the LLM to handle multimodal types of input.
|
|
217
|
+
async invoke(input, options) {
|
|
218
|
+
const stream = await this._streamIterator(input, options);
|
|
219
|
+
let generatedOutput = "";
|
|
220
|
+
for await (const chunk of stream) {
|
|
221
|
+
generatedOutput += chunk;
|
|
222
|
+
}
|
|
223
|
+
return generatedOutput;
|
|
224
|
+
}
|
|
158
225
|
}
|