@saltcorn/large-language-model 0.6.5 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/generate.js +126 -6
- package/index.js +194 -4
- package/package.json +5 -2
package/generate.js
CHANGED
|
@@ -2,6 +2,13 @@ const fetch = require("node-fetch");
|
|
|
2
2
|
const util = require("util");
|
|
3
3
|
const exec = util.promisify(require("child_process").exec);
|
|
4
4
|
const db = require("@saltcorn/data/db");
|
|
5
|
+
const { VertexAI } = require("@google-cloud/vertexai");
|
|
6
|
+
const {
|
|
7
|
+
PredictionServiceClient,
|
|
8
|
+
helpers,
|
|
9
|
+
} = require("@google-cloud/aiplatform");
|
|
10
|
+
const { google } = require("googleapis");
|
|
11
|
+
const Plugin = require("@saltcorn/data/models/plugin");
|
|
5
12
|
|
|
6
13
|
const { features, getState } = require("@saltcorn/data/db/state");
|
|
7
14
|
let ollamaMod;
|
|
@@ -57,6 +64,13 @@ const getEmbedding = async (config, opts) => {
|
|
|
57
64
|
//console.log("embedding response ", olres);
|
|
58
65
|
return olres.embedding;
|
|
59
66
|
}
|
|
67
|
+
case "Google Vertex AI":
|
|
68
|
+
const oauth2Client = await initOAuth2Client(config);
|
|
69
|
+
if (oauth2Client.isTokenExpiring()) {
|
|
70
|
+
const { credentials } = await oauth2Client.refreshAccessToken();
|
|
71
|
+
await updatePluginTokenCfg(credentials);
|
|
72
|
+
}
|
|
73
|
+
return await getEmbeddingGoogleVertex(config, opts, oauth2Client);
|
|
60
74
|
default:
|
|
61
75
|
throw new Error("Not implemented for this backend");
|
|
62
76
|
}
|
|
@@ -117,6 +131,13 @@ const getCompletion = async (config, opts) => {
|
|
|
117
131
|
{ cwd: config.llama_dir }
|
|
118
132
|
);
|
|
119
133
|
return stdout;
|
|
134
|
+
case "Google Vertex AI":
|
|
135
|
+
const oauth2Client = await initOAuth2Client(config);
|
|
136
|
+
if (oauth2Client.isTokenExpiring()) {
|
|
137
|
+
const { credentials } = await oauth2Client.refreshAccessToken();
|
|
138
|
+
await updatePluginTokenCfg(credentials);
|
|
139
|
+
}
|
|
140
|
+
return await getCompletionGoogleVertex(config, opts, oauth2Client);
|
|
120
141
|
default:
|
|
121
142
|
break;
|
|
122
143
|
}
|
|
@@ -174,12 +195,12 @@ const getCompletionOpenAICompatible = async (
|
|
|
174
195
|
console.log("OpenAI response", JSON.stringify(results, null, 2));
|
|
175
196
|
if (results.error) throw new Error(`OpenAI error: ${results.error.message}`);
|
|
176
197
|
|
|
177
|
-
return
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
198
|
+
return results?.choices?.[0]?.message?.tool_calls
|
|
199
|
+
? {
|
|
200
|
+
tool_calls: results?.choices?.[0]?.message?.tool_calls,
|
|
201
|
+
content: results?.choices?.[0]?.message?.content || null,
|
|
202
|
+
}
|
|
203
|
+
: results?.choices?.[0]?.message?.content || null;
|
|
183
204
|
};
|
|
184
205
|
|
|
185
206
|
const getEmbeddingOpenAICompatible = async (
|
|
@@ -211,4 +232,103 @@ const getEmbeddingOpenAICompatible = async (
|
|
|
211
232
|
if (Array.isArray(prompt)) return results?.data?.map?.((d) => d?.embedding);
|
|
212
233
|
return results?.data?.[0]?.embedding;
|
|
213
234
|
};
|
|
235
|
+
|
|
236
|
+
const updatePluginTokenCfg = async (credentials) => {
|
|
237
|
+
let plugin = await Plugin.findOne({ name: "large-language-model" });
|
|
238
|
+
if (!plugin) {
|
|
239
|
+
plugin = await Plugin.findOne({
|
|
240
|
+
name: "@saltcorn/large-language-model",
|
|
241
|
+
});
|
|
242
|
+
}
|
|
243
|
+
const newConfig = {
|
|
244
|
+
...(plugin.configuration || {}),
|
|
245
|
+
tokens: credentials,
|
|
246
|
+
};
|
|
247
|
+
plugin.configuration = newConfig;
|
|
248
|
+
await plugin.upsert();
|
|
249
|
+
getState().processSend({
|
|
250
|
+
refresh_plugin_cfg: plugin.name,
|
|
251
|
+
tenant: db.getTenantSchema(),
|
|
252
|
+
});
|
|
253
|
+
};
|
|
254
|
+
|
|
255
|
+
const initOAuth2Client = async (config) => {
|
|
256
|
+
const { client_id, client_secret } = config || {};
|
|
257
|
+
const state = getState();
|
|
258
|
+
const pluginCfg =
|
|
259
|
+
state.plugin_cfgs.large_language_model ||
|
|
260
|
+
state.plugin_cfgs["@saltcorn/large-language-model"];
|
|
261
|
+
const baseUrl = (
|
|
262
|
+
getState().getConfig("base_url") || "http://localhost:3000"
|
|
263
|
+
).replace(/\/$/, "");
|
|
264
|
+
const redirect_uri = `${baseUrl}/callback`;
|
|
265
|
+
|
|
266
|
+
const oauth2Client = new google.auth.OAuth2(
|
|
267
|
+
client_id,
|
|
268
|
+
client_secret,
|
|
269
|
+
redirect_uri
|
|
270
|
+
);
|
|
271
|
+
oauth2Client.setCredentials(pluginCfg.tokens);
|
|
272
|
+
return oauth2Client;
|
|
273
|
+
};
|
|
274
|
+
|
|
275
|
+
const getCompletionGoogleVertex = async (config, opts, oauth2Client) => {
|
|
276
|
+
const vertexAI = new VertexAI({
|
|
277
|
+
project: config.project_id,
|
|
278
|
+
location: config.region || "us-central1",
|
|
279
|
+
googleAuthOptions: {
|
|
280
|
+
authClient: oauth2Client,
|
|
281
|
+
},
|
|
282
|
+
});
|
|
283
|
+
const generativeModel = vertexAI.getGenerativeModel({
|
|
284
|
+
model: config.model,
|
|
285
|
+
});
|
|
286
|
+
const chat = generativeModel.startChat();
|
|
287
|
+
const result = await chat.sendMessageStream(opts.prompt);
|
|
288
|
+
const chunks = [];
|
|
289
|
+
for await (const item of result.stream) {
|
|
290
|
+
chunks.push(item.candidates[0].content.parts[0].text);
|
|
291
|
+
}
|
|
292
|
+
return chunks.join();
|
|
293
|
+
};
|
|
294
|
+
|
|
295
|
+
const getEmbeddingGoogleVertex = async (config, opts, oauth2Client) => {
|
|
296
|
+
const predClient = new PredictionServiceClient({
|
|
297
|
+
apiEndpoint: "us-central1-aiplatform.googleapis.com",
|
|
298
|
+
authClient: oauth2Client,
|
|
299
|
+
});
|
|
300
|
+
const model = config.embed_model || "text-embedding-005";
|
|
301
|
+
let instances = null;
|
|
302
|
+
if (Array.isArray(opts.prompt)) {
|
|
303
|
+
instances = opts.prompt.map((p) =>
|
|
304
|
+
helpers.toValue({
|
|
305
|
+
content: p,
|
|
306
|
+
task_type: config.task_type || "RETRIEVAL_QUERY",
|
|
307
|
+
})
|
|
308
|
+
);
|
|
309
|
+
} else {
|
|
310
|
+
instances = [
|
|
311
|
+
helpers.toValue({
|
|
312
|
+
content: opts.prompt,
|
|
313
|
+
task_type: config.task_type || "RETRIEVAL_QUERY",
|
|
314
|
+
}),
|
|
315
|
+
];
|
|
316
|
+
}
|
|
317
|
+
const [response] = await predClient.predict({
|
|
318
|
+
endpoint: `projects/${config.project_id}/locations/${
|
|
319
|
+
config.region || "us-central1"
|
|
320
|
+
}/publishers/google/models/${model}`,
|
|
321
|
+
instances,
|
|
322
|
+
// default outputDimensionality is 768, can be changed with:
|
|
323
|
+
// parameters: helpers.toValue({ outputDimensionality: parseInt(512) }),
|
|
324
|
+
});
|
|
325
|
+
const predictions = response.predictions;
|
|
326
|
+
const embeddings = predictions.map((p) => {
|
|
327
|
+
const embeddingsProto = p.structValue.fields.embeddings;
|
|
328
|
+
const valuesProto = embeddingsProto.structValue.fields.values;
|
|
329
|
+
return valuesProto.listValue.values.map((v) => v.numberValue);
|
|
330
|
+
});
|
|
331
|
+
return embeddings;
|
|
332
|
+
};
|
|
333
|
+
|
|
214
334
|
module.exports = { getCompletion, getEmbedding };
|
package/index.js
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
const Workflow = require("@saltcorn/data/models/workflow");
|
|
2
2
|
const Form = require("@saltcorn/data/models/form");
|
|
3
3
|
const FieldRepeat = require("@saltcorn/data/models/fieldrepeat");
|
|
4
|
+
const Plugin = require("@saltcorn/data/models/plugin");
|
|
5
|
+
const { domReady } = require("@saltcorn/markup/tags");
|
|
4
6
|
const db = require("@saltcorn/data/db");
|
|
5
7
|
const { getCompletion, getEmbedding } = require("./generate");
|
|
6
8
|
const { OPENAI_MODELS } = require("./constants.js");
|
|
7
9
|
const { eval_expression } = require("@saltcorn/data/models/expression");
|
|
8
10
|
const { interpolate } = require("@saltcorn/data/utils");
|
|
11
|
+
const { getState } = require("@saltcorn/data/db/state");
|
|
12
|
+
const { google } = require("googleapis");
|
|
9
13
|
|
|
10
14
|
const configuration_workflow = () =>
|
|
11
15
|
new Workflow({
|
|
@@ -15,6 +19,35 @@ const configuration_workflow = () =>
|
|
|
15
19
|
form: async (context) => {
|
|
16
20
|
const isRoot = db.getTenantSchema() === db.connectObj.default_schema;
|
|
17
21
|
return new Form({
|
|
22
|
+
additionalHeaders: [
|
|
23
|
+
{
|
|
24
|
+
headerTag: `<script>
|
|
25
|
+
function backendChange(e) {
|
|
26
|
+
const val = e.value;
|
|
27
|
+
const authBtn = document.getElementById('vertex_authorize_btn');
|
|
28
|
+
if (val === 'Google Vertex AI') {
|
|
29
|
+
authBtn.classList.remove('d-none');
|
|
30
|
+
} else {
|
|
31
|
+
authBtn.classList.add('d-none');
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
${domReady(`
|
|
35
|
+
const backend = document.getElementById('inputbackend');
|
|
36
|
+
if (backend) {
|
|
37
|
+
backendChange(backend);
|
|
38
|
+
}`)}
|
|
39
|
+
</script>`,
|
|
40
|
+
},
|
|
41
|
+
],
|
|
42
|
+
additionalButtons: [
|
|
43
|
+
{
|
|
44
|
+
label: "authorize",
|
|
45
|
+
id: "vertex_authorize_btn",
|
|
46
|
+
onclick:
|
|
47
|
+
"location.href='/large-language-model/vertex/authorize'",
|
|
48
|
+
class: "btn btn-primary d-none",
|
|
49
|
+
},
|
|
50
|
+
],
|
|
18
51
|
fields: [
|
|
19
52
|
{
|
|
20
53
|
name: "backend",
|
|
@@ -27,8 +60,85 @@ const configuration_workflow = () =>
|
|
|
27
60
|
"OpenAI-compatible API",
|
|
28
61
|
"Local Ollama",
|
|
29
62
|
...(isRoot ? ["Local llama.cpp"] : []),
|
|
63
|
+
"Google Vertex AI",
|
|
64
|
+
],
|
|
65
|
+
onChange: "backendChange(this)",
|
|
66
|
+
},
|
|
67
|
+
},
|
|
68
|
+
{
|
|
69
|
+
name: "client_id",
|
|
70
|
+
label: "Client ID",
|
|
71
|
+
sublabel: "OAuth2 client ID from your Google Cloud account",
|
|
72
|
+
type: "String",
|
|
73
|
+
required: true,
|
|
74
|
+
showIf: { backend: "Google Vertex AI" },
|
|
75
|
+
},
|
|
76
|
+
{
|
|
77
|
+
name: "client_secret",
|
|
78
|
+
label: "Client Secret",
|
|
79
|
+
sublabel: "Client secret from your Google Cloud account",
|
|
80
|
+
type: "String",
|
|
81
|
+
required: true,
|
|
82
|
+
showIf: { backend: "Google Vertex AI" },
|
|
83
|
+
},
|
|
84
|
+
{
|
|
85
|
+
name: "project_id",
|
|
86
|
+
label: "Project ID",
|
|
87
|
+
sublabel: "Google Cloud project ID",
|
|
88
|
+
type: "String",
|
|
89
|
+
required: true,
|
|
90
|
+
showIf: { backend: "Google Vertex AI" },
|
|
91
|
+
},
|
|
92
|
+
{
|
|
93
|
+
name: "model",
|
|
94
|
+
label: "Model",
|
|
95
|
+
type: "String",
|
|
96
|
+
showIf: { backend: "Google Vertex AI" },
|
|
97
|
+
attributes: {
|
|
98
|
+
options: ["gemini-1.5-pro", "gemini-1.5-flash"],
|
|
99
|
+
},
|
|
100
|
+
required: true,
|
|
101
|
+
},
|
|
102
|
+
{
|
|
103
|
+
name: "embed_model",
|
|
104
|
+
label: "Embedding model",
|
|
105
|
+
type: "String",
|
|
106
|
+
required: true,
|
|
107
|
+
showIf: { backend: "Google Vertex AI" },
|
|
108
|
+
attributes: {
|
|
109
|
+
options: [
|
|
110
|
+
"text-embedding-005",
|
|
111
|
+
"text-embedding-004",
|
|
112
|
+
"textembedding-gecko@003",
|
|
30
113
|
],
|
|
31
114
|
},
|
|
115
|
+
default: "text-embedding-005",
|
|
116
|
+
},
|
|
117
|
+
{
|
|
118
|
+
name: "embed_task_type",
|
|
119
|
+
label: "Embedding task type",
|
|
120
|
+
type: "String",
|
|
121
|
+
showIf: { backend: "Google Vertex AI" },
|
|
122
|
+
attributes: {
|
|
123
|
+
options: [
|
|
124
|
+
"RETRIEVAL_QUERY",
|
|
125
|
+
"RETRIEVAL_DOCUMENT",
|
|
126
|
+
"SEMANTIC_SIMILARITY",
|
|
127
|
+
"CLASSIFICATION",
|
|
128
|
+
"CLUSTERING",
|
|
129
|
+
"QUESTION_ANSWERING",
|
|
130
|
+
"FACT_VERIFICATION",
|
|
131
|
+
"CODE_RETRIEVAL_QUERY",
|
|
132
|
+
],
|
|
133
|
+
},
|
|
134
|
+
default: "RETRIEVAL_QUERY",
|
|
135
|
+
},
|
|
136
|
+
{
|
|
137
|
+
name: "region",
|
|
138
|
+
label: "Region",
|
|
139
|
+
sublabel: "Google Cloud region (default: us-central1)",
|
|
140
|
+
type: "String",
|
|
141
|
+
default: "us-central1",
|
|
32
142
|
},
|
|
33
143
|
{
|
|
34
144
|
name: "api_key",
|
|
@@ -186,14 +296,90 @@ const functions = (config) => {
|
|
|
186
296
|
};
|
|
187
297
|
};
|
|
188
298
|
|
|
299
|
+
const routes = (config) => {
|
|
300
|
+
return [
|
|
301
|
+
{
|
|
302
|
+
url: "/large-language-model/vertex/authorize",
|
|
303
|
+
method: "get",
|
|
304
|
+
callback: async (req, res) => {
|
|
305
|
+
const { client_id, client_secret } = config || {};
|
|
306
|
+
const baseUrl = (
|
|
307
|
+
getState().getConfig("base_url") || "http://localhost:3000"
|
|
308
|
+
).replace(/\/$/, "");
|
|
309
|
+
const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
|
|
310
|
+
const oauth2Client = new google.auth.OAuth2(
|
|
311
|
+
client_id,
|
|
312
|
+
client_secret,
|
|
313
|
+
redirect_uri
|
|
314
|
+
);
|
|
315
|
+
const authUrl = oauth2Client.generateAuthUrl({
|
|
316
|
+
access_type: "offline",
|
|
317
|
+
scope: "https://www.googleapis.com/auth/cloud-platform",
|
|
318
|
+
});
|
|
319
|
+
res.redirect(authUrl);
|
|
320
|
+
},
|
|
321
|
+
},
|
|
322
|
+
{
|
|
323
|
+
url: "/large-language-model/vertex/callback",
|
|
324
|
+
method: "get",
|
|
325
|
+
callback: async (req, res) => {
|
|
326
|
+
const { client_id, client_secret } = config || {};
|
|
327
|
+
const baseUrl = (
|
|
328
|
+
getState().getConfig("base_url") || "http://localhost:3000"
|
|
329
|
+
).replace(/\/$/, "");
|
|
330
|
+
const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
|
|
331
|
+
const oauth2Client = new google.auth.OAuth2(
|
|
332
|
+
client_id,
|
|
333
|
+
client_secret,
|
|
334
|
+
redirect_uri
|
|
335
|
+
);
|
|
336
|
+
let plugin = await Plugin.findOne({ name: "large-language-model" });
|
|
337
|
+
if (!plugin) {
|
|
338
|
+
plugin = await Plugin.findOne({
|
|
339
|
+
name: "@saltcorn/large-language-model",
|
|
340
|
+
});
|
|
341
|
+
}
|
|
342
|
+
try {
|
|
343
|
+
const code = req.query.code;
|
|
344
|
+
if (!code) throw new Error("Missing code in query string.");
|
|
345
|
+
const { tokens } = await oauth2Client.getToken(code);
|
|
346
|
+
if (!tokens.refresh_token) {
|
|
347
|
+
req.flash(
|
|
348
|
+
"warning",
|
|
349
|
+
req.__(
|
|
350
|
+
"No refresh token received. Please revoke the plugin's access and try again."
|
|
351
|
+
)
|
|
352
|
+
);
|
|
353
|
+
} else {
|
|
354
|
+
const newConfig = { ...(plugin.configuration || {}), tokens };
|
|
355
|
+
plugin.configuration = newConfig;
|
|
356
|
+
await plugin.upsert();
|
|
357
|
+
req.flash(
|
|
358
|
+
"success",
|
|
359
|
+
req.__("Authentication successful! You can now use Vertex AI.")
|
|
360
|
+
);
|
|
361
|
+
}
|
|
362
|
+
} catch (error) {
|
|
363
|
+
console.error("Error retrieving access token:", error);
|
|
364
|
+
req.flash("error", req.__("Error retrieving access"));
|
|
365
|
+
} finally {
|
|
366
|
+
res.redirect(`/plugins/configure/${encodeURIComponent(plugin.name)}`);
|
|
367
|
+
}
|
|
368
|
+
},
|
|
369
|
+
},
|
|
370
|
+
];
|
|
371
|
+
};
|
|
372
|
+
|
|
189
373
|
module.exports = {
|
|
190
374
|
sc_plugin_api_version: 1,
|
|
191
375
|
configuration_workflow,
|
|
192
376
|
functions,
|
|
193
377
|
modelpatterns: require("./model.js"),
|
|
378
|
+
routes,
|
|
194
379
|
actions: (config) => ({
|
|
195
380
|
llm_function_call: require("./function-insert-action.js")(config),
|
|
196
381
|
llm_generate: {
|
|
382
|
+
description: "Generate text with AI based on a text prompt",
|
|
197
383
|
requireRow: true,
|
|
198
384
|
configFields: ({ table, mode }) => {
|
|
199
385
|
const override_fields =
|
|
@@ -317,7 +503,7 @@ module.exports = {
|
|
|
317
503
|
upd[chat_history_field] = [
|
|
318
504
|
...history,
|
|
319
505
|
{ role: "user", content: prompt },
|
|
320
|
-
{ role: "
|
|
506
|
+
{ role: "assistant", content: ans },
|
|
321
507
|
];
|
|
322
508
|
}
|
|
323
509
|
if (mode === "workflow") return upd;
|
|
@@ -325,6 +511,8 @@ module.exports = {
|
|
|
325
511
|
},
|
|
326
512
|
},
|
|
327
513
|
llm_generate_json: {
|
|
514
|
+
description:
|
|
515
|
+
"Generate JSON with AI based on a text prompt. You must sppecify the JSON fields in the configuration.",
|
|
328
516
|
requireRow: true,
|
|
329
517
|
configFields: ({ table, mode }) => {
|
|
330
518
|
const override_fields =
|
|
@@ -427,7 +615,7 @@ module.exports = {
|
|
|
427
615
|
label: "Multiple",
|
|
428
616
|
type: "Bool",
|
|
429
617
|
sublabel:
|
|
430
|
-
"Select to generate an array of objects. Unselect for a single object",
|
|
618
|
+
"Select (true) to generate an array of objects. Unselect (false) for a single object",
|
|
431
619
|
},
|
|
432
620
|
{
|
|
433
621
|
name: "gen_description",
|
|
@@ -506,13 +694,15 @@ module.exports = {
|
|
|
506
694
|
...opts,
|
|
507
695
|
...toolargs,
|
|
508
696
|
});
|
|
509
|
-
const ans = JSON.parse(compl.tool_calls[0].function.arguments)[
|
|
697
|
+
const ans = JSON.parse(compl.tool_calls[0].function.arguments)[
|
|
698
|
+
answer_field
|
|
699
|
+
];
|
|
510
700
|
const upd = { [answer_field]: ans };
|
|
511
701
|
if (chat_history_field) {
|
|
512
702
|
upd[chat_history_field] = [
|
|
513
703
|
...history,
|
|
514
704
|
{ role: "user", content: prompt },
|
|
515
|
-
{ role: "
|
|
705
|
+
{ role: "assistant", content: ans },
|
|
516
706
|
];
|
|
517
707
|
}
|
|
518
708
|
if (mode === "workflow") return upd;
|
package/package.json
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@saltcorn/large-language-model",
|
|
3
|
-
"version": "0.
|
|
3
|
+
"version": "0.7.0",
|
|
4
4
|
"description": "Large language models and functionality for Saltcorn",
|
|
5
5
|
"main": "index.js",
|
|
6
6
|
"dependencies": {
|
|
7
7
|
"@saltcorn/data": "^0.9.0",
|
|
8
8
|
"node-fetch": "2.6.9",
|
|
9
9
|
"underscore": "1.13.6",
|
|
10
|
-
"ollama": "0.5.0"
|
|
10
|
+
"ollama": "0.5.0",
|
|
11
|
+
"@google-cloud/vertexai": "^1.9.3",
|
|
12
|
+
"@google-cloud/aiplatform": "^3.34.0",
|
|
13
|
+
"googleapis": "^144.0.0"
|
|
11
14
|
},
|
|
12
15
|
"author": "Tom Nielsen",
|
|
13
16
|
"license": "MIT",
|