@saltcorn/large-language-model 0.6.6 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/generate.js +120 -0
  2. package/index.js +190 -2
  3. package/package.json +5 -2
package/generate.js CHANGED
@@ -2,6 +2,13 @@ const fetch = require("node-fetch");
2
2
  const util = require("util");
3
3
  const exec = util.promisify(require("child_process").exec);
4
4
  const db = require("@saltcorn/data/db");
5
+ const { VertexAI } = require("@google-cloud/vertexai");
6
+ const {
7
+ PredictionServiceClient,
8
+ helpers,
9
+ } = require("@google-cloud/aiplatform");
10
+ const { google } = require("googleapis");
11
+ const Plugin = require("@saltcorn/data/models/plugin");
5
12
 
6
13
  const { features, getState } = require("@saltcorn/data/db/state");
7
14
  let ollamaMod;
@@ -57,6 +64,13 @@ const getEmbedding = async (config, opts) => {
57
64
  //console.log("embedding response ", olres);
58
65
  return olres.embedding;
59
66
  }
67
+ case "Google Vertex AI":
68
+ const oauth2Client = await initOAuth2Client(config);
69
+ if (oauth2Client.isTokenExpiring()) {
70
+ const { credentials } = await oauth2Client.refreshAccessToken();
71
+ await updatePluginTokenCfg(credentials);
72
+ }
73
+ return await getEmbeddingGoogleVertex(config, opts, oauth2Client);
60
74
  default:
61
75
  throw new Error("Not implemented for this backend");
62
76
  }
@@ -117,6 +131,13 @@ const getCompletion = async (config, opts) => {
117
131
  { cwd: config.llama_dir }
118
132
  );
119
133
  return stdout;
134
+ case "Google Vertex AI":
135
+ const oauth2Client = await initOAuth2Client(config);
136
+ if (oauth2Client.isTokenExpiring()) {
137
+ const { credentials } = await oauth2Client.refreshAccessToken();
138
+ await updatePluginTokenCfg(credentials);
139
+ }
140
+ return await getCompletionGoogleVertex(config, opts, oauth2Client);
120
141
  default:
121
142
  break;
122
143
  }
@@ -211,4 +232,103 @@ const getEmbeddingOpenAICompatible = async (
211
232
  if (Array.isArray(prompt)) return results?.data?.map?.((d) => d?.embedding);
212
233
  return results?.data?.[0]?.embedding;
213
234
  };
235
+
236
+ const updatePluginTokenCfg = async (credentials) => {
237
+ let plugin = await Plugin.findOne({ name: "large-language-model" });
238
+ if (!plugin) {
239
+ plugin = await Plugin.findOne({
240
+ name: "@saltcorn/large-language-model",
241
+ });
242
+ }
243
+ const newConfig = {
244
+ ...(plugin.configuration || {}),
245
+ tokens: credentials,
246
+ };
247
+ plugin.configuration = newConfig;
248
+ await plugin.upsert();
249
+ getState().processSend({
250
+ refresh_plugin_cfg: plugin.name,
251
+ tenant: db.getTenantSchema(),
252
+ });
253
+ };
254
+
255
+ const initOAuth2Client = async (config) => {
256
+ const { client_id, client_secret } = config || {};
257
+ const state = getState();
258
+ const pluginCfg =
259
+ state.plugin_cfgs.large_language_model ||
260
+ state.plugin_cfgs["@saltcorn/large-language-model"];
261
+ const baseUrl = (
262
+ getState().getConfig("base_url") || "http://localhost:3000"
263
+ ).replace(/\/$/, "");
264
+ const redirect_uri = `${baseUrl}/callback`;
265
+
266
+ const oauth2Client = new google.auth.OAuth2(
267
+ client_id,
268
+ client_secret,
269
+ redirect_uri
270
+ );
271
+ oauth2Client.setCredentials(pluginCfg.tokens);
272
+ return oauth2Client;
273
+ };
274
+
275
+ const getCompletionGoogleVertex = async (config, opts, oauth2Client) => {
276
+ const vertexAI = new VertexAI({
277
+ project: config.project_id,
278
+ location: config.region || "us-central1",
279
+ googleAuthOptions: {
280
+ authClient: oauth2Client,
281
+ },
282
+ });
283
+ const generativeModel = vertexAI.getGenerativeModel({
284
+ model: config.model,
285
+ });
286
+ const chat = generativeModel.startChat();
287
+ const result = await chat.sendMessageStream(opts.prompt);
288
+ const chunks = [];
289
+ for await (const item of result.stream) {
290
+ chunks.push(item.candidates[0].content.parts[0].text);
291
+ }
292
+ return chunks.join();
293
+ };
294
+
295
+ const getEmbeddingGoogleVertex = async (config, opts, oauth2Client) => {
296
+ const predClient = new PredictionServiceClient({
297
+ apiEndpoint: "us-central1-aiplatform.googleapis.com",
298
+ authClient: oauth2Client,
299
+ });
300
+ const model = config.embed_model || "text-embedding-005";
301
+ let instances = null;
302
+ if (Array.isArray(opts.prompt)) {
303
+ instances = opts.prompt.map((p) =>
304
+ helpers.toValue({
305
+ content: p,
306
+ task_type: config.task_type || "RETRIEVAL_QUERY",
307
+ })
308
+ );
309
+ } else {
310
+ instances = [
311
+ helpers.toValue({
312
+ content: opts.prompt,
313
+ task_type: config.task_type || "RETRIEVAL_QUERY",
314
+ }),
315
+ ];
316
+ }
317
+ const [response] = await predClient.predict({
318
+ endpoint: `projects/${config.project_id}/locations/${
319
+ config.region || "us-central1"
320
+ }/publishers/google/models/${model}`,
321
+ instances,
322
+ // default outputDimensionality is 768, can be changed with:
323
+ // parameters: helpers.toValue({ outputDimensionality: parseInt(512) }),
324
+ });
325
+ const predictions = response.predictions;
326
+ const embeddings = predictions.map((p) => {
327
+ const embeddingsProto = p.structValue.fields.embeddings;
328
+ const valuesProto = embeddingsProto.structValue.fields.values;
329
+ return valuesProto.listValue.values.map((v) => v.numberValue);
330
+ });
331
+ return embeddings;
332
+ };
333
+
214
334
  module.exports = { getCompletion, getEmbedding };
package/index.js CHANGED
@@ -1,11 +1,15 @@
1
1
  const Workflow = require("@saltcorn/data/models/workflow");
2
2
  const Form = require("@saltcorn/data/models/form");
3
3
  const FieldRepeat = require("@saltcorn/data/models/fieldrepeat");
4
+ const Plugin = require("@saltcorn/data/models/plugin");
5
+ const { domReady } = require("@saltcorn/markup/tags");
4
6
  const db = require("@saltcorn/data/db");
5
7
  const { getCompletion, getEmbedding } = require("./generate");
6
8
  const { OPENAI_MODELS } = require("./constants.js");
7
9
  const { eval_expression } = require("@saltcorn/data/models/expression");
8
10
  const { interpolate } = require("@saltcorn/data/utils");
11
+ const { getState } = require("@saltcorn/data/db/state");
12
+ const { google } = require("googleapis");
9
13
 
10
14
  const configuration_workflow = () =>
11
15
  new Workflow({
@@ -15,6 +19,35 @@ const configuration_workflow = () =>
15
19
  form: async (context) => {
16
20
  const isRoot = db.getTenantSchema() === db.connectObj.default_schema;
17
21
  return new Form({
22
+ additionalHeaders: [
23
+ {
24
+ headerTag: `<script>
25
+ function backendChange(e) {
26
+ const val = e.value;
27
+ const authBtn = document.getElementById('vertex_authorize_btn');
28
+ if (val === 'Google Vertex AI') {
29
+ authBtn.classList.remove('d-none');
30
+ } else {
31
+ authBtn.classList.add('d-none');
32
+ }
33
+ }
34
+ ${domReady(`
35
+ const backend = document.getElementById('inputbackend');
36
+ if (backend) {
37
+ backendChange(backend);
38
+ }`)}
39
+ </script>`,
40
+ },
41
+ ],
42
+ additionalButtons: [
43
+ {
44
+ label: "authorize",
45
+ id: "vertex_authorize_btn",
46
+ onclick:
47
+ "location.href='/large-language-model/vertex/authorize'",
48
+ class: "btn btn-primary d-none",
49
+ },
50
+ ],
18
51
  fields: [
19
52
  {
20
53
  name: "backend",
@@ -27,8 +60,85 @@ const configuration_workflow = () =>
27
60
  "OpenAI-compatible API",
28
61
  "Local Ollama",
29
62
  ...(isRoot ? ["Local llama.cpp"] : []),
63
+ "Google Vertex AI",
64
+ ],
65
+ onChange: "backendChange(this)",
66
+ },
67
+ },
68
+ {
69
+ name: "client_id",
70
+ label: "Client ID",
71
+ sublabel: "OAuth2 client ID from your Google Cloud account",
72
+ type: "String",
73
+ required: true,
74
+ showIf: { backend: "Google Vertex AI" },
75
+ },
76
+ {
77
+ name: "client_secret",
78
+ label: "Client Secret",
79
+ sublabel: "Client secret from your Google Cloud account",
80
+ type: "String",
81
+ required: true,
82
+ showIf: { backend: "Google Vertex AI" },
83
+ },
84
+ {
85
+ name: "project_id",
86
+ label: "Project ID",
87
+ sublabel: "Google Cloud project ID",
88
+ type: "String",
89
+ required: true,
90
+ showIf: { backend: "Google Vertex AI" },
91
+ },
92
+ {
93
+ name: "model",
94
+ label: "Model",
95
+ type: "String",
96
+ showIf: { backend: "Google Vertex AI" },
97
+ attributes: {
98
+ options: ["gemini-1.5-pro", "gemini-1.5-flash"],
99
+ },
100
+ required: true,
101
+ },
102
+ {
103
+ name: "embed_model",
104
+ label: "Embedding model",
105
+ type: "String",
106
+ required: true,
107
+ showIf: { backend: "Google Vertex AI" },
108
+ attributes: {
109
+ options: [
110
+ "text-embedding-005",
111
+ "text-embedding-004",
112
+ "textembedding-gecko@003",
30
113
  ],
31
114
  },
115
+ default: "text-embedding-005",
116
+ },
117
+ {
118
+ name: "embed_task_type",
119
+ label: "Embedding task type",
120
+ type: "String",
121
+ showIf: { backend: "Google Vertex AI" },
122
+ attributes: {
123
+ options: [
124
+ "RETRIEVAL_QUERY",
125
+ "RETRIEVAL_DOCUMENT",
126
+ "SEMANTIC_SIMILARITY",
127
+ "CLASSIFICATION",
128
+ "CLUSTERING",
129
+ "QUESTION_ANSWERING",
130
+ "FACT_VERIFICATION",
131
+ "CODE_RETRIEVAL_QUERY",
132
+ ],
133
+ },
134
+ default: "RETRIEVAL_QUERY",
135
+ },
136
+ {
137
+ name: "region",
138
+ label: "Region",
139
+ sublabel: "Google Cloud region (default: us-central1)",
140
+ type: "String",
141
+ default: "us-central1",
32
142
  },
33
143
  {
34
144
  name: "api_key",
@@ -186,11 +296,86 @@ const functions = (config) => {
186
296
  };
187
297
  };
188
298
 
299
+ const routes = (config) => {
300
+ return [
301
+ {
302
+ url: "/large-language-model/vertex/authorize",
303
+ method: "get",
304
+ callback: async (req, res) => {
305
+ const { client_id, client_secret } = config || {};
306
+ const baseUrl = (
307
+ getState().getConfig("base_url") || "http://localhost:3000"
308
+ ).replace(/\/$/, "");
309
+ const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
310
+ const oauth2Client = new google.auth.OAuth2(
311
+ client_id,
312
+ client_secret,
313
+ redirect_uri
314
+ );
315
+ const authUrl = oauth2Client.generateAuthUrl({
316
+ access_type: "offline",
317
+ scope: "https://www.googleapis.com/auth/cloud-platform",
318
+ });
319
+ res.redirect(authUrl);
320
+ },
321
+ },
322
+ {
323
+ url: "/large-language-model/vertex/callback",
324
+ method: "get",
325
+ callback: async (req, res) => {
326
+ const { client_id, client_secret } = config || {};
327
+ const baseUrl = (
328
+ getState().getConfig("base_url") || "http://localhost:3000"
329
+ ).replace(/\/$/, "");
330
+ const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
331
+ const oauth2Client = new google.auth.OAuth2(
332
+ client_id,
333
+ client_secret,
334
+ redirect_uri
335
+ );
336
+ let plugin = await Plugin.findOne({ name: "large-language-model" });
337
+ if (!plugin) {
338
+ plugin = await Plugin.findOne({
339
+ name: "@saltcorn/large-language-model",
340
+ });
341
+ }
342
+ try {
343
+ const code = req.query.code;
344
+ if (!code) throw new Error("Missing code in query string.");
345
+ const { tokens } = await oauth2Client.getToken(code);
346
+ if (!tokens.refresh_token) {
347
+ req.flash(
348
+ "warning",
349
+ req.__(
350
+ "No refresh token received. Please revoke the plugin's access and try again."
351
+ )
352
+ );
353
+ } else {
354
+ const newConfig = { ...(plugin.configuration || {}), tokens };
355
+ plugin.configuration = newConfig;
356
+ await plugin.upsert();
357
+ req.flash(
358
+ "success",
359
+ req.__("Authentication successful! You can now use Vertex AI.")
360
+ );
361
+ }
362
+ } catch (error) {
363
+ console.error("Error retrieving access token:", error);
364
+ req.flash("error", req.__("Error retrieving access"));
365
+ } finally {
366
+ res.redirect(`/plugins/configure/${encodeURIComponent(plugin.name)}`);
367
+ }
368
+ },
369
+ },
370
+ ];
371
+ };
372
+
189
373
  module.exports = {
190
374
  sc_plugin_api_version: 1,
191
375
  configuration_workflow,
192
376
  functions,
193
377
  modelpatterns: require("./model.js"),
378
+ routes,
194
379
  actions: (config) => ({
195
380
  llm_function_call: require("./function-insert-action.js")(config),
196
381
  llm_generate: {
@@ -326,7 +511,8 @@ module.exports = {
326
511
  },
327
512
  },
328
513
  llm_generate_json: {
329
- description: "Generate JSON with AI based on a text prompt. You must sppecify the JSON fields in the configuration.",
514
+ description:
515
+ "Generate JSON with AI based on a text prompt. You must sppecify the JSON fields in the configuration.",
330
516
  requireRow: true,
331
517
  configFields: ({ table, mode }) => {
332
518
  const override_fields =
@@ -508,7 +694,9 @@ module.exports = {
508
694
  ...opts,
509
695
  ...toolargs,
510
696
  });
511
- const ans = JSON.parse(compl.tool_calls[0].function.arguments)[answer_field];
697
+ const ans = JSON.parse(compl.tool_calls[0].function.arguments)[
698
+ answer_field
699
+ ];
512
700
  const upd = { [answer_field]: ans };
513
701
  if (chat_history_field) {
514
702
  upd[chat_history_field] = [
package/package.json CHANGED
@@ -1,13 +1,16 @@
1
1
  {
2
2
  "name": "@saltcorn/large-language-model",
3
- "version": "0.6.6",
3
+ "version": "0.7.1",
4
4
  "description": "Large language models and functionality for Saltcorn",
5
5
  "main": "index.js",
6
6
  "dependencies": {
7
7
  "@saltcorn/data": "^0.9.0",
8
8
  "node-fetch": "2.6.9",
9
9
  "underscore": "1.13.6",
10
- "ollama": "0.5.0"
10
+ "ollama": "0.5.0",
11
+ "@google-cloud/vertexai": "^1.9.3",
12
+ "@google-cloud/aiplatform": "^3.34.0",
13
+ "googleapis": "^144.0.0"
11
14
  },
12
15
  "author": "Tom Nielsen",
13
16
  "license": "MIT",