@aj-archipelago/cortex 1.0.4 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -3
- package/config/default.example.json +18 -0
- package/config.js +28 -8
- package/helper_apps/MediaFileChunker/Dockerfile +20 -0
- package/helper_apps/MediaFileChunker/package-lock.json +18 -18
- package/helper_apps/MediaFileChunker/package.json +1 -1
- package/helper_apps/WhisperX/.dockerignore +27 -0
- package/helper_apps/WhisperX/Dockerfile +31 -0
- package/helper_apps/WhisperX/app-ts.py +76 -0
- package/helper_apps/WhisperX/app.py +115 -0
- package/helper_apps/WhisperX/docker-compose.debug.yml +12 -0
- package/helper_apps/WhisperX/docker-compose.yml +10 -0
- package/helper_apps/WhisperX/requirements.txt +6 -0
- package/index.js +1 -1
- package/lib/gcpAuthTokenHelper.js +37 -0
- package/lib/redisSubscription.js +1 -1
- package/package.json +9 -7
- package/pathways/basePathway.js +2 -2
- package/pathways/index.js +8 -2
- package/pathways/summary.js +2 -2
- package/pathways/sys_openai_chat.js +19 -0
- package/pathways/sys_openai_completion.js +11 -0
- package/pathways/{lc_test.mjs → test_langchain.mjs} +1 -1
- package/pathways/test_palm_chat.js +31 -0
- package/pathways/transcribe.js +3 -1
- package/pathways/translate.js +2 -1
- package/{graphql → server}/graphql.js +64 -62
- package/{graphql → server}/pathwayPrompter.js +9 -1
- package/{graphql → server}/pathwayResolver.js +46 -47
- package/{graphql → server}/plugins/azureTranslatePlugin.js +22 -0
- package/{graphql → server}/plugins/modelPlugin.js +15 -42
- package/server/plugins/openAiChatPlugin.js +134 -0
- package/{graphql → server}/plugins/openAiCompletionPlugin.js +38 -2
- package/{graphql → server}/plugins/openAiWhisperPlugin.js +59 -7
- package/server/plugins/palmChatPlugin.js +229 -0
- package/server/plugins/palmCompletionPlugin.js +134 -0
- package/{graphql → server}/prompt.js +11 -4
- package/server/rest.js +321 -0
- package/{graphql → server}/typeDef.js +30 -13
- package/tests/chunkfunction.test.js +1 -1
- package/tests/config.test.js +1 -1
- package/tests/main.test.js +282 -43
- package/tests/mocks.js +1 -1
- package/tests/modelPlugin.test.js +3 -15
- package/tests/openAiChatPlugin.test.js +125 -0
- package/tests/openai_api.test.js +147 -0
- package/tests/palmChatPlugin.test.js +256 -0
- package/tests/palmCompletionPlugin.test.js +87 -0
- package/tests/pathwayResolver.test.js +1 -1
- package/tests/server.js +23 -0
- package/tests/truncateMessages.test.js +1 -1
- package/graphql/plugins/openAiChatPlugin.js +0 -46
- package/tests/chunking.test.js +0 -155
- package/tests/translate.test.js +0 -126
- /package/{graphql → server}/chunker.js +0 -0
- /package/{graphql → server}/parser.js +0 -0
- /package/{graphql → server}/pathwayResponseParser.js +0 -0
- /package/{graphql → server}/plugins/localModelPlugin.js +0 -0
- /package/{graphql → server}/pubsub.js +0 -0
- /package/{graphql → server}/requestState.js +0 -0
- /package/{graphql → server}/resolver.js +0 -0
- /package/{graphql → server}/subscriptions.js +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
// OpenAIChatPlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
import { encode } from 'gpt-3-encoder';
|
|
4
|
+
|
|
5
|
+
class OpenAIChatPlugin extends ModelPlugin {
|
|
6
|
+
constructor(config, pathway) {
|
|
7
|
+
super(config, pathway);
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
// convert to OpenAI messages array format if necessary
|
|
11
|
+
convertPalmToOpenAIMessages(context, examples, messages) {
|
|
12
|
+
let openAIMessages = [];
|
|
13
|
+
|
|
14
|
+
// Add context as a system message
|
|
15
|
+
if (context) {
|
|
16
|
+
openAIMessages.push({
|
|
17
|
+
role: 'system',
|
|
18
|
+
content: context,
|
|
19
|
+
});
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// Add examples to the messages array
|
|
23
|
+
examples.forEach(example => {
|
|
24
|
+
openAIMessages.push({
|
|
25
|
+
role: example.input.author || 'user',
|
|
26
|
+
content: example.input.content,
|
|
27
|
+
});
|
|
28
|
+
openAIMessages.push({
|
|
29
|
+
role: example.output.author || 'assistant',
|
|
30
|
+
content: example.output.content,
|
|
31
|
+
});
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
// Add remaining messages to the messages array
|
|
35
|
+
messages.forEach(message => {
|
|
36
|
+
openAIMessages.push({
|
|
37
|
+
role: message.author,
|
|
38
|
+
content: message.content,
|
|
39
|
+
});
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
return openAIMessages;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
// Set up parameters specific to the OpenAI Chat API
|
|
46
|
+
getRequestParameters(text, parameters, prompt) {
|
|
47
|
+
const { modelPromptText, modelPromptMessages, tokenLength, modelPrompt } = this.getCompiledPrompt(text, parameters, prompt);
|
|
48
|
+
const { stream } = parameters;
|
|
49
|
+
|
|
50
|
+
// Define the model's max token length
|
|
51
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
52
|
+
|
|
53
|
+
let requestMessages = modelPromptMessages || [{ "role": "user", "content": modelPromptText }];
|
|
54
|
+
|
|
55
|
+
// Check if the messages are in Palm format and convert them to OpenAI format if necessary
|
|
56
|
+
const isPalmFormat = requestMessages.some(message => 'author' in message);
|
|
57
|
+
if (isPalmFormat) {
|
|
58
|
+
const context = modelPrompt.context || '';
|
|
59
|
+
const examples = modelPrompt.examples || [];
|
|
60
|
+
requestMessages = this.convertPalmToOpenAIMessages(context, examples, expandedMessages);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// Check if the token length exceeds the model's max token length
|
|
64
|
+
if (tokenLength > modelTargetTokenLength) {
|
|
65
|
+
// Remove older messages until the token length is within the model's limit
|
|
66
|
+
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
const requestParameters = {
|
|
70
|
+
messages: requestMessages,
|
|
71
|
+
temperature: this.temperature ?? 0.7,
|
|
72
|
+
...(stream !== undefined ? { stream } : {}),
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
return requestParameters;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
// Execute the request to the OpenAI Chat API
|
|
79
|
+
async execute(text, parameters, prompt) {
|
|
80
|
+
const url = this.requestUrl(text);
|
|
81
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
82
|
+
|
|
83
|
+
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
84
|
+
const params = {};
|
|
85
|
+
const headers = this.model.headers || {};
|
|
86
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
// Parse the response from the OpenAI Chat API
|
|
90
|
+
parseResponse(data) {
|
|
91
|
+
const { choices } = data;
|
|
92
|
+
if (!choices || !choices.length) {
|
|
93
|
+
return data;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// if we got a choices array back with more than one choice, return the whole array
|
|
97
|
+
if (choices.length > 1) {
|
|
98
|
+
return choices;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// otherwise, return the first choice
|
|
102
|
+
const messageResult = choices[0].message && choices[0].message.content && choices[0].message.content.trim();
|
|
103
|
+
return messageResult ?? null;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// Override the logging function to display the messages and responses
|
|
107
|
+
logRequestData(data, responseData, prompt) {
|
|
108
|
+
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
109
|
+
console.log(separator);
|
|
110
|
+
|
|
111
|
+
const { stream, messages } = data;
|
|
112
|
+
if (messages && messages.length > 1) {
|
|
113
|
+
messages.forEach((message, index) => {
|
|
114
|
+
const words = message.content.split(" ");
|
|
115
|
+
const tokenCount = encode(message.content).length;
|
|
116
|
+
const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
117
|
+
|
|
118
|
+
console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
|
|
119
|
+
});
|
|
120
|
+
} else {
|
|
121
|
+
console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
if (stream) {
|
|
125
|
+
console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
|
|
126
|
+
} else {
|
|
127
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
export default OpenAIChatPlugin;
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
// OpenAICompletionPlugin.js
|
|
2
2
|
|
|
3
3
|
import ModelPlugin from './modelPlugin.js';
|
|
4
|
-
|
|
5
4
|
import { encode } from 'gpt-3-encoder';
|
|
6
5
|
|
|
7
6
|
// Helper function to truncate the prompt if it is too long
|
|
@@ -52,7 +51,7 @@ class OpenAICompletionPlugin extends ModelPlugin {
|
|
|
52
51
|
frequency_penalty: 0,
|
|
53
52
|
presence_penalty: 0,
|
|
54
53
|
stop: ["<|im_end|>"],
|
|
55
|
-
stream
|
|
54
|
+
...(stream !== undefined ? { stream } : {}),
|
|
56
55
|
};
|
|
57
56
|
} else {
|
|
58
57
|
|
|
@@ -83,8 +82,45 @@ class OpenAICompletionPlugin extends ModelPlugin {
|
|
|
83
82
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
84
83
|
const params = {};
|
|
85
84
|
const headers = this.model.headers || {};
|
|
85
|
+
|
|
86
86
|
return this.executeRequest(url, data, params, headers, prompt);
|
|
87
87
|
}
|
|
88
|
+
|
|
89
|
+
// Parse the response from the OpenAI Completion API
|
|
90
|
+
parseResponse(data) {
|
|
91
|
+
const { choices } = data;
|
|
92
|
+
if (!choices || !choices.length) {
|
|
93
|
+
return data;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// if we got a choices array back with more than one choice, return the whole array
|
|
97
|
+
if (choices.length > 1) {
|
|
98
|
+
return choices;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// otherwise, return the first choice
|
|
102
|
+
const textResult = choices[0].text && choices[0].text.trim();
|
|
103
|
+
return textResult ?? null;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
// Override the logging function to log the prompt and response
|
|
107
|
+
logRequestData(data, responseData, prompt) {
|
|
108
|
+
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
109
|
+
console.log(separator);
|
|
110
|
+
|
|
111
|
+
const stream = data.stream;
|
|
112
|
+
const modelInput = data.prompt;
|
|
113
|
+
|
|
114
|
+
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
115
|
+
|
|
116
|
+
if (stream) {
|
|
117
|
+
console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
|
|
118
|
+
} else {
|
|
119
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
|
|
123
|
+
}
|
|
88
124
|
}
|
|
89
125
|
|
|
90
126
|
export default OpenAICompletionPlugin;
|
|
@@ -14,10 +14,33 @@ import http from 'http';
|
|
|
14
14
|
import https from 'https';
|
|
15
15
|
import url from 'url';
|
|
16
16
|
import { promisify } from 'util';
|
|
17
|
+
import subsrt from 'subsrt';
|
|
17
18
|
const pipeline = promisify(stream.pipeline);
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
const API_URL = config.get('whisperMediaApiUrl');
|
|
22
|
+
const WHISPER_TS_API_URL = config.get('whisperTSApiUrl');
|
|
23
|
+
|
|
24
|
+
function alignSubtitles(subtitles) {
|
|
25
|
+
const result = [];
|
|
26
|
+
const offset = 1000 * 60 * 10; // 10 minutes for each chunk
|
|
27
|
+
|
|
28
|
+
function preprocessStr(str) {
|
|
29
|
+
return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
function shiftSubtitles(subtitle, shiftOffset) {
|
|
33
|
+
const captions = subsrt.parse(preprocessStr(subtitle));
|
|
34
|
+
const resynced = subsrt.resync(captions, { offset: shiftOffset });
|
|
35
|
+
return resynced;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
for (let i = 0; i < subtitles.length; i++) {
|
|
39
|
+
const subtitle = subtitles[i];
|
|
40
|
+
result.push(...shiftSubtitles(subtitle, i * offset));
|
|
41
|
+
}
|
|
42
|
+
return subsrt.build(result);
|
|
43
|
+
}
|
|
21
44
|
|
|
22
45
|
function generateUniqueFilename(extension) {
|
|
23
46
|
return `${uuidv4()}.${extension}`;
|
|
@@ -93,17 +116,37 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
93
116
|
|
|
94
117
|
// Execute the request to the OpenAI Whisper API
|
|
95
118
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
119
|
+
const { responseFormat, wordTimestamped } = parameters;
|
|
96
120
|
const url = this.requestUrl(text);
|
|
97
121
|
const params = {};
|
|
98
122
|
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
99
123
|
|
|
124
|
+
const processTS = async (uri) => {
|
|
125
|
+
if (wordTimestamped) {
|
|
126
|
+
if (!WHISPER_TS_API_URL) {
|
|
127
|
+
throw new Error(`WHISPER_TS_API_URL not set for word timestamped processing`);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
try {
|
|
131
|
+
// const res = await axios.post(WHISPER_TS_API_URL, { params: { fileurl: uri } });
|
|
132
|
+
const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri},{},{});
|
|
133
|
+
return res;
|
|
134
|
+
} catch (err) {
|
|
135
|
+
console.log(`Error getting word timestamped data from api:`, err);
|
|
136
|
+
throw err;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
100
141
|
const processChunk = async (chunk) => {
|
|
101
142
|
try {
|
|
102
|
-
const { language } = parameters;
|
|
143
|
+
const { language, responseFormat } = parameters;
|
|
144
|
+
const response_format = responseFormat || 'text';
|
|
145
|
+
|
|
103
146
|
const formData = new FormData();
|
|
104
147
|
formData.append('file', fs.createReadStream(chunk));
|
|
105
148
|
formData.append('model', this.model.params.model);
|
|
106
|
-
formData.append('response_format',
|
|
149
|
+
formData.append('response_format', response_format);
|
|
107
150
|
language && formData.append('language', language);
|
|
108
151
|
modelPromptText && formData.append('prompt', modelPromptText);
|
|
109
152
|
|
|
@@ -114,7 +157,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
114
157
|
}
|
|
115
158
|
}
|
|
116
159
|
|
|
117
|
-
let result =
|
|
160
|
+
let result = [];
|
|
118
161
|
let { file } = parameters;
|
|
119
162
|
let totalCount = 0;
|
|
120
163
|
let completedCount = 0;
|
|
@@ -134,7 +177,6 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
134
177
|
|
|
135
178
|
let chunks = []; // array of local file paths
|
|
136
179
|
try {
|
|
137
|
-
|
|
138
180
|
const uris = await this.getMediaChunks(file, requestId); // array of remote file uris
|
|
139
181
|
if (!uris || !uris.length) {
|
|
140
182
|
throw new Error(`Error in getting chunks from media helper for file ${file}`);
|
|
@@ -144,14 +186,20 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
144
186
|
|
|
145
187
|
// sequential download of chunks
|
|
146
188
|
for (const uri of uris) {
|
|
147
|
-
|
|
189
|
+
if (wordTimestamped) { // get word timestamped data
|
|
190
|
+
sendProgress(); // no download needed auto progress
|
|
191
|
+
const ts = await processTS(uri);
|
|
192
|
+
result.push(ts);
|
|
193
|
+
} else {
|
|
194
|
+
chunks.push(await downloadFile(uri));
|
|
195
|
+
}
|
|
148
196
|
sendProgress();
|
|
149
197
|
}
|
|
150
198
|
|
|
151
199
|
|
|
152
200
|
// sequential processing of chunks
|
|
153
201
|
for (const chunk of chunks) {
|
|
154
|
-
result
|
|
202
|
+
result.push(await processChunk(chunk));
|
|
155
203
|
sendProgress();
|
|
156
204
|
}
|
|
157
205
|
|
|
@@ -184,7 +232,11 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
184
232
|
console.error("An error occurred while deleting:", error);
|
|
185
233
|
}
|
|
186
234
|
}
|
|
187
|
-
|
|
235
|
+
|
|
236
|
+
if (['srt','vtt'].includes(responseFormat) || wordTimestamped) { // align subtitles for formats
|
|
237
|
+
return alignSubtitles(result);
|
|
238
|
+
}
|
|
239
|
+
return result.join(` `);
|
|
188
240
|
}
|
|
189
241
|
}
|
|
190
242
|
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
// palmChatPlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
import { encode } from 'gpt-3-encoder';
|
|
4
|
+
import HandleBars from '../../lib/handleBars.js';
|
|
5
|
+
|
|
6
|
+
class PalmChatPlugin extends ModelPlugin {
|
|
7
|
+
constructor(config, pathway) {
|
|
8
|
+
super(config, pathway);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
// Convert to PaLM messages array format if necessary
|
|
12
|
+
convertMessagesToPalm(messages) {
|
|
13
|
+
let context = '';
|
|
14
|
+
let modifiedMessages = [];
|
|
15
|
+
let lastAuthor = '';
|
|
16
|
+
|
|
17
|
+
messages.forEach(message => {
|
|
18
|
+
const { role, author, content } = message;
|
|
19
|
+
|
|
20
|
+
// Extract system messages into the context string
|
|
21
|
+
if (role === 'system') {
|
|
22
|
+
context += (context.length > 0 ? '\n' : '') + content;
|
|
23
|
+
return;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
// Aggregate consecutive author messages, appending the content
|
|
27
|
+
if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
|
|
28
|
+
modifiedMessages[modifiedMessages.length - 1].content += '\n' + content;
|
|
29
|
+
}
|
|
30
|
+
// Only push messages with role 'user' or 'assistant' or existing author messages
|
|
31
|
+
else if (role === 'user' || role === 'assistant' || author) {
|
|
32
|
+
modifiedMessages.push({
|
|
33
|
+
author: author || role,
|
|
34
|
+
content,
|
|
35
|
+
});
|
|
36
|
+
lastAuthor = author || role;
|
|
37
|
+
}
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
return {
|
|
41
|
+
modifiedMessages,
|
|
42
|
+
context,
|
|
43
|
+
};
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
// Handlebars compiler for context (PaLM chat specific)
|
|
47
|
+
getCompiledContext(text, parameters, context) {
|
|
48
|
+
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
49
|
+
return context ? HandleBars.compile(context)({ ...combinedParameters, text}) : '';
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
// Handlebars compiler for examples (PaLM chat specific)
|
|
53
|
+
getCompiledExamples(text, parameters, examples = []) {
|
|
54
|
+
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
55
|
+
|
|
56
|
+
const compileContent = (content) => {
|
|
57
|
+
const compile = HandleBars.compile(content);
|
|
58
|
+
return compile({ ...combinedParameters, text });
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
const processExample = (example, key) => {
|
|
62
|
+
if (example[key]?.content) {
|
|
63
|
+
return { ...example[key], content: compileContent(example[key].content) };
|
|
64
|
+
}
|
|
65
|
+
return { ...example[key] };
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
return examples.map((example) => ({
|
|
69
|
+
input: example.input ? processExample(example, 'input') : undefined,
|
|
70
|
+
output: example.output ? processExample(example, 'output') : undefined,
|
|
71
|
+
}));
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// Set up parameters specific to the PaLM Chat API
|
|
75
|
+
getRequestParameters(text, parameters, prompt) {
|
|
76
|
+
const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
77
|
+
const { stream } = parameters;
|
|
78
|
+
|
|
79
|
+
// Define the model's max token length
|
|
80
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
81
|
+
|
|
82
|
+
const palmMessages = this.convertMessagesToPalm(modelPromptMessages || [{ "author": "user", "content": modelPromptText }]);
|
|
83
|
+
|
|
84
|
+
let requestMessages = palmMessages.modifiedMessages;
|
|
85
|
+
|
|
86
|
+
// Check if the token length exceeds the model's max token length
|
|
87
|
+
if (tokenLength > modelTargetTokenLength) {
|
|
88
|
+
// Remove older messages until the token length is within the model's limit
|
|
89
|
+
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
|
|
93
|
+
const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
|
|
94
|
+
|
|
95
|
+
// For PaLM right now, the max return tokens is 1024, regardless of the max context length
|
|
96
|
+
// I can't think of a time you'd want to constrain it to fewer at the moment.
|
|
97
|
+
const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
|
|
98
|
+
|
|
99
|
+
if (max_tokens < 0) {
|
|
100
|
+
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Ensure there are an even number of messages (PaLM requires an even number of messages)
|
|
104
|
+
if (requestMessages.length % 2 === 0) {
|
|
105
|
+
requestMessages = requestMessages.slice(1);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
const requestParameters = {
|
|
109
|
+
instances: [{
|
|
110
|
+
context: context,
|
|
111
|
+
examples: examples,
|
|
112
|
+
messages: requestMessages,
|
|
113
|
+
}],
|
|
114
|
+
parameters: {
|
|
115
|
+
temperature: this.temperature ?? 0.7,
|
|
116
|
+
maxOutputTokens: max_tokens,
|
|
117
|
+
topP: parameters.topP ?? 0.95,
|
|
118
|
+
topK: parameters.topK ?? 40,
|
|
119
|
+
}
|
|
120
|
+
};
|
|
121
|
+
|
|
122
|
+
return requestParameters;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// Get the safetyAttributes from the PaLM Chat API response data
|
|
126
|
+
getSafetyAttributes(data) {
|
|
127
|
+
const { predictions } = data;
|
|
128
|
+
if (!predictions || !predictions.length) {
|
|
129
|
+
return null;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// if we got a predictions array back with more than one prediction, return the safetyAttributes of the first prediction
|
|
133
|
+
if (predictions.length > 1) {
|
|
134
|
+
return predictions[0].safetyAttributes ?? null;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// otherwise, return the safetyAttributes of the content of the first prediction
|
|
138
|
+
return predictions[0].safetyAttributes ?? null;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
// Execute the request to the PaLM Chat API
|
|
142
|
+
async execute(text, parameters, prompt) {
|
|
143
|
+
const url = this.requestUrl(text);
|
|
144
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
145
|
+
|
|
146
|
+
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
147
|
+
const params = {};
|
|
148
|
+
const headers = this.model.headers || {};
|
|
149
|
+
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
150
|
+
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
151
|
+
headers.Authorization = `Bearer ${authToken}`;
|
|
152
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
// Parse the response from the PaLM Chat API
|
|
156
|
+
parseResponse(data) {
|
|
157
|
+
const { predictions } = data;
|
|
158
|
+
if (!predictions || !predictions.length) {
|
|
159
|
+
return null;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
// Get the candidates array from the first prediction
|
|
163
|
+
const { candidates } = predictions[0];
|
|
164
|
+
|
|
165
|
+
// if it was blocked, return the blocked message
|
|
166
|
+
if (predictions[0].safetyAttributes?.blocked) {
|
|
167
|
+
return 'The response is blocked because the input or response potentially violates Google policies. Try rephrasing the prompt or adjusting the parameter settings. Currently, only English is supported.';
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
if (!candidates || !candidates.length) {
|
|
171
|
+
return null;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// If we got a candidates array back with more than one candidate, return the whole array
|
|
175
|
+
if (candidates.length > 1) {
|
|
176
|
+
return candidates;
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// Otherwise, return the content of the first candidate
|
|
180
|
+
const messageResult = candidates[0].content && candidates[0].content.trim();
|
|
181
|
+
return messageResult ?? null;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// Override the logging function to display the messages and responses
|
|
185
|
+
logRequestData(data, responseData, prompt) {
|
|
186
|
+
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
187
|
+
console.log(separator);
|
|
188
|
+
|
|
189
|
+
const instances = data && data.instances;
|
|
190
|
+
const messages = instances && instances[0] && instances[0].messages;
|
|
191
|
+
const { context, examples } = instances && instances [0] || {};
|
|
192
|
+
|
|
193
|
+
if (context) {
|
|
194
|
+
console.log(`\x1b[36mContext: ${context}\x1b[0m`);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if (examples && examples.length) {
|
|
198
|
+
examples.forEach((example, index) => {
|
|
199
|
+
console.log(`\x1b[36mExample ${index + 1}: Input: "${example.input.content}", Output: "${example.output.content}"\x1b[0m`);
|
|
200
|
+
});
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if (messages && messages.length > 1) {
|
|
204
|
+
messages.forEach((message, index) => {
|
|
205
|
+
const words = message.content.split(" ");
|
|
206
|
+
const tokenCount = encode(message.content).length;
|
|
207
|
+
const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
208
|
+
|
|
209
|
+
console.log(`\x1b[36mMessage ${index + 1}: Author: ${message.author}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
|
|
210
|
+
});
|
|
211
|
+
} else if (messages && messages.length === 1) {
|
|
212
|
+
console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const safetyAttributes = this.getSafetyAttributes(responseData);
|
|
216
|
+
|
|
217
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
218
|
+
|
|
219
|
+
if (safetyAttributes) {
|
|
220
|
+
console.log(`\x1b[33mSafety Attributes: ${JSON.stringify(safetyAttributes, null, 2)}\x1b[0m`);
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
if (prompt && prompt.debugInfo) {
|
|
224
|
+
prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
export default PalmChatPlugin;
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
// palmCompletionPlugin.js
|
|
2
|
+
|
|
3
|
+
import ModelPlugin from './modelPlugin.js';
|
|
4
|
+
|
|
5
|
+
// Helper function to truncate the prompt if it is too long
|
|
6
|
+
const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) => {
|
|
7
|
+
const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
|
|
8
|
+
|
|
9
|
+
if (textTokenCount > maxAllowedTokens) {
|
|
10
|
+
pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
|
|
11
|
+
return pathwayResolver.truncate(text, maxAllowedTokens);
|
|
12
|
+
}
|
|
13
|
+
return text;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
// PalmCompletionPlugin class for handling requests and responses to the PaLM API Text Completion API
|
|
17
|
+
class PalmCompletionPlugin extends ModelPlugin {
|
|
18
|
+
constructor(config, pathway) {
|
|
19
|
+
super(config, pathway);
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// Set up parameters specific to the PaLM API Text Completion API
|
|
23
|
+
getRequestParameters(text, parameters, prompt, pathwayResolver) {
|
|
24
|
+
const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
25
|
+
const { stream } = parameters;
|
|
26
|
+
// Define the model's max token length
|
|
27
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
28
|
+
|
|
29
|
+
const truncatedPrompt = truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
|
|
30
|
+
|
|
31
|
+
const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
|
|
32
|
+
|
|
33
|
+
if (max_tokens < 0) {
|
|
34
|
+
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
if (!truncatedPrompt) {
|
|
38
|
+
throw new Error(`Prompt is empty. The model will not be called.`);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
const requestParameters = {
|
|
42
|
+
instances: [
|
|
43
|
+
{ prompt: truncatedPrompt }
|
|
44
|
+
],
|
|
45
|
+
parameters: {
|
|
46
|
+
temperature: this.temperature ?? 0.7,
|
|
47
|
+
maxOutputTokens: max_tokens,
|
|
48
|
+
topP: parameters.topP ?? 0.95,
|
|
49
|
+
topK: parameters.topK ?? 40,
|
|
50
|
+
}
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
return requestParameters;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
// Execute the request to the PaLM API Text Completion API
|
|
57
|
+
async execute(text, parameters, prompt, pathwayResolver) {
|
|
58
|
+
const url = this.requestUrl(text);
|
|
59
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt, pathwayResolver);
|
|
60
|
+
|
|
61
|
+
const data = { ...requestParameters };
|
|
62
|
+
const params = {};
|
|
63
|
+
const headers = this.model.headers || {};
|
|
64
|
+
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
65
|
+
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
66
|
+
headers.Authorization = `Bearer ${authToken}`;
|
|
67
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// Parse the response from the PaLM API Text Completion API
|
|
71
|
+
parseResponse(data) {
|
|
72
|
+
const { predictions } = data;
|
|
73
|
+
if (!predictions || !predictions.length) {
|
|
74
|
+
return data;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// if we got a predictions array back with more than one prediction, return the whole array
|
|
78
|
+
if (predictions.length > 1) {
|
|
79
|
+
return predictions;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
// otherwise, return the content of the first prediction
|
|
83
|
+
// if it was blocked, return the blocked message
|
|
84
|
+
if (predictions[0].safetyAttributes?.blocked) {
|
|
85
|
+
return 'The response is blocked because the input or response potentially violates Google policies. Try rephrasing the prompt or adjusting the parameter settings. Currently, only English is supported.';
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
const contentResult = predictions[0].content && predictions[0].content.trim();
|
|
89
|
+
return contentResult ?? null;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// Get the safetyAttributes from the PaLM API Text Completion API response data
|
|
93
|
+
getSafetyAttributes(data) {
|
|
94
|
+
const { predictions } = data;
|
|
95
|
+
if (!predictions || !predictions.length) {
|
|
96
|
+
return null;
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// if we got a predictions array back with more than one prediction, return the safetyAttributes of the first prediction
|
|
100
|
+
if (predictions.length > 1) {
|
|
101
|
+
return predictions[0].safetyAttributes ?? null;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// otherwise, return the safetyAttributes of the content of the first prediction
|
|
105
|
+
return predictions[0].safetyAttributes ?? null;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Override the logging function to log the prompt and response
|
|
109
|
+
logRequestData(data, responseData, prompt) {
|
|
110
|
+
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
111
|
+
console.log(separator);
|
|
112
|
+
|
|
113
|
+
const safetyAttributes = this.getSafetyAttributes(responseData);
|
|
114
|
+
|
|
115
|
+
const instances = data && data.instances;
|
|
116
|
+
const modelInput = instances && instances[0] && instances[0].prompt;
|
|
117
|
+
|
|
118
|
+
if (modelInput) {
|
|
119
|
+
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
123
|
+
|
|
124
|
+
if (safetyAttributes) {
|
|
125
|
+
console.log(`\x1b[33mSafety Attributes: ${JSON.stringify(safetyAttributes, null, 2)}\x1b[0m`);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
if (prompt && prompt.debugInfo) {
|
|
129
|
+
prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
export default PalmCompletionPlugin;
|