@aj-archipelago/cortex 1.1.3 → 1.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintignore +3 -3
- package/README.md +17 -4
- package/config.js +45 -9
- package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/Dockerfile +1 -1
- package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/fileChunker.js +4 -1
- package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/package-lock.json +25 -216
- package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/package.json +2 -2
- package/helper-apps/cortex-whisper-wrapper/.dockerignore +27 -0
- package/helper-apps/cortex-whisper-wrapper/Dockerfile +32 -0
- package/helper-apps/cortex-whisper-wrapper/app.py +104 -0
- package/helper-apps/cortex-whisper-wrapper/docker-compose.debug.yml +12 -0
- package/helper-apps/cortex-whisper-wrapper/docker-compose.yml +10 -0
- package/helper-apps/cortex-whisper-wrapper/models/.gitkeep +0 -0
- package/helper-apps/cortex-whisper-wrapper/requirements.txt +5 -0
- package/lib/cortexRequest.js +117 -0
- package/lib/pathwayTools.js +2 -1
- package/lib/redisSubscription.js +2 -2
- package/lib/requestExecutor.js +360 -0
- package/lib/requestMonitor.js +131 -28
- package/package.json +2 -1
- package/pathways/summary.js +3 -3
- package/server/graphql.js +6 -6
- package/server/{pathwayPrompter.js → modelExecutor.js} +24 -21
- package/server/pathwayResolver.js +22 -17
- package/server/plugins/azureCognitivePlugin.js +25 -20
- package/server/plugins/azureTranslatePlugin.js +6 -10
- package/server/plugins/cohereGeneratePlugin.js +5 -12
- package/server/plugins/cohereSummarizePlugin.js +5 -12
- package/server/plugins/localModelPlugin.js +3 -3
- package/server/plugins/modelPlugin.js +18 -12
- package/server/plugins/openAiChatExtensionPlugin.js +5 -5
- package/server/plugins/openAiChatPlugin.js +8 -10
- package/server/plugins/openAiCompletionPlugin.js +9 -12
- package/server/plugins/openAiDallE3Plugin.js +14 -31
- package/server/plugins/openAiEmbeddingsPlugin.js +6 -9
- package/server/plugins/openAiImagePlugin.js +19 -15
- package/server/plugins/openAiWhisperPlugin.js +168 -100
- package/server/plugins/palmChatPlugin.js +9 -10
- package/server/plugins/palmCodeCompletionPlugin.js +2 -2
- package/server/plugins/palmCompletionPlugin.js +11 -12
- package/server/resolver.js +2 -2
- package/server/rest.js +1 -1
- package/tests/config.test.js +1 -1
- package/tests/mocks.js +5 -0
- package/tests/modelPlugin.test.js +3 -10
- package/tests/openAiChatPlugin.test.js +9 -8
- package/tests/openai_api.test.js +3 -3
- package/tests/palmChatPlugin.test.js +1 -1
- package/tests/palmCompletionPlugin.test.js +1 -1
- package/tests/pathwayResolver.test.js +2 -1
- package/tests/requestMonitor.test.js +94 -0
- package/tests/{requestDurationEstimator.test.js → requestMonitorDurationEstimator.test.js} +21 -17
- package/tests/truncateMessages.test.js +1 -1
- package/lib/request.js +0 -259
- package/lib/requestDurationEstimator.js +0 -90
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/blobHandler.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/docHelper.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/function.json +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/helper.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/index.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/localFileHandler.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/redis.js +0 -0
- /package/{helper_apps/CortexFileHandler → helper-apps/cortex-file-handler}/start.js +0 -0
|
@@ -1,57 +1,53 @@
|
|
|
1
1
|
// openAiWhisperPlugin.js
|
|
2
2
|
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
import { config } from '../../config.js';
|
|
4
|
+
import subsrt from 'subsrt';
|
|
3
5
|
import FormData from 'form-data';
|
|
4
6
|
import fs from 'fs';
|
|
5
|
-
import { axios } from '../../lib/
|
|
7
|
+
import { axios } from '../../lib/requestExecutor.js';
|
|
6
8
|
import stream from 'stream';
|
|
7
9
|
import os from 'os';
|
|
8
10
|
import path from 'path';
|
|
9
|
-
import { v4 as uuidv4 } from 'uuid';
|
|
10
|
-
import { config } from '../../config.js';
|
|
11
|
-
import { deleteTempPath } from '../../helper_apps/CortexFileHandler/helper.js';
|
|
12
11
|
import http from 'http';
|
|
13
12
|
import https from 'https';
|
|
13
|
+
import { URL } from 'url';
|
|
14
|
+
import { v4 as uuidv4 } from 'uuid';
|
|
14
15
|
import { promisify } from 'util';
|
|
15
|
-
import subsrt from 'subsrt';
|
|
16
16
|
import { publishRequestProgress } from '../../lib/redisSubscription.js';
|
|
17
17
|
import logger from '../../lib/logger.js';
|
|
18
|
-
|
|
19
18
|
const pipeline = promisify(stream.pipeline);
|
|
20
19
|
|
|
21
20
|
const API_URL = config.get('whisperMediaApiUrl');
|
|
22
21
|
const WHISPER_TS_API_URL = config.get('whisperTSApiUrl');
|
|
22
|
+
if(WHISPER_TS_API_URL){
|
|
23
|
+
logger.info(`WHISPER API URL using ${WHISPER_TS_API_URL}`);
|
|
24
|
+
}else{
|
|
25
|
+
logger.warn(`WHISPER API URL not set using default OpenAI API Whisper`);
|
|
26
|
+
}
|
|
23
27
|
|
|
24
28
|
const OFFSET_CHUNK = 1000 * 60 * 10; // 10 minutes for each chunk
|
|
25
29
|
|
|
26
|
-
function
|
|
27
|
-
const result = [];
|
|
28
|
-
|
|
29
|
-
function preprocessStr(str) {
|
|
30
|
-
return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
|
|
31
|
-
}
|
|
32
|
-
|
|
33
|
-
function shiftSubtitles(subtitle, shiftOffset) {
|
|
34
|
-
const captions = subsrt.parse(preprocessStr(subtitle));
|
|
35
|
-
const resynced = subsrt.resync(captions, { offset: shiftOffset });
|
|
36
|
-
return resynced;
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
for (let i = 0; i < subtitles.length; i++) {
|
|
40
|
-
result.push(...shiftSubtitles(subtitles[i], i * OFFSET_CHUNK));
|
|
41
|
-
}
|
|
42
|
-
|
|
30
|
+
async function deleteTempPath(path) {
|
|
43
31
|
try {
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
obj.text = obj.content;
|
|
48
|
-
}
|
|
32
|
+
if (!path) {
|
|
33
|
+
logger.warn('Temporary path is not defined.');
|
|
34
|
+
return;
|
|
49
35
|
}
|
|
50
|
-
|
|
51
|
-
|
|
36
|
+
if (!fs.existsSync(path)) {
|
|
37
|
+
logger.warn(`Temporary path ${path} does not exist.`);
|
|
38
|
+
return;
|
|
39
|
+
}
|
|
40
|
+
const stats = fs.statSync(path);
|
|
41
|
+
if (stats.isFile()) {
|
|
42
|
+
fs.unlinkSync(path);
|
|
43
|
+
logger.info(`Temporary file ${path} deleted successfully.`);
|
|
44
|
+
} else if (stats.isDirectory()) {
|
|
45
|
+
fs.rmSync(path, { recursive: true });
|
|
46
|
+
logger.info(`Temporary folder ${path} and its contents deleted successfully.`);
|
|
47
|
+
}
|
|
48
|
+
} catch (err) {
|
|
49
|
+
logger.error(`Error occurred while deleting the temporary path: ${err}`);
|
|
52
50
|
}
|
|
53
|
-
|
|
54
|
-
return subsrt.build(result, { format: format === 'vtt' ? 'vtt' : 'srt' });
|
|
55
51
|
}
|
|
56
52
|
|
|
57
53
|
function generateUniqueFilename(extension) {
|
|
@@ -92,9 +88,49 @@ const downloadFile = async (fileUrl) => {
|
|
|
92
88
|
});
|
|
93
89
|
};
|
|
94
90
|
|
|
91
|
+
// convert srt format to text
|
|
92
|
+
function convertToText(str) {
|
|
93
|
+
return str
|
|
94
|
+
.split('\n')
|
|
95
|
+
.filter(line => !line.match(/^\d+$/) && !line.match(/^\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}$/) && line !== '')
|
|
96
|
+
.join(' ');
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
function alignSubtitles(subtitles, format) {
|
|
100
|
+
const result = [];
|
|
101
|
+
|
|
102
|
+
function preprocessStr(str) {
|
|
103
|
+
return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
function shiftSubtitles(subtitle, shiftOffset) {
|
|
107
|
+
const captions = subsrt.parse(preprocessStr(subtitle));
|
|
108
|
+
const resynced = subsrt.resync(captions, { offset: shiftOffset });
|
|
109
|
+
return resynced;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
for (let i = 0; i < subtitles.length; i++) {
|
|
113
|
+
result.push(...shiftSubtitles(subtitles[i], i * OFFSET_CHUNK));
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
try {
|
|
117
|
+
//if content has needed html style tags, keep them
|
|
118
|
+
for(const obj of result) {
|
|
119
|
+
if(obj && obj.content){
|
|
120
|
+
obj.text = obj.content;
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
} catch (error) {
|
|
124
|
+
logger.error(`An error occurred in content text parsing: ${error}`);
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
return subsrt.build(result, { format: format === 'vtt' ? 'vtt' : 'srt' });
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
95
131
|
class OpenAIWhisperPlugin extends ModelPlugin {
|
|
96
|
-
constructor(
|
|
97
|
-
super(
|
|
132
|
+
constructor(pathway, model) {
|
|
133
|
+
super(pathway, model);
|
|
98
134
|
}
|
|
99
135
|
|
|
100
136
|
async getMediaChunks(file, requestId) {
|
|
@@ -108,7 +144,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
108
144
|
return [file];
|
|
109
145
|
}
|
|
110
146
|
} catch (err) {
|
|
111
|
-
logger.error(`Error getting media chunks list from api
|
|
147
|
+
logger.error(`Error getting media chunks list from api: ${err}`);
|
|
112
148
|
throw err;
|
|
113
149
|
}
|
|
114
150
|
}
|
|
@@ -118,7 +154,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
118
154
|
if (API_URL) {
|
|
119
155
|
//call helper api to mark processing as completed
|
|
120
156
|
const res = await axios.delete(API_URL, { params: { requestId } });
|
|
121
|
-
logger.info(`Marked request ${requestId} as completed
|
|
157
|
+
logger.info(`Marked request ${requestId} as completed:`, res.data);
|
|
122
158
|
return res.data;
|
|
123
159
|
}
|
|
124
160
|
} catch (err) {
|
|
@@ -127,38 +163,21 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
127
163
|
}
|
|
128
164
|
|
|
129
165
|
// Execute the request to the OpenAI Whisper API
|
|
130
|
-
async execute(text, parameters, prompt,
|
|
131
|
-
const {
|
|
132
|
-
const
|
|
133
|
-
|
|
134
|
-
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
166
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
167
|
+
const { pathwayResolver } = cortexRequest;
|
|
168
|
+
const { responseFormat, wordTimestamped, highlightWords, maxLineWidth, maxLineCount, maxWordsPerLine } = parameters;
|
|
169
|
+
cortexRequest.url = this.requestUrl(text);
|
|
135
170
|
|
|
136
|
-
const
|
|
137
|
-
|
|
138
|
-
if (!WHISPER_TS_API_URL) {
|
|
139
|
-
throw new Error(`WHISPER_TS_API_URL not set for word timestamped processing`);
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
try {
|
|
143
|
-
const tsparams = { fileurl:uri };
|
|
144
|
-
if(highlightWords) tsparams.highlight_words = highlightWords ? "True" : "False";
|
|
145
|
-
if(maxLineWidth) tsparams.max_line_width = maxLineWidth;
|
|
146
|
-
if(maxLineCount) tsparams.max_line_count = maxLineCount;
|
|
147
|
-
if(maxWordsPerLine) tsparams.max_words_per_line = maxWordsPerLine;
|
|
148
|
-
if(wordTimestamped!=null) tsparams.word_timestamps = wordTimestamped;
|
|
149
|
-
|
|
150
|
-
const res = await this.executeRequest(WHISPER_TS_API_URL, tsparams, {}, {}, {}, requestId, pathway);
|
|
151
|
-
return res;
|
|
152
|
-
} catch (err) {
|
|
153
|
-
logger.error(`Error getting word timestamped data from api: ${err}`);
|
|
154
|
-
throw err;
|
|
155
|
-
}
|
|
156
|
-
}
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
const processChunk = async (chunk) => {
|
|
171
|
+
const chunks = [];
|
|
172
|
+
const processChunk = async (uri) => {
|
|
160
173
|
try {
|
|
174
|
+
const chunk = await downloadFile(uri);
|
|
175
|
+
chunks.push(chunk);
|
|
176
|
+
|
|
161
177
|
const { language, responseFormat } = parameters;
|
|
178
|
+
cortexRequest.url = this.requestUrl(text);
|
|
179
|
+
const params = {};
|
|
180
|
+
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
162
181
|
const response_format = responseFormat || 'text';
|
|
163
182
|
|
|
164
183
|
const formData = new FormData();
|
|
@@ -168,9 +187,44 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
168
187
|
language && formData.append('language', language);
|
|
169
188
|
modelPromptText && formData.append('prompt', modelPromptText);
|
|
170
189
|
|
|
171
|
-
|
|
190
|
+
cortexRequest.data = formData;
|
|
191
|
+
cortexRequest.params = params;
|
|
192
|
+
cortexRequest.headers = { ...cortexRequest.headers, ...formData.getHeaders() };
|
|
193
|
+
|
|
194
|
+
return this.executeRequest(cortexRequest);
|
|
172
195
|
} catch (err) {
|
|
173
|
-
logger.error(err);
|
|
196
|
+
logger.error(`Error getting word timestamped data from api: ${err}`);
|
|
197
|
+
throw err;
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
const processTS = async (uri) => {
|
|
202
|
+
try {
|
|
203
|
+
const tsparams = { fileurl:uri };
|
|
204
|
+
if(highlightWords) tsparams.highlight_words = highlightWords ? "True" : "False";
|
|
205
|
+
if(maxLineWidth) tsparams.max_line_width = maxLineWidth;
|
|
206
|
+
if(maxLineCount) tsparams.max_line_count = maxLineCount;
|
|
207
|
+
if(maxWordsPerLine) tsparams.max_words_per_line = maxWordsPerLine;
|
|
208
|
+
if(wordTimestamped!=null) {
|
|
209
|
+
if(!wordTimestamped) {
|
|
210
|
+
tsparams.word_timestamps = "False";
|
|
211
|
+
}else{
|
|
212
|
+
tsparams.word_timestamps = wordTimestamped;
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
cortexRequest.url = WHISPER_TS_API_URL;
|
|
217
|
+
cortexRequest.data = tsparams;
|
|
218
|
+
|
|
219
|
+
const res = await this.executeRequest(cortexRequest);
|
|
220
|
+
|
|
221
|
+
if(!wordTimestamped && !responseFormat){
|
|
222
|
+
//if no response format, convert to text
|
|
223
|
+
return convertToText(res);
|
|
224
|
+
}
|
|
225
|
+
return res;
|
|
226
|
+
} catch (err) {
|
|
227
|
+
logger.error(`Error getting word timestamped data from api: ${err}`);
|
|
174
228
|
throw err;
|
|
175
229
|
}
|
|
176
230
|
}
|
|
@@ -179,59 +233,74 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
179
233
|
let { file } = parameters;
|
|
180
234
|
let totalCount = 0;
|
|
181
235
|
let completedCount = 0;
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
236
|
+
let partialCount = 0;
|
|
237
|
+
const { requestId } = pathwayResolver;
|
|
238
|
+
|
|
239
|
+
const MAXPARTIALCOUNT = 60;
|
|
240
|
+
const sendProgress = (partial=false) => {
|
|
241
|
+
if(partial){
|
|
242
|
+
partialCount = Math.min(partialCount + 1, MAXPARTIALCOUNT-1);
|
|
243
|
+
}else {
|
|
244
|
+
partialCount = 0;
|
|
245
|
+
completedCount++;
|
|
246
|
+
}
|
|
186
247
|
if (completedCount >= totalCount) return;
|
|
248
|
+
|
|
249
|
+
const progress = (partialCount / MAXPARTIALCOUNT + completedCount) / totalCount;
|
|
250
|
+
logger.info(`Progress for ${requestId}: ${progress}`);
|
|
251
|
+
|
|
187
252
|
publishRequestProgress({
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
253
|
+
requestId,
|
|
254
|
+
progress,
|
|
255
|
+
data: null,
|
|
191
256
|
});
|
|
192
|
-
|
|
193
257
|
}
|
|
194
258
|
|
|
195
|
-
|
|
259
|
+
async function processURI(uri) {
|
|
260
|
+
let result = null;
|
|
261
|
+
let _promise = null;
|
|
262
|
+
if(WHISPER_TS_API_URL){
|
|
263
|
+
_promise = processTS
|
|
264
|
+
}else {
|
|
265
|
+
_promise = processChunk;
|
|
266
|
+
}
|
|
267
|
+
_promise(uri).then((ts) => { result = ts;});
|
|
268
|
+
|
|
269
|
+
//send updates while waiting for result
|
|
270
|
+
while(!result) {
|
|
271
|
+
sendProgress(true);
|
|
272
|
+
await new Promise(r => setTimeout(r, 3000));
|
|
273
|
+
}
|
|
274
|
+
return result;
|
|
275
|
+
}
|
|
276
|
+
|
|
196
277
|
try {
|
|
197
278
|
const uris = await this.getMediaChunks(file, requestId); // array of remote file uris
|
|
198
279
|
if (!uris || !uris.length) {
|
|
199
280
|
throw new Error(`Error in getting chunks from media helper for file ${file}`);
|
|
200
281
|
}
|
|
201
|
-
totalCount = uris.length
|
|
202
|
-
API_URL && (completedCount = uris.length); // api progress is already calculated
|
|
282
|
+
totalCount = uris.length + 1; // total number of chunks that will be processed
|
|
203
283
|
|
|
204
|
-
// sequential
|
|
284
|
+
// sequential process of chunks
|
|
205
285
|
for (const uri of uris) {
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
result.push(ts);
|
|
210
|
-
} else {
|
|
211
|
-
chunks.push(await downloadFile(uri));
|
|
212
|
-
}
|
|
213
|
-
sendProgress();
|
|
214
|
-
}
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
// sequential processing of chunks
|
|
218
|
-
for (const chunk of chunks) {
|
|
219
|
-
result.push(await processChunk(chunk));
|
|
220
|
-
sendProgress();
|
|
286
|
+
sendProgress();
|
|
287
|
+
const ts = await processURI(uri);
|
|
288
|
+
result.push(ts);
|
|
221
289
|
}
|
|
222
290
|
|
|
223
|
-
// parallel processing, dropped
|
|
224
|
-
// result = await Promise.all(mediaSplit.chunks.map(processChunk));
|
|
225
|
-
|
|
226
291
|
} catch (error) {
|
|
227
|
-
const errMsg = `Transcribe error: ${error?.message || error}`;
|
|
292
|
+
const errMsg = `Transcribe error: ${error?.response?.data || error?.message || error}`;
|
|
228
293
|
logger.error(errMsg);
|
|
229
294
|
return errMsg;
|
|
230
295
|
}
|
|
231
296
|
finally {
|
|
232
297
|
try {
|
|
233
298
|
for (const chunk of chunks) {
|
|
234
|
-
|
|
299
|
+
try {
|
|
300
|
+
await deleteTempPath(chunk);
|
|
301
|
+
} catch (error) {
|
|
302
|
+
//ignore error
|
|
303
|
+
}
|
|
235
304
|
}
|
|
236
305
|
|
|
237
306
|
await this.markCompletedForCleanUp(requestId);
|
|
@@ -258,4 +327,3 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
258
327
|
}
|
|
259
328
|
|
|
260
329
|
export default OpenAIWhisperPlugin;
|
|
261
|
-
|
|
@@ -5,8 +5,8 @@ import HandleBars from '../../lib/handleBars.js';
|
|
|
5
5
|
import logger from '../../lib/logger.js';
|
|
6
6
|
|
|
7
7
|
class PalmChatPlugin extends ModelPlugin {
|
|
8
|
-
constructor(
|
|
9
|
-
super(
|
|
8
|
+
constructor(pathway, model) {
|
|
9
|
+
super(pathway, model);
|
|
10
10
|
}
|
|
11
11
|
|
|
12
12
|
// Convert to PaLM messages array format if necessary
|
|
@@ -137,18 +137,17 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
137
137
|
}
|
|
138
138
|
|
|
139
139
|
// Execute the request to the PaLM Chat API
|
|
140
|
-
async execute(text, parameters, prompt,
|
|
141
|
-
const url = this.requestUrl(text);
|
|
140
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
142
141
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
143
|
-
const { requestId, pathway} = pathwayResolver;
|
|
144
142
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
143
|
+
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
|
|
144
|
+
cortexRequest.params = {}; // query params
|
|
145
|
+
|
|
148
146
|
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
149
147
|
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
150
|
-
headers.Authorization = `Bearer ${authToken}`;
|
|
151
|
-
|
|
148
|
+
cortexRequest.headers.Authorization = `Bearer ${authToken}`;
|
|
149
|
+
|
|
150
|
+
return this.executeRequest(cortexRequest);
|
|
152
151
|
}
|
|
153
152
|
|
|
154
153
|
// Parse the response from the PaLM Chat API
|
|
@@ -4,8 +4,8 @@ import PalmCompletionPlugin from './palmCompletionPlugin.js';
|
|
|
4
4
|
|
|
5
5
|
// PalmCodeCompletionPlugin class for handling requests and responses to the PaLM API Code Completion API
|
|
6
6
|
class PalmCodeCompletionPlugin extends PalmCompletionPlugin {
|
|
7
|
-
constructor(
|
|
8
|
-
super(
|
|
7
|
+
constructor(pathway, model) {
|
|
8
|
+
super(pathway, model);
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
// Set up parameters specific to the PaLM API Code Completion API
|
|
@@ -6,8 +6,8 @@ import logger from '../../lib/logger.js';
|
|
|
6
6
|
|
|
7
7
|
// PalmCompletionPlugin class for handling requests and responses to the PaLM API Text Completion API
|
|
8
8
|
class PalmCompletionPlugin extends ModelPlugin {
|
|
9
|
-
constructor(
|
|
10
|
-
super(
|
|
9
|
+
constructor(pathway, model) {
|
|
10
|
+
super(pathway, model);
|
|
11
11
|
}
|
|
12
12
|
|
|
13
13
|
truncatePromptIfNecessary (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) {
|
|
@@ -54,18 +54,17 @@ class PalmCompletionPlugin extends ModelPlugin {
|
|
|
54
54
|
}
|
|
55
55
|
|
|
56
56
|
// Execute the request to the PaLM API Text Completion API
|
|
57
|
-
async execute(text, parameters, prompt,
|
|
58
|
-
const
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
const params = {};
|
|
64
|
-
const headers = this.model.headers || {};
|
|
57
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
58
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt, cortexRequest.pathwayResolver);
|
|
59
|
+
|
|
60
|
+
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
|
|
61
|
+
cortexRequest.params = {}; // query params
|
|
62
|
+
|
|
65
63
|
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
66
64
|
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
67
|
-
headers.Authorization = `Bearer ${authToken}`;
|
|
68
|
-
|
|
65
|
+
cortexRequest.headers.Authorization = `Bearer ${authToken}`;
|
|
66
|
+
|
|
67
|
+
return this.executeRequest(cortexRequest);
|
|
69
68
|
}
|
|
70
69
|
|
|
71
70
|
// Parse the response from the PaLM API Text Completion API
|
package/server/resolver.js
CHANGED
|
@@ -4,7 +4,7 @@ import { PathwayResolver } from './pathwayResolver.js';
|
|
|
4
4
|
// This resolver uses standard parameters required by Apollo server:
|
|
5
5
|
// (parent, args, contextValue, info)
|
|
6
6
|
const rootResolver = async (parent, args, contextValue, info) => {
|
|
7
|
-
const { config, pathway
|
|
7
|
+
const { config, pathway } = contextValue;
|
|
8
8
|
const { temperature, enableGraphqlCache } = pathway;
|
|
9
9
|
|
|
10
10
|
// Turn on graphql caching if enableGraphqlCache true and temperature is 0
|
|
@@ -12,7 +12,7 @@ const rootResolver = async (parent, args, contextValue, info) => {
|
|
|
12
12
|
info.cacheControl.setCacheHint({ maxAge: 60 * 60 * 24, scope: 'PUBLIC' });
|
|
13
13
|
}
|
|
14
14
|
|
|
15
|
-
const pathwayResolver = new PathwayResolver({ config, pathway, args
|
|
15
|
+
const pathwayResolver = new PathwayResolver({ config, pathway, args });
|
|
16
16
|
contextValue.pathwayResolver = pathwayResolver;
|
|
17
17
|
|
|
18
18
|
// Execute the request with timeout
|
package/server/rest.js
CHANGED
package/tests/config.test.js
CHANGED
package/tests/mocks.js
CHANGED
|
@@ -6,6 +6,7 @@ export const mockConfig = {
|
|
|
6
6
|
defaultModelName: 'testModel',
|
|
7
7
|
models: {
|
|
8
8
|
testModel: {
|
|
9
|
+
name: 'testModel',
|
|
9
10
|
url: 'https://api.example.com/testModel',
|
|
10
11
|
type: 'OPENAI-COMPLETION',
|
|
11
12
|
},
|
|
@@ -40,6 +41,7 @@ export const mockConfig = {
|
|
|
40
41
|
|
|
41
42
|
export const mockPathwayResolverString = {
|
|
42
43
|
model: {
|
|
44
|
+
name: 'testModel',
|
|
43
45
|
url: 'https://api.example.com/testModel',
|
|
44
46
|
type: 'OPENAI-COMPLETION',
|
|
45
47
|
},
|
|
@@ -51,6 +53,7 @@ export const mockConfig = {
|
|
|
51
53
|
|
|
52
54
|
export const mockPathwayResolverFunction = {
|
|
53
55
|
model: {
|
|
56
|
+
name: 'testModel',
|
|
54
57
|
url: 'https://api.example.com/testModel',
|
|
55
58
|
type: 'OPENAI-COMPLETION',
|
|
56
59
|
},
|
|
@@ -64,6 +67,7 @@ export const mockConfig = {
|
|
|
64
67
|
|
|
65
68
|
export const mockPathwayResolverMessages = {
|
|
66
69
|
model: {
|
|
70
|
+
name: 'testModel',
|
|
67
71
|
url: 'https://api.example.com/testModel',
|
|
68
72
|
type: 'OPENAI-COMPLETION',
|
|
69
73
|
},
|
|
@@ -78,3 +82,4 @@ export const mockConfig = {
|
|
|
78
82
|
}),
|
|
79
83
|
};
|
|
80
84
|
|
|
85
|
+
export const mockModelEndpoints = { testModel: { name: 'testModel', url: 'https://api.example.com/testModel', type: 'OPENAI-COMPLETION' }};
|
|
@@ -8,10 +8,10 @@ const DEFAULT_MAX_TOKENS = 4096;
|
|
|
8
8
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
9
9
|
|
|
10
10
|
// Mock configuration and pathway objects
|
|
11
|
-
const { config, pathway,
|
|
11
|
+
const { config, pathway, model } = mockPathwayResolverString;
|
|
12
12
|
|
|
13
13
|
test('ModelPlugin constructor', (t) => {
|
|
14
|
-
const modelPlugin = new ModelPlugin(
|
|
14
|
+
const modelPlugin = new ModelPlugin(pathway, model);
|
|
15
15
|
|
|
16
16
|
t.is(modelPlugin.modelName, pathway.model, 'modelName should be set from pathway');
|
|
17
17
|
t.deepEqual(modelPlugin.model, config.get('models')[pathway.model], 'model should be set from config');
|
|
@@ -20,7 +20,7 @@ test('ModelPlugin constructor', (t) => {
|
|
|
20
20
|
});
|
|
21
21
|
|
|
22
22
|
test.beforeEach((t) => {
|
|
23
|
-
t.context.modelPlugin = new ModelPlugin(
|
|
23
|
+
t.context.modelPlugin = new ModelPlugin(pathway, model);
|
|
24
24
|
});
|
|
25
25
|
|
|
26
26
|
test('getCompiledPrompt - text and parameters', (t) => {
|
|
@@ -71,13 +71,6 @@ test('getPromptTokenRatio', (t) => {
|
|
|
71
71
|
t.is(modelPlugin.getPromptTokenRatio(), DEFAULT_PROMPT_TOKEN_RATIO, 'getPromptTokenRatio should return default prompt token ratio');
|
|
72
72
|
});
|
|
73
73
|
|
|
74
|
-
test('requestUrl', (t) => {
|
|
75
|
-
const { modelPlugin } = t.context;
|
|
76
|
-
|
|
77
|
-
const expectedUrl = HandleBars.compile(modelPlugin.model.url)({ ...modelPlugin.model, ...config.getEnv(), ...config });
|
|
78
|
-
t.is(modelPlugin.requestUrl(), expectedUrl, 'requestUrl should return the correct URL');
|
|
79
|
-
});
|
|
80
|
-
|
|
81
74
|
test('default parseResponse', (t) => {
|
|
82
75
|
const { modelPlugin } = t.context;
|
|
83
76
|
const multipleChoicesResponse = {
|
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
import test from 'ava';
|
|
2
2
|
import OpenAIChatPlugin from '../server/plugins/openAiChatPlugin.js';
|
|
3
3
|
import { mockPathwayResolverMessages } from './mocks.js';
|
|
4
|
+
import { config } from '../config.js';
|
|
4
5
|
|
|
5
|
-
const {
|
|
6
|
+
const { pathway, modelName, model } = mockPathwayResolverMessages;
|
|
6
7
|
|
|
7
8
|
// Test the constructor
|
|
8
9
|
test('constructor', (t) => {
|
|
9
|
-
const plugin = new OpenAIChatPlugin(
|
|
10
|
-
t.is(plugin.config,
|
|
10
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
11
|
+
t.is(plugin.config, config);
|
|
11
12
|
t.is(plugin.pathwayPrompt, mockPathwayResolverMessages.pathway.prompt);
|
|
12
13
|
});
|
|
13
14
|
|
|
14
15
|
// Test the convertPalmToOpenAIMessages function
|
|
15
16
|
test('convertPalmToOpenAIMessages', (t) => {
|
|
16
|
-
const plugin = new OpenAIChatPlugin(
|
|
17
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
17
18
|
const context = 'This is a test context.';
|
|
18
19
|
const examples = [
|
|
19
20
|
{
|
|
@@ -37,7 +38,7 @@ test('convertPalmToOpenAIMessages', (t) => {
|
|
|
37
38
|
|
|
38
39
|
// Test the getRequestParameters function
|
|
39
40
|
test('getRequestParameters', async (t) => {
|
|
40
|
-
const plugin = new OpenAIChatPlugin(
|
|
41
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
41
42
|
const text = 'Help me';
|
|
42
43
|
const parameters = { name: 'John', age: 30 };
|
|
43
44
|
const prompt = mockPathwayResolverMessages.pathway.prompt;
|
|
@@ -59,7 +60,7 @@ test('getRequestParameters', async (t) => {
|
|
|
59
60
|
|
|
60
61
|
// Test the execute function
|
|
61
62
|
test('execute', async (t) => {
|
|
62
|
-
const plugin = new OpenAIChatPlugin(
|
|
63
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
63
64
|
const text = 'Help me';
|
|
64
65
|
const parameters = { name: 'John', age: 30 };
|
|
65
66
|
const prompt = mockPathwayResolverMessages.pathway.prompt;
|
|
@@ -91,7 +92,7 @@ test('execute', async (t) => {
|
|
|
91
92
|
|
|
92
93
|
// Test the parseResponse function
|
|
93
94
|
test('parseResponse', (t) => {
|
|
94
|
-
const plugin = new OpenAIChatPlugin(
|
|
95
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
95
96
|
const data = {
|
|
96
97
|
choices: [
|
|
97
98
|
{
|
|
@@ -107,7 +108,7 @@ test('parseResponse', (t) => {
|
|
|
107
108
|
|
|
108
109
|
// Test the logRequestData function
|
|
109
110
|
test('logRequestData', (t) => {
|
|
110
|
-
const plugin = new OpenAIChatPlugin(
|
|
111
|
+
const plugin = new OpenAIChatPlugin(pathway, model);
|
|
111
112
|
const data = {
|
|
112
113
|
messages: [
|
|
113
114
|
{ role: 'user', content: 'User: Help me\nAssistant: Please help John who is 30 years old.' },
|
package/tests/openai_api.test.js
CHANGED
|
@@ -5,7 +5,7 @@ import got from 'got';
|
|
|
5
5
|
import axios from 'axios';
|
|
6
6
|
import serverFactory from '../index.js';
|
|
7
7
|
|
|
8
|
-
const API_BASE =
|
|
8
|
+
const API_BASE = `http://localhost:${process.env.CORTEX_PORT}/v1`;
|
|
9
9
|
|
|
10
10
|
let testServer;
|
|
11
11
|
|
|
@@ -110,7 +110,7 @@ test('POST SSE: /v1/completions should send a series of events and a [DONE] even
|
|
|
110
110
|
stream: true,
|
|
111
111
|
};
|
|
112
112
|
|
|
113
|
-
const url =
|
|
113
|
+
const url = `http://localhost:${process.env.CORTEX_PORT}/v1`;
|
|
114
114
|
|
|
115
115
|
const completionsAssertions = (t, messageJson) => {
|
|
116
116
|
t.truthy(messageJson.id);
|
|
@@ -133,7 +133,7 @@ test('POST SSE: /v1/chat/completions should send a series of events and a [DONE]
|
|
|
133
133
|
stream: true,
|
|
134
134
|
};
|
|
135
135
|
|
|
136
|
-
const url =
|
|
136
|
+
const url = `http://localhost:${process.env.CORTEX_PORT}/v1`;
|
|
137
137
|
|
|
138
138
|
const chatCompletionsAssertions = (t, messageJson) => {
|
|
139
139
|
t.truthy(messageJson.id);
|
|
@@ -6,7 +6,7 @@ import { mockPathwayResolverMessages } from './mocks.js';
|
|
|
6
6
|
const { config, pathway, modelName, model } = mockPathwayResolverMessages;
|
|
7
7
|
|
|
8
8
|
test.beforeEach((t) => {
|
|
9
|
-
const palmChatPlugin = new PalmChatPlugin(
|
|
9
|
+
const palmChatPlugin = new PalmChatPlugin(pathway, model);
|
|
10
10
|
t.context = { palmChatPlugin };
|
|
11
11
|
});
|
|
12
12
|
|
|
@@ -7,7 +7,7 @@ import { mockPathwayResolverString } from './mocks.js';
|
|
|
7
7
|
const { config, pathway, modelName, model } = mockPathwayResolverString;
|
|
8
8
|
|
|
9
9
|
test.beforeEach((t) => {
|
|
10
|
-
const palmCompletionPlugin = new PalmCompletionPlugin(
|
|
10
|
+
const palmCompletionPlugin = new PalmCompletionPlugin(pathway, model);
|
|
11
11
|
t.context = { palmCompletionPlugin };
|
|
12
12
|
});
|
|
13
13
|
|