@aj-archipelago/cortex 1.0.4 → 1.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/README.md +3 -3
  2. package/config/default.example.json +18 -0
  3. package/config.js +28 -8
  4. package/helper_apps/MediaFileChunker/Dockerfile +20 -0
  5. package/helper_apps/MediaFileChunker/package-lock.json +18 -18
  6. package/helper_apps/MediaFileChunker/package.json +1 -1
  7. package/helper_apps/WhisperX/.dockerignore +27 -0
  8. package/helper_apps/WhisperX/Dockerfile +31 -0
  9. package/helper_apps/WhisperX/app-ts.py +76 -0
  10. package/helper_apps/WhisperX/app.py +115 -0
  11. package/helper_apps/WhisperX/docker-compose.debug.yml +12 -0
  12. package/helper_apps/WhisperX/docker-compose.yml +10 -0
  13. package/helper_apps/WhisperX/requirements.txt +6 -0
  14. package/index.js +1 -1
  15. package/lib/gcpAuthTokenHelper.js +37 -0
  16. package/lib/redisSubscription.js +1 -1
  17. package/package.json +9 -7
  18. package/pathways/basePathway.js +2 -2
  19. package/pathways/index.js +8 -2
  20. package/pathways/summary.js +2 -2
  21. package/pathways/sys_openai_chat.js +19 -0
  22. package/pathways/sys_openai_completion.js +11 -0
  23. package/pathways/{lc_test.mjs → test_langchain.mjs} +1 -1
  24. package/pathways/test_palm_chat.js +31 -0
  25. package/pathways/transcribe.js +3 -1
  26. package/pathways/translate.js +2 -1
  27. package/{graphql → server}/graphql.js +64 -62
  28. package/{graphql → server}/pathwayPrompter.js +9 -1
  29. package/{graphql → server}/pathwayResolver.js +46 -47
  30. package/{graphql → server}/plugins/azureTranslatePlugin.js +22 -0
  31. package/{graphql → server}/plugins/modelPlugin.js +15 -42
  32. package/server/plugins/openAiChatPlugin.js +134 -0
  33. package/{graphql → server}/plugins/openAiCompletionPlugin.js +38 -2
  34. package/{graphql → server}/plugins/openAiWhisperPlugin.js +59 -7
  35. package/server/plugins/palmChatPlugin.js +229 -0
  36. package/server/plugins/palmCompletionPlugin.js +134 -0
  37. package/{graphql → server}/prompt.js +11 -4
  38. package/server/rest.js +321 -0
  39. package/{graphql → server}/typeDef.js +30 -13
  40. package/tests/chunkfunction.test.js +1 -1
  41. package/tests/config.test.js +1 -1
  42. package/tests/main.test.js +282 -43
  43. package/tests/mocks.js +1 -1
  44. package/tests/modelPlugin.test.js +3 -15
  45. package/tests/openAiChatPlugin.test.js +125 -0
  46. package/tests/openai_api.test.js +147 -0
  47. package/tests/palmChatPlugin.test.js +256 -0
  48. package/tests/palmCompletionPlugin.test.js +87 -0
  49. package/tests/pathwayResolver.test.js +1 -1
  50. package/tests/server.js +23 -0
  51. package/tests/truncateMessages.test.js +1 -1
  52. package/graphql/plugins/openAiChatPlugin.js +0 -46
  53. package/tests/chunking.test.js +0 -155
  54. package/tests/translate.test.js +0 -126
  55. /package/{graphql → server}/chunker.js +0 -0
  56. /package/{graphql → server}/parser.js +0 -0
  57. /package/{graphql → server}/pathwayResponseParser.js +0 -0
  58. /package/{graphql → server}/plugins/localModelPlugin.js +0 -0
  59. /package/{graphql → server}/pubsub.js +0 -0
  60. /package/{graphql → server}/requestState.js +0 -0
  61. /package/{graphql → server}/resolver.js +0 -0
  62. /package/{graphql → server}/subscriptions.js +0 -0
@@ -0,0 +1,134 @@
1
+ // OpenAIChatPlugin.js
2
+ import ModelPlugin from './modelPlugin.js';
3
+ import { encode } from 'gpt-3-encoder';
4
+
5
+ class OpenAIChatPlugin extends ModelPlugin {
6
+ constructor(config, pathway) {
7
+ super(config, pathway);
8
+ }
9
+
10
+ // convert to OpenAI messages array format if necessary
11
+ convertPalmToOpenAIMessages(context, examples, messages) {
12
+ let openAIMessages = [];
13
+
14
+ // Add context as a system message
15
+ if (context) {
16
+ openAIMessages.push({
17
+ role: 'system',
18
+ content: context,
19
+ });
20
+ }
21
+
22
+ // Add examples to the messages array
23
+ examples.forEach(example => {
24
+ openAIMessages.push({
25
+ role: example.input.author || 'user',
26
+ content: example.input.content,
27
+ });
28
+ openAIMessages.push({
29
+ role: example.output.author || 'assistant',
30
+ content: example.output.content,
31
+ });
32
+ });
33
+
34
+ // Add remaining messages to the messages array
35
+ messages.forEach(message => {
36
+ openAIMessages.push({
37
+ role: message.author,
38
+ content: message.content,
39
+ });
40
+ });
41
+
42
+ return openAIMessages;
43
+ }
44
+
45
+ // Set up parameters specific to the OpenAI Chat API
46
+ getRequestParameters(text, parameters, prompt) {
47
+ const { modelPromptText, modelPromptMessages, tokenLength, modelPrompt } = this.getCompiledPrompt(text, parameters, prompt);
48
+ const { stream } = parameters;
49
+
50
+ // Define the model's max token length
51
+ const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
52
+
53
+ let requestMessages = modelPromptMessages || [{ "role": "user", "content": modelPromptText }];
54
+
55
+ // Check if the messages are in Palm format and convert them to OpenAI format if necessary
56
+ const isPalmFormat = requestMessages.some(message => 'author' in message);
57
+ if (isPalmFormat) {
58
+ const context = modelPrompt.context || '';
59
+ const examples = modelPrompt.examples || [];
60
+ requestMessages = this.convertPalmToOpenAIMessages(context, examples, expandedMessages);
61
+ }
62
+
63
+ // Check if the token length exceeds the model's max token length
64
+ if (tokenLength > modelTargetTokenLength) {
65
+ // Remove older messages until the token length is within the model's limit
66
+ requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
67
+ }
68
+
69
+ const requestParameters = {
70
+ messages: requestMessages,
71
+ temperature: this.temperature ?? 0.7,
72
+ ...(stream !== undefined ? { stream } : {}),
73
+ };
74
+
75
+ return requestParameters;
76
+ }
77
+
78
+ // Execute the request to the OpenAI Chat API
79
+ async execute(text, parameters, prompt) {
80
+ const url = this.requestUrl(text);
81
+ const requestParameters = this.getRequestParameters(text, parameters, prompt);
82
+
83
+ const data = { ...(this.model.params || {}), ...requestParameters };
84
+ const params = {};
85
+ const headers = this.model.headers || {};
86
+ return this.executeRequest(url, data, params, headers, prompt);
87
+ }
88
+
89
+ // Parse the response from the OpenAI Chat API
90
+ parseResponse(data) {
91
+ const { choices } = data;
92
+ if (!choices || !choices.length) {
93
+ return data;
94
+ }
95
+
96
+ // if we got a choices array back with more than one choice, return the whole array
97
+ if (choices.length > 1) {
98
+ return choices;
99
+ }
100
+
101
+ // otherwise, return the first choice
102
+ const messageResult = choices[0].message && choices[0].message.content && choices[0].message.content.trim();
103
+ return messageResult ?? null;
104
+ }
105
+
106
+ // Override the logging function to display the messages and responses
107
+ logRequestData(data, responseData, prompt) {
108
+ const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
+ console.log(separator);
110
+
111
+ const { stream, messages } = data;
112
+ if (messages && messages.length > 1) {
113
+ messages.forEach((message, index) => {
114
+ const words = message.content.split(" ");
115
+ const tokenCount = encode(message.content).length;
116
+ const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
117
+
118
+ console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
119
+ });
120
+ } else {
121
+ console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
122
+ }
123
+
124
+ if (stream) {
125
+ console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
126
+ } else {
127
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
128
+ }
129
+
130
+ prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
131
+ }
132
+ }
133
+
134
+ export default OpenAIChatPlugin;
@@ -1,7 +1,6 @@
1
1
  // OpenAICompletionPlugin.js
2
2
 
3
3
  import ModelPlugin from './modelPlugin.js';
4
-
5
4
  import { encode } from 'gpt-3-encoder';
6
5
 
7
6
  // Helper function to truncate the prompt if it is too long
@@ -52,7 +51,7 @@ class OpenAICompletionPlugin extends ModelPlugin {
52
51
  frequency_penalty: 0,
53
52
  presence_penalty: 0,
54
53
  stop: ["<|im_end|>"],
55
- stream
54
+ ...(stream !== undefined ? { stream } : {}),
56
55
  };
57
56
  } else {
58
57
 
@@ -83,8 +82,45 @@ class OpenAICompletionPlugin extends ModelPlugin {
83
82
  const data = { ...(this.model.params || {}), ...requestParameters };
84
83
  const params = {};
85
84
  const headers = this.model.headers || {};
85
+
86
86
  return this.executeRequest(url, data, params, headers, prompt);
87
87
  }
88
+
89
+ // Parse the response from the OpenAI Completion API
90
+ parseResponse(data) {
91
+ const { choices } = data;
92
+ if (!choices || !choices.length) {
93
+ return data;
94
+ }
95
+
96
+ // if we got a choices array back with more than one choice, return the whole array
97
+ if (choices.length > 1) {
98
+ return choices;
99
+ }
100
+
101
+ // otherwise, return the first choice
102
+ const textResult = choices[0].text && choices[0].text.trim();
103
+ return textResult ?? null;
104
+ }
105
+
106
+ // Override the logging function to log the prompt and response
107
+ logRequestData(data, responseData, prompt) {
108
+ const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
+ console.log(separator);
110
+
111
+ const stream = data.stream;
112
+ const modelInput = data.prompt;
113
+
114
+ console.log(`\x1b[36m${modelInput}\x1b[0m`);
115
+
116
+ if (stream) {
117
+ console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
118
+ } else {
119
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
120
+ }
121
+
122
+ prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
123
+ }
88
124
  }
89
125
 
90
126
  export default OpenAICompletionPlugin;
@@ -14,10 +14,33 @@ import http from 'http';
14
14
  import https from 'https';
15
15
  import url from 'url';
16
16
  import { promisify } from 'util';
17
+ import subsrt from 'subsrt';
17
18
  const pipeline = promisify(stream.pipeline);
18
19
 
19
20
 
20
21
  const API_URL = config.get('whisperMediaApiUrl');
22
+ const WHISPER_TS_API_URL = config.get('whisperTSApiUrl');
23
+
24
+ function alignSubtitles(subtitles) {
25
+ const result = [];
26
+ const offset = 1000 * 60 * 10; // 10 minutes for each chunk
27
+
28
+ function preprocessStr(str) {
29
+ return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
30
+ }
31
+
32
+ function shiftSubtitles(subtitle, shiftOffset) {
33
+ const captions = subsrt.parse(preprocessStr(subtitle));
34
+ const resynced = subsrt.resync(captions, { offset: shiftOffset });
35
+ return resynced;
36
+ }
37
+
38
+ for (let i = 0; i < subtitles.length; i++) {
39
+ const subtitle = subtitles[i];
40
+ result.push(...shiftSubtitles(subtitle, i * offset));
41
+ }
42
+ return subsrt.build(result);
43
+ }
21
44
 
22
45
  function generateUniqueFilename(extension) {
23
46
  return `${uuidv4()}.${extension}`;
@@ -93,17 +116,37 @@ class OpenAIWhisperPlugin extends ModelPlugin {
93
116
 
94
117
  // Execute the request to the OpenAI Whisper API
95
118
  async execute(text, parameters, prompt, pathwayResolver) {
119
+ const { responseFormat, wordTimestamped } = parameters;
96
120
  const url = this.requestUrl(text);
97
121
  const params = {};
98
122
  const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
99
123
 
124
+ const processTS = async (uri) => {
125
+ if (wordTimestamped) {
126
+ if (!WHISPER_TS_API_URL) {
127
+ throw new Error(`WHISPER_TS_API_URL not set for word timestamped processing`);
128
+ }
129
+
130
+ try {
131
+ // const res = await axios.post(WHISPER_TS_API_URL, { params: { fileurl: uri } });
132
+ const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri},{},{});
133
+ return res;
134
+ } catch (err) {
135
+ console.log(`Error getting word timestamped data from api:`, err);
136
+ throw err;
137
+ }
138
+ }
139
+ }
140
+
100
141
  const processChunk = async (chunk) => {
101
142
  try {
102
- const { language } = parameters;
143
+ const { language, responseFormat } = parameters;
144
+ const response_format = responseFormat || 'text';
145
+
103
146
  const formData = new FormData();
104
147
  formData.append('file', fs.createReadStream(chunk));
105
148
  formData.append('model', this.model.params.model);
106
- formData.append('response_format', 'text');
149
+ formData.append('response_format', response_format);
107
150
  language && formData.append('language', language);
108
151
  modelPromptText && formData.append('prompt', modelPromptText);
109
152
 
@@ -114,7 +157,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
114
157
  }
115
158
  }
116
159
 
117
- let result = ``;
160
+ let result = [];
118
161
  let { file } = parameters;
119
162
  let totalCount = 0;
120
163
  let completedCount = 0;
@@ -134,7 +177,6 @@ class OpenAIWhisperPlugin extends ModelPlugin {
134
177
 
135
178
  let chunks = []; // array of local file paths
136
179
  try {
137
-
138
180
  const uris = await this.getMediaChunks(file, requestId); // array of remote file uris
139
181
  if (!uris || !uris.length) {
140
182
  throw new Error(`Error in getting chunks from media helper for file ${file}`);
@@ -144,14 +186,20 @@ class OpenAIWhisperPlugin extends ModelPlugin {
144
186
 
145
187
  // sequential download of chunks
146
188
  for (const uri of uris) {
147
- chunks.push(await downloadFile(uri));
189
+ if (wordTimestamped) { // get word timestamped data
190
+ sendProgress(); // no download needed auto progress
191
+ const ts = await processTS(uri);
192
+ result.push(ts);
193
+ } else {
194
+ chunks.push(await downloadFile(uri));
195
+ }
148
196
  sendProgress();
149
197
  }
150
198
 
151
199
 
152
200
  // sequential processing of chunks
153
201
  for (const chunk of chunks) {
154
- result += await processChunk(chunk);
202
+ result.push(await processChunk(chunk));
155
203
  sendProgress();
156
204
  }
157
205
 
@@ -184,7 +232,11 @@ class OpenAIWhisperPlugin extends ModelPlugin {
184
232
  console.error("An error occurred while deleting:", error);
185
233
  }
186
234
  }
187
- return result;
235
+
236
+ if (['srt','vtt'].includes(responseFormat) || wordTimestamped) { // align subtitles for formats
237
+ return alignSubtitles(result);
238
+ }
239
+ return result.join(` `);
188
240
  }
189
241
  }
190
242
 
@@ -0,0 +1,229 @@
1
+ // palmChatPlugin.js
2
+ import ModelPlugin from './modelPlugin.js';
3
+ import { encode } from 'gpt-3-encoder';
4
+ import HandleBars from '../../lib/handleBars.js';
5
+
6
+ class PalmChatPlugin extends ModelPlugin {
7
+ constructor(config, pathway) {
8
+ super(config, pathway);
9
+ }
10
+
11
+ // Convert to PaLM messages array format if necessary
12
+ convertMessagesToPalm(messages) {
13
+ let context = '';
14
+ let modifiedMessages = [];
15
+ let lastAuthor = '';
16
+
17
+ messages.forEach(message => {
18
+ const { role, author, content } = message;
19
+
20
+ // Extract system messages into the context string
21
+ if (role === 'system') {
22
+ context += (context.length > 0 ? '\n' : '') + content;
23
+ return;
24
+ }
25
+
26
+ // Aggregate consecutive author messages, appending the content
27
+ if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
28
+ modifiedMessages[modifiedMessages.length - 1].content += '\n' + content;
29
+ }
30
+ // Only push messages with role 'user' or 'assistant' or existing author messages
31
+ else if (role === 'user' || role === 'assistant' || author) {
32
+ modifiedMessages.push({
33
+ author: author || role,
34
+ content,
35
+ });
36
+ lastAuthor = author || role;
37
+ }
38
+ });
39
+
40
+ return {
41
+ modifiedMessages,
42
+ context,
43
+ };
44
+ }
45
+
46
+ // Handlebars compiler for context (PaLM chat specific)
47
+ getCompiledContext(text, parameters, context) {
48
+ const combinedParameters = { ...this.promptParameters, ...parameters };
49
+ return context ? HandleBars.compile(context)({ ...combinedParameters, text}) : '';
50
+ }
51
+
52
+ // Handlebars compiler for examples (PaLM chat specific)
53
+ getCompiledExamples(text, parameters, examples = []) {
54
+ const combinedParameters = { ...this.promptParameters, ...parameters };
55
+
56
+ const compileContent = (content) => {
57
+ const compile = HandleBars.compile(content);
58
+ return compile({ ...combinedParameters, text });
59
+ };
60
+
61
+ const processExample = (example, key) => {
62
+ if (example[key]?.content) {
63
+ return { ...example[key], content: compileContent(example[key].content) };
64
+ }
65
+ return { ...example[key] };
66
+ };
67
+
68
+ return examples.map((example) => ({
69
+ input: example.input ? processExample(example, 'input') : undefined,
70
+ output: example.output ? processExample(example, 'output') : undefined,
71
+ }));
72
+ }
73
+
74
+ // Set up parameters specific to the PaLM Chat API
75
+ getRequestParameters(text, parameters, prompt) {
76
+ const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
77
+ const { stream } = parameters;
78
+
79
+ // Define the model's max token length
80
+ const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
81
+
82
+ const palmMessages = this.convertMessagesToPalm(modelPromptMessages || [{ "author": "user", "content": modelPromptText }]);
83
+
84
+ let requestMessages = palmMessages.modifiedMessages;
85
+
86
+ // Check if the token length exceeds the model's max token length
87
+ if (tokenLength > modelTargetTokenLength) {
88
+ // Remove older messages until the token length is within the model's limit
89
+ requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
90
+ }
91
+
92
+ const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
93
+ const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
94
+
95
+ // For PaLM right now, the max return tokens is 1024, regardless of the max context length
96
+ // I can't think of a time you'd want to constrain it to fewer at the moment.
97
+ const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
98
+
99
+ if (max_tokens < 0) {
100
+ throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
101
+ }
102
+
103
+ // Ensure there are an even number of messages (PaLM requires an even number of messages)
104
+ if (requestMessages.length % 2 === 0) {
105
+ requestMessages = requestMessages.slice(1);
106
+ }
107
+
108
+ const requestParameters = {
109
+ instances: [{
110
+ context: context,
111
+ examples: examples,
112
+ messages: requestMessages,
113
+ }],
114
+ parameters: {
115
+ temperature: this.temperature ?? 0.7,
116
+ maxOutputTokens: max_tokens,
117
+ topP: parameters.topP ?? 0.95,
118
+ topK: parameters.topK ?? 40,
119
+ }
120
+ };
121
+
122
+ return requestParameters;
123
+ }
124
+
125
+ // Get the safetyAttributes from the PaLM Chat API response data
126
+ getSafetyAttributes(data) {
127
+ const { predictions } = data;
128
+ if (!predictions || !predictions.length) {
129
+ return null;
130
+ }
131
+
132
+ // if we got a predictions array back with more than one prediction, return the safetyAttributes of the first prediction
133
+ if (predictions.length > 1) {
134
+ return predictions[0].safetyAttributes ?? null;
135
+ }
136
+
137
+ // otherwise, return the safetyAttributes of the content of the first prediction
138
+ return predictions[0].safetyAttributes ?? null;
139
+ }
140
+
141
+ // Execute the request to the PaLM Chat API
142
+ async execute(text, parameters, prompt) {
143
+ const url = this.requestUrl(text);
144
+ const requestParameters = this.getRequestParameters(text, parameters, prompt);
145
+
146
+ const data = { ...(this.model.params || {}), ...requestParameters };
147
+ const params = {};
148
+ const headers = this.model.headers || {};
149
+ const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
150
+ const authToken = await gcpAuthTokenHelper.getAccessToken();
151
+ headers.Authorization = `Bearer ${authToken}`;
152
+ return this.executeRequest(url, data, params, headers, prompt);
153
+ }
154
+
155
+ // Parse the response from the PaLM Chat API
156
+ parseResponse(data) {
157
+ const { predictions } = data;
158
+ if (!predictions || !predictions.length) {
159
+ return null;
160
+ }
161
+
162
+ // Get the candidates array from the first prediction
163
+ const { candidates } = predictions[0];
164
+
165
+ // if it was blocked, return the blocked message
166
+ if (predictions[0].safetyAttributes?.blocked) {
167
+ return 'The response is blocked because the input or response potentially violates Google policies. Try rephrasing the prompt or adjusting the parameter settings. Currently, only English is supported.';
168
+ }
169
+
170
+ if (!candidates || !candidates.length) {
171
+ return null;
172
+ }
173
+
174
+ // If we got a candidates array back with more than one candidate, return the whole array
175
+ if (candidates.length > 1) {
176
+ return candidates;
177
+ }
178
+
179
+ // Otherwise, return the content of the first candidate
180
+ const messageResult = candidates[0].content && candidates[0].content.trim();
181
+ return messageResult ?? null;
182
+ }
183
+
184
+ // Override the logging function to display the messages and responses
185
+ logRequestData(data, responseData, prompt) {
186
+ const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
187
+ console.log(separator);
188
+
189
+ const instances = data && data.instances;
190
+ const messages = instances && instances[0] && instances[0].messages;
191
+ const { context, examples } = instances && instances [0] || {};
192
+
193
+ if (context) {
194
+ console.log(`\x1b[36mContext: ${context}\x1b[0m`);
195
+ }
196
+
197
+ if (examples && examples.length) {
198
+ examples.forEach((example, index) => {
199
+ console.log(`\x1b[36mExample ${index + 1}: Input: "${example.input.content}", Output: "${example.output.content}"\x1b[0m`);
200
+ });
201
+ }
202
+
203
+ if (messages && messages.length > 1) {
204
+ messages.forEach((message, index) => {
205
+ const words = message.content.split(" ");
206
+ const tokenCount = encode(message.content).length;
207
+ const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
208
+
209
+ console.log(`\x1b[36mMessage ${index + 1}: Author: ${message.author}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
210
+ });
211
+ } else if (messages && messages.length === 1) {
212
+ console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
213
+ }
214
+
215
+ const safetyAttributes = this.getSafetyAttributes(responseData);
216
+
217
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
218
+
219
+ if (safetyAttributes) {
220
+ console.log(`\x1b[33mSafety Attributes: ${JSON.stringify(safetyAttributes, null, 2)}\x1b[0m`);
221
+ }
222
+
223
+ if (prompt && prompt.debugInfo) {
224
+ prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
225
+ }
226
+ }
227
+ }
228
+
229
+ export default PalmChatPlugin;
@@ -0,0 +1,134 @@
1
+ // palmCompletionPlugin.js
2
+
3
+ import ModelPlugin from './modelPlugin.js';
4
+
5
+ // Helper function to truncate the prompt if it is too long
6
+ const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) => {
7
+ const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
8
+
9
+ if (textTokenCount > maxAllowedTokens) {
10
+ pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
11
+ return pathwayResolver.truncate(text, maxAllowedTokens);
12
+ }
13
+ return text;
14
+ }
15
+
16
+ // PalmCompletionPlugin class for handling requests and responses to the PaLM API Text Completion API
17
+ class PalmCompletionPlugin extends ModelPlugin {
18
+ constructor(config, pathway) {
19
+ super(config, pathway);
20
+ }
21
+
22
+ // Set up parameters specific to the PaLM API Text Completion API
23
+ getRequestParameters(text, parameters, prompt, pathwayResolver) {
24
+ const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
25
+ const { stream } = parameters;
26
+ // Define the model's max token length
27
+ const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
28
+
29
+ const truncatedPrompt = truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
30
+
31
+ const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
32
+
33
+ if (max_tokens < 0) {
34
+ throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
35
+ }
36
+
37
+ if (!truncatedPrompt) {
38
+ throw new Error(`Prompt is empty. The model will not be called.`);
39
+ }
40
+
41
+ const requestParameters = {
42
+ instances: [
43
+ { prompt: truncatedPrompt }
44
+ ],
45
+ parameters: {
46
+ temperature: this.temperature ?? 0.7,
47
+ maxOutputTokens: max_tokens,
48
+ topP: parameters.topP ?? 0.95,
49
+ topK: parameters.topK ?? 40,
50
+ }
51
+ };
52
+
53
+ return requestParameters;
54
+ }
55
+
56
+ // Execute the request to the PaLM API Text Completion API
57
+ async execute(text, parameters, prompt, pathwayResolver) {
58
+ const url = this.requestUrl(text);
59
+ const requestParameters = this.getRequestParameters(text, parameters, prompt, pathwayResolver);
60
+
61
+ const data = { ...requestParameters };
62
+ const params = {};
63
+ const headers = this.model.headers || {};
64
+ const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
65
+ const authToken = await gcpAuthTokenHelper.getAccessToken();
66
+ headers.Authorization = `Bearer ${authToken}`;
67
+ return this.executeRequest(url, data, params, headers, prompt);
68
+ }
69
+
70
+ // Parse the response from the PaLM API Text Completion API
71
+ parseResponse(data) {
72
+ const { predictions } = data;
73
+ if (!predictions || !predictions.length) {
74
+ return data;
75
+ }
76
+
77
+ // if we got a predictions array back with more than one prediction, return the whole array
78
+ if (predictions.length > 1) {
79
+ return predictions;
80
+ }
81
+
82
+ // otherwise, return the content of the first prediction
83
+ // if it was blocked, return the blocked message
84
+ if (predictions[0].safetyAttributes?.blocked) {
85
+ return 'The response is blocked because the input or response potentially violates Google policies. Try rephrasing the prompt or adjusting the parameter settings. Currently, only English is supported.';
86
+ }
87
+
88
+ const contentResult = predictions[0].content && predictions[0].content.trim();
89
+ return contentResult ?? null;
90
+ }
91
+
92
+ // Get the safetyAttributes from the PaLM API Text Completion API response data
93
+ getSafetyAttributes(data) {
94
+ const { predictions } = data;
95
+ if (!predictions || !predictions.length) {
96
+ return null;
97
+ }
98
+
99
+ // if we got a predictions array back with more than one prediction, return the safetyAttributes of the first prediction
100
+ if (predictions.length > 1) {
101
+ return predictions[0].safetyAttributes ?? null;
102
+ }
103
+
104
+ // otherwise, return the safetyAttributes of the content of the first prediction
105
+ return predictions[0].safetyAttributes ?? null;
106
+ }
107
+
108
+ // Override the logging function to log the prompt and response
109
+ logRequestData(data, responseData, prompt) {
110
+ const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
111
+ console.log(separator);
112
+
113
+ const safetyAttributes = this.getSafetyAttributes(responseData);
114
+
115
+ const instances = data && data.instances;
116
+ const modelInput = instances && instances[0] && instances[0].prompt;
117
+
118
+ if (modelInput) {
119
+ console.log(`\x1b[36m${modelInput}\x1b[0m`);
120
+ }
121
+
122
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
123
+
124
+ if (safetyAttributes) {
125
+ console.log(`\x1b[33mSafety Attributes: ${JSON.stringify(safetyAttributes, null, 2)}\x1b[0m`);
126
+ }
127
+
128
+ if (prompt && prompt.debugInfo) {
129
+ prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
130
+ }
131
+ }
132
+ }
133
+
134
+ export default PalmCompletionPlugin;