@saltcorn/large-language-model 0.6.5 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/generate.js +126 -6
  2. package/index.js +194 -4
  3. package/package.json +5 -2
package/generate.js CHANGED
@@ -2,6 +2,13 @@ const fetch = require("node-fetch");
2
2
  const util = require("util");
3
3
  const exec = util.promisify(require("child_process").exec);
4
4
  const db = require("@saltcorn/data/db");
5
+ const { VertexAI } = require("@google-cloud/vertexai");
6
+ const {
7
+ PredictionServiceClient,
8
+ helpers,
9
+ } = require("@google-cloud/aiplatform");
10
+ const { google } = require("googleapis");
11
+ const Plugin = require("@saltcorn/data/models/plugin");
5
12
 
6
13
  const { features, getState } = require("@saltcorn/data/db/state");
7
14
  let ollamaMod;
@@ -57,6 +64,13 @@ const getEmbedding = async (config, opts) => {
57
64
  //console.log("embedding response ", olres);
58
65
  return olres.embedding;
59
66
  }
67
+ case "Google Vertex AI":
68
+ const oauth2Client = await initOAuth2Client(config);
69
+ if (oauth2Client.isTokenExpiring()) {
70
+ const { credentials } = await oauth2Client.refreshAccessToken();
71
+ await updatePluginTokenCfg(credentials);
72
+ }
73
+ return await getEmbeddingGoogleVertex(config, opts, oauth2Client);
60
74
  default:
61
75
  throw new Error("Not implemented for this backend");
62
76
  }
@@ -117,6 +131,13 @@ const getCompletion = async (config, opts) => {
117
131
  { cwd: config.llama_dir }
118
132
  );
119
133
  return stdout;
134
+ case "Google Vertex AI":
135
+ const oauth2Client = await initOAuth2Client(config);
136
+ if (oauth2Client.isTokenExpiring()) {
137
+ const { credentials } = await oauth2Client.refreshAccessToken();
138
+ await updatePluginTokenCfg(credentials);
139
+ }
140
+ return await getCompletionGoogleVertex(config, opts, oauth2Client);
120
141
  default:
121
142
  break;
122
143
  }
@@ -174,12 +195,12 @@ const getCompletionOpenAICompatible = async (
174
195
  console.log("OpenAI response", JSON.stringify(results, null, 2));
175
196
  if (results.error) throw new Error(`OpenAI error: ${results.error.message}`);
176
197
 
177
- return (
178
- results?.choices?.[0]?.message?.content ||
179
- (results?.choices?.[0]?.message?.tool_calls
180
- ? { tool_calls: results?.choices?.[0]?.message?.tool_calls }
181
- : null)
182
- );
198
+ return results?.choices?.[0]?.message?.tool_calls
199
+ ? {
200
+ tool_calls: results?.choices?.[0]?.message?.tool_calls,
201
+ content: results?.choices?.[0]?.message?.content || null,
202
+ }
203
+ : results?.choices?.[0]?.message?.content || null;
183
204
  };
184
205
 
185
206
  const getEmbeddingOpenAICompatible = async (
@@ -211,4 +232,103 @@ const getEmbeddingOpenAICompatible = async (
211
232
  if (Array.isArray(prompt)) return results?.data?.map?.((d) => d?.embedding);
212
233
  return results?.data?.[0]?.embedding;
213
234
  };
235
+
236
+ const updatePluginTokenCfg = async (credentials) => {
237
+ let plugin = await Plugin.findOne({ name: "large-language-model" });
238
+ if (!plugin) {
239
+ plugin = await Plugin.findOne({
240
+ name: "@saltcorn/large-language-model",
241
+ });
242
+ }
243
+ const newConfig = {
244
+ ...(plugin.configuration || {}),
245
+ tokens: credentials,
246
+ };
247
+ plugin.configuration = newConfig;
248
+ await plugin.upsert();
249
+ getState().processSend({
250
+ refresh_plugin_cfg: plugin.name,
251
+ tenant: db.getTenantSchema(),
252
+ });
253
+ };
254
+
255
+ const initOAuth2Client = async (config) => {
256
+ const { client_id, client_secret } = config || {};
257
+ const state = getState();
258
+ const pluginCfg =
259
+ state.plugin_cfgs.large_language_model ||
260
+ state.plugin_cfgs["@saltcorn/large-language-model"];
261
+ const baseUrl = (
262
+ getState().getConfig("base_url") || "http://localhost:3000"
263
+ ).replace(/\/$/, "");
264
+ const redirect_uri = `${baseUrl}/callback`;
265
+
266
+ const oauth2Client = new google.auth.OAuth2(
267
+ client_id,
268
+ client_secret,
269
+ redirect_uri
270
+ );
271
+ oauth2Client.setCredentials(pluginCfg.tokens);
272
+ return oauth2Client;
273
+ };
274
+
275
+ const getCompletionGoogleVertex = async (config, opts, oauth2Client) => {
276
+ const vertexAI = new VertexAI({
277
+ project: config.project_id,
278
+ location: config.region || "us-central1",
279
+ googleAuthOptions: {
280
+ authClient: oauth2Client,
281
+ },
282
+ });
283
+ const generativeModel = vertexAI.getGenerativeModel({
284
+ model: config.model,
285
+ });
286
+ const chat = generativeModel.startChat();
287
+ const result = await chat.sendMessageStream(opts.prompt);
288
+ const chunks = [];
289
+ for await (const item of result.stream) {
290
+ chunks.push(item.candidates[0].content.parts[0].text);
291
+ }
292
+ return chunks.join();
293
+ };
294
+
295
+ const getEmbeddingGoogleVertex = async (config, opts, oauth2Client) => {
296
+ const predClient = new PredictionServiceClient({
297
+ apiEndpoint: "us-central1-aiplatform.googleapis.com",
298
+ authClient: oauth2Client,
299
+ });
300
+ const model = config.embed_model || "text-embedding-005";
301
+ let instances = null;
302
+ if (Array.isArray(opts.prompt)) {
303
+ instances = opts.prompt.map((p) =>
304
+ helpers.toValue({
305
+ content: p,
306
+ task_type: config.task_type || "RETRIEVAL_QUERY",
307
+ })
308
+ );
309
+ } else {
310
+ instances = [
311
+ helpers.toValue({
312
+ content: opts.prompt,
313
+ task_type: config.task_type || "RETRIEVAL_QUERY",
314
+ }),
315
+ ];
316
+ }
317
+ const [response] = await predClient.predict({
318
+ endpoint: `projects/${config.project_id}/locations/${
319
+ config.region || "us-central1"
320
+ }/publishers/google/models/${model}`,
321
+ instances,
322
+ // default outputDimensionality is 768, can be changed with:
323
+ // parameters: helpers.toValue({ outputDimensionality: parseInt(512) }),
324
+ });
325
+ const predictions = response.predictions;
326
+ const embeddings = predictions.map((p) => {
327
+ const embeddingsProto = p.structValue.fields.embeddings;
328
+ const valuesProto = embeddingsProto.structValue.fields.values;
329
+ return valuesProto.listValue.values.map((v) => v.numberValue);
330
+ });
331
+ return embeddings;
332
+ };
333
+
214
334
  module.exports = { getCompletion, getEmbedding };
package/index.js CHANGED
@@ -1,11 +1,15 @@
1
1
  const Workflow = require("@saltcorn/data/models/workflow");
2
2
  const Form = require("@saltcorn/data/models/form");
3
3
  const FieldRepeat = require("@saltcorn/data/models/fieldrepeat");
4
+ const Plugin = require("@saltcorn/data/models/plugin");
5
+ const { domReady } = require("@saltcorn/markup/tags");
4
6
  const db = require("@saltcorn/data/db");
5
7
  const { getCompletion, getEmbedding } = require("./generate");
6
8
  const { OPENAI_MODELS } = require("./constants.js");
7
9
  const { eval_expression } = require("@saltcorn/data/models/expression");
8
10
  const { interpolate } = require("@saltcorn/data/utils");
11
+ const { getState } = require("@saltcorn/data/db/state");
12
+ const { google } = require("googleapis");
9
13
 
10
14
  const configuration_workflow = () =>
11
15
  new Workflow({
@@ -15,6 +19,35 @@ const configuration_workflow = () =>
15
19
  form: async (context) => {
16
20
  const isRoot = db.getTenantSchema() === db.connectObj.default_schema;
17
21
  return new Form({
22
+ additionalHeaders: [
23
+ {
24
+ headerTag: `<script>
25
+ function backendChange(e) {
26
+ const val = e.value;
27
+ const authBtn = document.getElementById('vertex_authorize_btn');
28
+ if (val === 'Google Vertex AI') {
29
+ authBtn.classList.remove('d-none');
30
+ } else {
31
+ authBtn.classList.add('d-none');
32
+ }
33
+ }
34
+ ${domReady(`
35
+ const backend = document.getElementById('inputbackend');
36
+ if (backend) {
37
+ backendChange(backend);
38
+ }`)}
39
+ </script>`,
40
+ },
41
+ ],
42
+ additionalButtons: [
43
+ {
44
+ label: "authorize",
45
+ id: "vertex_authorize_btn",
46
+ onclick:
47
+ "location.href='/large-language-model/vertex/authorize'",
48
+ class: "btn btn-primary d-none",
49
+ },
50
+ ],
18
51
  fields: [
19
52
  {
20
53
  name: "backend",
@@ -27,8 +60,85 @@ const configuration_workflow = () =>
27
60
  "OpenAI-compatible API",
28
61
  "Local Ollama",
29
62
  ...(isRoot ? ["Local llama.cpp"] : []),
63
+ "Google Vertex AI",
64
+ ],
65
+ onChange: "backendChange(this)",
66
+ },
67
+ },
68
+ {
69
+ name: "client_id",
70
+ label: "Client ID",
71
+ sublabel: "OAuth2 client ID from your Google Cloud account",
72
+ type: "String",
73
+ required: true,
74
+ showIf: { backend: "Google Vertex AI" },
75
+ },
76
+ {
77
+ name: "client_secret",
78
+ label: "Client Secret",
79
+ sublabel: "Client secret from your Google Cloud account",
80
+ type: "String",
81
+ required: true,
82
+ showIf: { backend: "Google Vertex AI" },
83
+ },
84
+ {
85
+ name: "project_id",
86
+ label: "Project ID",
87
+ sublabel: "Google Cloud project ID",
88
+ type: "String",
89
+ required: true,
90
+ showIf: { backend: "Google Vertex AI" },
91
+ },
92
+ {
93
+ name: "model",
94
+ label: "Model",
95
+ type: "String",
96
+ showIf: { backend: "Google Vertex AI" },
97
+ attributes: {
98
+ options: ["gemini-1.5-pro", "gemini-1.5-flash"],
99
+ },
100
+ required: true,
101
+ },
102
+ {
103
+ name: "embed_model",
104
+ label: "Embedding model",
105
+ type: "String",
106
+ required: true,
107
+ showIf: { backend: "Google Vertex AI" },
108
+ attributes: {
109
+ options: [
110
+ "text-embedding-005",
111
+ "text-embedding-004",
112
+ "textembedding-gecko@003",
30
113
  ],
31
114
  },
115
+ default: "text-embedding-005",
116
+ },
117
+ {
118
+ name: "embed_task_type",
119
+ label: "Embedding task type",
120
+ type: "String",
121
+ showIf: { backend: "Google Vertex AI" },
122
+ attributes: {
123
+ options: [
124
+ "RETRIEVAL_QUERY",
125
+ "RETRIEVAL_DOCUMENT",
126
+ "SEMANTIC_SIMILARITY",
127
+ "CLASSIFICATION",
128
+ "CLUSTERING",
129
+ "QUESTION_ANSWERING",
130
+ "FACT_VERIFICATION",
131
+ "CODE_RETRIEVAL_QUERY",
132
+ ],
133
+ },
134
+ default: "RETRIEVAL_QUERY",
135
+ },
136
+ {
137
+ name: "region",
138
+ label: "Region",
139
+ sublabel: "Google Cloud region (default: us-central1)",
140
+ type: "String",
141
+ default: "us-central1",
32
142
  },
33
143
  {
34
144
  name: "api_key",
@@ -186,14 +296,90 @@ const functions = (config) => {
186
296
  };
187
297
  };
188
298
 
299
+ const routes = (config) => {
300
+ return [
301
+ {
302
+ url: "/large-language-model/vertex/authorize",
303
+ method: "get",
304
+ callback: async (req, res) => {
305
+ const { client_id, client_secret } = config || {};
306
+ const baseUrl = (
307
+ getState().getConfig("base_url") || "http://localhost:3000"
308
+ ).replace(/\/$/, "");
309
+ const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
310
+ const oauth2Client = new google.auth.OAuth2(
311
+ client_id,
312
+ client_secret,
313
+ redirect_uri
314
+ );
315
+ const authUrl = oauth2Client.generateAuthUrl({
316
+ access_type: "offline",
317
+ scope: "https://www.googleapis.com/auth/cloud-platform",
318
+ });
319
+ res.redirect(authUrl);
320
+ },
321
+ },
322
+ {
323
+ url: "/large-language-model/vertex/callback",
324
+ method: "get",
325
+ callback: async (req, res) => {
326
+ const { client_id, client_secret } = config || {};
327
+ const baseUrl = (
328
+ getState().getConfig("base_url") || "http://localhost:3000"
329
+ ).replace(/\/$/, "");
330
+ const redirect_uri = `${baseUrl}/large-language-model/vertex/callback`;
331
+ const oauth2Client = new google.auth.OAuth2(
332
+ client_id,
333
+ client_secret,
334
+ redirect_uri
335
+ );
336
+ let plugin = await Plugin.findOne({ name: "large-language-model" });
337
+ if (!plugin) {
338
+ plugin = await Plugin.findOne({
339
+ name: "@saltcorn/large-language-model",
340
+ });
341
+ }
342
+ try {
343
+ const code = req.query.code;
344
+ if (!code) throw new Error("Missing code in query string.");
345
+ const { tokens } = await oauth2Client.getToken(code);
346
+ if (!tokens.refresh_token) {
347
+ req.flash(
348
+ "warning",
349
+ req.__(
350
+ "No refresh token received. Please revoke the plugin's access and try again."
351
+ )
352
+ );
353
+ } else {
354
+ const newConfig = { ...(plugin.configuration || {}), tokens };
355
+ plugin.configuration = newConfig;
356
+ await plugin.upsert();
357
+ req.flash(
358
+ "success",
359
+ req.__("Authentication successful! You can now use Vertex AI.")
360
+ );
361
+ }
362
+ } catch (error) {
363
+ console.error("Error retrieving access token:", error);
364
+ req.flash("error", req.__("Error retrieving access"));
365
+ } finally {
366
+ res.redirect(`/plugins/configure/${encodeURIComponent(plugin.name)}`);
367
+ }
368
+ },
369
+ },
370
+ ];
371
+ };
372
+
189
373
  module.exports = {
190
374
  sc_plugin_api_version: 1,
191
375
  configuration_workflow,
192
376
  functions,
193
377
  modelpatterns: require("./model.js"),
378
+ routes,
194
379
  actions: (config) => ({
195
380
  llm_function_call: require("./function-insert-action.js")(config),
196
381
  llm_generate: {
382
+ description: "Generate text with AI based on a text prompt",
197
383
  requireRow: true,
198
384
  configFields: ({ table, mode }) => {
199
385
  const override_fields =
@@ -317,7 +503,7 @@ module.exports = {
317
503
  upd[chat_history_field] = [
318
504
  ...history,
319
505
  { role: "user", content: prompt },
320
- { role: "system", content: ans },
506
+ { role: "assistant", content: ans },
321
507
  ];
322
508
  }
323
509
  if (mode === "workflow") return upd;
@@ -325,6 +511,8 @@ module.exports = {
325
511
  },
326
512
  },
327
513
  llm_generate_json: {
514
+ description:
515
+ "Generate JSON with AI based on a text prompt. You must sppecify the JSON fields in the configuration.",
328
516
  requireRow: true,
329
517
  configFields: ({ table, mode }) => {
330
518
  const override_fields =
@@ -427,7 +615,7 @@ module.exports = {
427
615
  label: "Multiple",
428
616
  type: "Bool",
429
617
  sublabel:
430
- "Select to generate an array of objects. Unselect for a single object",
618
+ "Select (true) to generate an array of objects. Unselect (false) for a single object",
431
619
  },
432
620
  {
433
621
  name: "gen_description",
@@ -506,13 +694,15 @@ module.exports = {
506
694
  ...opts,
507
695
  ...toolargs,
508
696
  });
509
- const ans = JSON.parse(compl.tool_calls[0].function.arguments)[answer_field];
697
+ const ans = JSON.parse(compl.tool_calls[0].function.arguments)[
698
+ answer_field
699
+ ];
510
700
  const upd = { [answer_field]: ans };
511
701
  if (chat_history_field) {
512
702
  upd[chat_history_field] = [
513
703
  ...history,
514
704
  { role: "user", content: prompt },
515
- { role: "system", content: ans },
705
+ { role: "assistant", content: ans },
516
706
  ];
517
707
  }
518
708
  if (mode === "workflow") return upd;
package/package.json CHANGED
@@ -1,13 +1,16 @@
1
1
  {
2
2
  "name": "@saltcorn/large-language-model",
3
- "version": "0.6.5",
3
+ "version": "0.7.0",
4
4
  "description": "Large language models and functionality for Saltcorn",
5
5
  "main": "index.js",
6
6
  "dependencies": {
7
7
  "@saltcorn/data": "^0.9.0",
8
8
  "node-fetch": "2.6.9",
9
9
  "underscore": "1.13.6",
10
- "ollama": "0.5.0"
10
+ "ollama": "0.5.0",
11
+ "@google-cloud/vertexai": "^1.9.3",
12
+ "@google-cloud/aiplatform": "^3.34.0",
13
+ "googleapis": "^144.0.0"
11
14
  },
12
15
  "author": "Tom Nielsen",
13
16
  "license": "MIT",