@aj-archipelago/cortex 1.0.9 → 1.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/config.js +52 -46
- package/helper_apps/MediaFileChunker/package-lock.json +7 -6
- package/helper_apps/MediaFileChunker/package.json +1 -1
- package/lib/request.js +149 -34
- package/package.json +1 -1
- package/pathways/basePathway.js +12 -6
- package/pathways/index.js +2 -0
- package/pathways/summary.js +1 -1
- package/pathways/test_cohere_summarize.js +10 -0
- package/pathways/transcribe.js +1 -0
- package/pathways/translate.js +1 -0
- package/server/chunker.js +1 -1
- package/server/graphql.js +2 -1
- package/server/parser.js +12 -0
- package/server/pathwayPrompter.js +12 -0
- package/server/pathwayResolver.js +121 -48
- package/server/pathwayResponseParser.js +5 -9
- package/server/plugins/azureTranslatePlugin.js +3 -3
- package/server/plugins/cohereGeneratePlugin.js +60 -0
- package/server/plugins/cohereSummarizePlugin.js +50 -0
- package/server/plugins/modelPlugin.js +34 -19
- package/server/plugins/openAiChatExtensionPlugin.js +56 -0
- package/server/plugins/openAiChatPlugin.js +5 -5
- package/server/plugins/openAiCompletionPlugin.js +4 -4
- package/server/plugins/openAiWhisperPlugin.js +3 -3
- package/server/plugins/palmChatPlugin.js +3 -3
- package/server/plugins/palmCompletionPlugin.js +3 -3
- package/server/resolver.js +2 -2
- package/server/rest.js +28 -16
- package/server/subscriptions.js +3 -4
- package/server/typeDef.js +2 -1
- package/tests/openAiChatPlugin.test.js +1 -1
- package/tests/server.js +23 -0
|
@@ -25,7 +25,7 @@ class PathwayResolver {
|
|
|
25
25
|
args?.model,
|
|
26
26
|
pathway.inputParameters?.model,
|
|
27
27
|
config.get('defaultModelName')
|
|
28
|
-
].find(modelName => modelName && config.get('models')
|
|
28
|
+
].find(modelName => modelName && Object.prototype.hasOwnProperty.call(config.get('models'), modelName));
|
|
29
29
|
this.model = config.get('models')[this.modelName];
|
|
30
30
|
|
|
31
31
|
if (!this.model) {
|
|
@@ -63,57 +63,118 @@ class PathwayResolver {
|
|
|
63
63
|
this.pathwayPrompt = pathway.prompt;
|
|
64
64
|
}
|
|
65
65
|
|
|
66
|
+
// This code handles async and streaming responses. In either case, we use
|
|
67
|
+
// the graphql subscription to send progress updates to the client. Most of
|
|
68
|
+
// the time the client will be an external client, but it could also be the
|
|
69
|
+
// Cortex REST api code.
|
|
66
70
|
async asyncResolve(args) {
|
|
67
|
-
const
|
|
71
|
+
const MAX_RETRY_COUNT = 3;
|
|
72
|
+
let attempt = 0;
|
|
73
|
+
let streamErrorOccurred = false;
|
|
74
|
+
|
|
75
|
+
while (attempt < MAX_RETRY_COUNT) {
|
|
76
|
+
const responseData = await this.executePathway(args);
|
|
77
|
+
|
|
78
|
+
if (args.async || typeof responseData === 'string') {
|
|
79
|
+
const { completedCount, totalCount } = requestState[this.requestId];
|
|
80
|
+
requestState[this.requestId].data = responseData;
|
|
81
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
82
|
+
requestProgress: {
|
|
83
|
+
requestId: this.requestId,
|
|
84
|
+
progress: completedCount / totalCount,
|
|
85
|
+
data: JSON.stringify(responseData),
|
|
86
|
+
}
|
|
87
|
+
});
|
|
88
|
+
} else {
|
|
89
|
+
try {
|
|
90
|
+
const incomingMessage = responseData;
|
|
68
91
|
|
|
69
|
-
|
|
70
|
-
if(args.async || typeof responseData === 'string') {
|
|
71
|
-
const { completedCount, totalCount } = requestState[this.requestId];
|
|
72
|
-
requestState[this.requestId].data = responseData;
|
|
73
|
-
pubsub.publish('REQUEST_PROGRESS', {
|
|
74
|
-
requestProgress: {
|
|
75
|
-
requestId: this.requestId,
|
|
76
|
-
progress: completedCount / totalCount,
|
|
77
|
-
data: JSON.stringify(responseData),
|
|
78
|
-
}
|
|
79
|
-
});
|
|
80
|
-
} else { // stream
|
|
81
|
-
try {
|
|
82
|
-
const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
|
|
83
|
-
incomingMessage.on('data', data => {
|
|
84
|
-
const events = data.toString().split('\n');
|
|
85
|
-
|
|
86
|
-
events.forEach(event => {
|
|
87
|
-
if (event.trim() === '') return; // Skip empty lines
|
|
88
|
-
|
|
89
|
-
const message = event.replace(/^data: /, '');
|
|
90
|
-
|
|
91
|
-
//console.log(`====================================`);
|
|
92
|
-
//console.log(`STREAM EVENT: ${event}`);
|
|
93
|
-
//console.log(`MESSAGE: ${message}`);
|
|
94
|
-
|
|
95
|
-
const requestProgress = {
|
|
96
|
-
requestId: this.requestId,
|
|
97
|
-
data: message,
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
if (message.trim() === '[DONE]') {
|
|
101
|
-
requestProgress.progress = 1;
|
|
102
|
-
}
|
|
103
|
-
|
|
92
|
+
const processData = (data) => {
|
|
104
93
|
try {
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
94
|
+
//console.log(`\n\nReceived stream data for requestId ${this.requestId}`, data.toString());
|
|
95
|
+
let events = data.toString().split('\n');
|
|
96
|
+
|
|
97
|
+
//events = "data: {\"id\":\"chatcmpl-20bf1895-2fa7-4ef9-abfe-4d142aba5817\",\"object\":\"chat.completion.chunk\",\"created\":1689303423723,\"model\":\"gpt-4\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":{\"error\":{\"message\":\"The server had an error while processing your request. Sorry about that!\",\"type\":\"server_error\",\"param\":null,\"code\":null}}},\"finish_reason\":null}]}\n\n".split("\n");
|
|
98
|
+
|
|
99
|
+
for (let event of events) {
|
|
100
|
+
if (streamErrorOccurred) break;
|
|
101
|
+
|
|
102
|
+
// skip empty events
|
|
103
|
+
if (!(event.trim() === '')) {
|
|
104
|
+
//console.log(`Processing stream event for requestId ${this.requestId}`, event);
|
|
105
|
+
|
|
106
|
+
let message = event.replace(/^data: /, '');
|
|
107
|
+
|
|
108
|
+
const requestProgress = {
|
|
109
|
+
requestId: this.requestId,
|
|
110
|
+
data: message,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// check for end of stream or in-stream errors
|
|
114
|
+
if (message.trim() === '[DONE]') {
|
|
115
|
+
requestProgress.progress = 1;
|
|
116
|
+
} else {
|
|
117
|
+
let parsedMessage;
|
|
118
|
+
try {
|
|
119
|
+
parsedMessage = JSON.parse(message);
|
|
120
|
+
} catch (error) {
|
|
121
|
+
console.error('Could not JSON parse stream message', message, error);
|
|
122
|
+
return;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
const streamError = parsedMessage.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
|
|
126
|
+
if (streamError) {
|
|
127
|
+
streamErrorOccurred = true;
|
|
128
|
+
console.error(`Stream error: ${streamError.message}`);
|
|
129
|
+
incomingMessage.off('data', processData); // Stop listening to 'data'
|
|
130
|
+
return;
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
try {
|
|
135
|
+
//console.log(`Publishing stream message to requestId ${this.requestId}`, message);
|
|
136
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
137
|
+
requestProgress: requestProgress
|
|
138
|
+
});
|
|
139
|
+
} catch (error) {
|
|
140
|
+
console.error('Could not publish the stream message', message, error);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
108
144
|
} catch (error) {
|
|
109
|
-
console.error('Could not
|
|
145
|
+
console.error('Could not process stream data', error);
|
|
110
146
|
}
|
|
111
|
-
}
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
147
|
+
};
|
|
148
|
+
|
|
149
|
+
if (incomingMessage) {
|
|
150
|
+
await new Promise((resolve, reject) => {
|
|
151
|
+
incomingMessage.on('data', processData);
|
|
152
|
+
incomingMessage.on('end', resolve);
|
|
153
|
+
incomingMessage.on('error', reject);
|
|
154
|
+
});
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
} catch (error) {
|
|
158
|
+
console.error('Could not subscribe to stream', error);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
if (streamErrorOccurred) {
|
|
163
|
+
attempt++;
|
|
164
|
+
console.error(`Stream attempt ${attempt} failed. Retrying...`);
|
|
165
|
+
streamErrorOccurred = false; // Reset the flag for the next attempt
|
|
166
|
+
} else {
|
|
167
|
+
return;
|
|
115
168
|
}
|
|
116
169
|
}
|
|
170
|
+
// if all retries failed, publish the stream end message
|
|
171
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
172
|
+
requestProgress: {
|
|
173
|
+
requestId: this.requestId,
|
|
174
|
+
progress: 1,
|
|
175
|
+
data: '[DONE]',
|
|
176
|
+
}
|
|
177
|
+
});
|
|
117
178
|
}
|
|
118
179
|
|
|
119
180
|
async resolve(args) {
|
|
@@ -127,6 +188,15 @@ class PathwayResolver {
|
|
|
127
188
|
}
|
|
128
189
|
else {
|
|
129
190
|
// Syncronously process the request
|
|
191
|
+
return await this.executePathway(args);
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
async executePathway(args) {
|
|
196
|
+
if (this.pathway.executePathway && typeof this.pathway.executePathway === 'function') {
|
|
197
|
+
return await this.pathway.executePathway({ args, runAllPrompts: this.promptAndParse.bind(this) });
|
|
198
|
+
}
|
|
199
|
+
else {
|
|
130
200
|
return await this.promptAndParse(args);
|
|
131
201
|
}
|
|
132
202
|
}
|
|
@@ -167,7 +237,7 @@ class PathwayResolver {
|
|
|
167
237
|
} else {
|
|
168
238
|
chunkTokenLength = this.chunkMaxTokenLength;
|
|
169
239
|
}
|
|
170
|
-
const encoded = encode(text);
|
|
240
|
+
const encoded = text ? encode(text) : [];
|
|
171
241
|
if (!this.useInputChunking || encoded.length <= chunkTokenLength) { // no chunking, return as is
|
|
172
242
|
if (encoded.length > 0 && encoded.length >= chunkTokenLength) {
|
|
173
243
|
const warnText = `Truncating long input text. Text length: ${text.length}`;
|
|
@@ -275,8 +345,11 @@ class PathwayResolver {
|
|
|
275
345
|
previousResult = this.truncate(previousResult, this.chunkMaxTokenLength);
|
|
276
346
|
result = await Promise.all(chunks.map(chunk =>
|
|
277
347
|
this.applyPrompt(this.prompts[i], chunk, currentParameters)));
|
|
278
|
-
|
|
279
|
-
|
|
348
|
+
|
|
349
|
+
if (result.length === 1) {
|
|
350
|
+
result = result[0];
|
|
351
|
+
} else if (!currentParameters.stream) {
|
|
352
|
+
result = result.join("\n\n");
|
|
280
353
|
}
|
|
281
354
|
}
|
|
282
355
|
|
|
@@ -1,29 +1,25 @@
|
|
|
1
|
-
import { parseNumberedList, parseNumberedObjectList, parseCommaSeparatedList } from './parser.js';
|
|
1
|
+
import { parseNumberedList, parseNumberedObjectList, parseCommaSeparatedList, isCommaSeparatedList, isNumberedList } from './parser.js';
|
|
2
2
|
|
|
3
3
|
class PathwayResponseParser {
|
|
4
4
|
constructor(pathway) {
|
|
5
5
|
this.pathway = pathway;
|
|
6
6
|
}
|
|
7
7
|
|
|
8
|
-
isCommaSeparatedList(data) {
|
|
9
|
-
const commaSeparatedPattern = /^([^,\n]+,)+[^,\n]+$/;
|
|
10
|
-
return commaSeparatedPattern.test(data.trim());
|
|
11
|
-
}
|
|
12
|
-
|
|
13
8
|
parse(data) {
|
|
14
9
|
if (this.pathway.parser) {
|
|
15
10
|
return this.pathway.parser(data);
|
|
16
11
|
}
|
|
17
12
|
|
|
18
13
|
if (this.pathway.list) {
|
|
19
|
-
if (
|
|
20
|
-
return parseCommaSeparatedList(data);
|
|
21
|
-
} else {
|
|
14
|
+
if (isNumberedList(data)) {
|
|
22
15
|
if (this.pathway.format) {
|
|
23
16
|
return parseNumberedObjectList(data, this.pathway.format);
|
|
24
17
|
}
|
|
25
18
|
return parseNumberedList(data);
|
|
19
|
+
} else if (isCommaSeparatedList(data)) {
|
|
20
|
+
return parseCommaSeparatedList(data);
|
|
26
21
|
}
|
|
22
|
+
return [data];
|
|
27
23
|
}
|
|
28
24
|
|
|
29
25
|
return data;
|
|
@@ -26,7 +26,7 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
26
26
|
// Execute the request to the Azure Translate API
|
|
27
27
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
28
28
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
29
|
-
const requestId = pathwayResolver
|
|
29
|
+
const { requestId, pathway} = pathwayResolver;
|
|
30
30
|
|
|
31
31
|
const url = this.requestUrl(text);
|
|
32
32
|
|
|
@@ -34,7 +34,7 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
34
34
|
const params = requestParameters.params;
|
|
35
35
|
const headers = this.model.headers || {};
|
|
36
36
|
|
|
37
|
-
return this.executeRequest(url, data, params, headers, prompt, requestId);
|
|
37
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
38
38
|
}
|
|
39
39
|
|
|
40
40
|
// Parse the response from the Azure Translate API
|
|
@@ -55,7 +55,7 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
55
55
|
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
56
56
|
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
57
57
|
|
|
58
|
-
prompt && prompt.debugInfo && (prompt.debugInfo +=
|
|
58
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
|
|
59
59
|
}
|
|
60
60
|
}
|
|
61
61
|
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
// CohereGeneratePlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
|
|
4
|
+
class CohereGeneratePlugin extends ModelPlugin {
|
|
5
|
+
constructor(config, pathway, modelName, model) {
|
|
6
|
+
super(config, pathway, modelName, model);
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
// Set up parameters specific to the Cohere API
|
|
10
|
+
getRequestParameters(text, parameters, prompt) {
|
|
11
|
+
const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
12
|
+
|
|
13
|
+
// Define the model's max token length
|
|
14
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
15
|
+
|
|
16
|
+
// Check if the token length exceeds the model's max token length
|
|
17
|
+
if (tokenLength > modelTargetTokenLength) {
|
|
18
|
+
// Truncate the prompt text to fit within the token length
|
|
19
|
+
modelPromptText = modelPromptText.substring(0, modelTargetTokenLength);
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
const requestParameters = {
|
|
23
|
+
model: "command",
|
|
24
|
+
prompt: modelPromptText,
|
|
25
|
+
max_tokens: this.getModelMaxReturnTokens(),
|
|
26
|
+
temperature: this.temperature ?? 0.7,
|
|
27
|
+
k: 0,
|
|
28
|
+
stop_sequences: parameters.stop_sequences || [],
|
|
29
|
+
return_likelihoods: parameters.return_likelihoods || "NONE"
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
return requestParameters;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// Execute the request to the Cohere API
|
|
36
|
+
async execute(text, parameters, prompt, pathwayResolver) {
|
|
37
|
+
const url = this.requestUrl();
|
|
38
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
39
|
+
const { requestId, pathway} = pathwayResolver;
|
|
40
|
+
|
|
41
|
+
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
42
|
+
const params = {};
|
|
43
|
+
const headers = {
|
|
44
|
+
...this.model.headers || {}
|
|
45
|
+
};
|
|
46
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
// Parse the response from the Cohere API
|
|
50
|
+
parseResponse(data) {
|
|
51
|
+
const { generations } = data;
|
|
52
|
+
if (!generations || !generations.length) {
|
|
53
|
+
return data;
|
|
54
|
+
}
|
|
55
|
+
// Return the text of the first generation
|
|
56
|
+
return generations[0].text || null;
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
export default CohereGeneratePlugin;
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
// CohereSummarizePlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
|
|
4
|
+
class CohereSummarizePlugin extends ModelPlugin {
|
|
5
|
+
constructor(config, pathway, modelName, model) {
|
|
6
|
+
super(config, pathway, modelName, model);
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
// Set up parameters specific to the Cohere Summarize API
|
|
10
|
+
getRequestParameters(text, parameters, prompt) {
|
|
11
|
+
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
12
|
+
|
|
13
|
+
const requestParameters = {
|
|
14
|
+
length: parameters.length || "medium",
|
|
15
|
+
format: parameters.format || "paragraph",
|
|
16
|
+
model: "summarize-xlarge",
|
|
17
|
+
extractiveness: parameters.extractiveness || "low",
|
|
18
|
+
temperature: this.temperature ?? 0.3,
|
|
19
|
+
text: modelPromptText
|
|
20
|
+
};
|
|
21
|
+
|
|
22
|
+
return requestParameters;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// Execute the request to the Cohere Summarize API
|
|
26
|
+
async execute(text, parameters, prompt, pathwayResolver) {
|
|
27
|
+
const url = this.requestUrl();
|
|
28
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
29
|
+
const { requestId, pathway} = pathwayResolver;
|
|
30
|
+
|
|
31
|
+
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
32
|
+
const params = {};
|
|
33
|
+
const headers = {
|
|
34
|
+
...this.model.headers || {}
|
|
35
|
+
};
|
|
36
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// Parse the response from the Cohere Summarize API
|
|
40
|
+
parseResponse(data) {
|
|
41
|
+
const { summary } = data;
|
|
42
|
+
if (!summary) {
|
|
43
|
+
return data;
|
|
44
|
+
}
|
|
45
|
+
// Return the summary
|
|
46
|
+
return summary;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
export default CohereSummarizePlugin;
|
|
@@ -22,11 +22,11 @@ class ModelPlugin {
|
|
|
22
22
|
|
|
23
23
|
// Make all of the parameters defined on the pathway itself available to the prompt
|
|
24
24
|
for (const [k, v] of Object.entries(pathway)) {
|
|
25
|
-
this.promptParameters[k] = v
|
|
25
|
+
this.promptParameters[k] = v?.default ?? v;
|
|
26
26
|
}
|
|
27
27
|
if (pathway.inputParameters) {
|
|
28
28
|
for (const [k, v] of Object.entries(pathway.inputParameters)) {
|
|
29
|
-
this.promptParameters[k] = v
|
|
29
|
+
this.promptParameters[k] = v?.default ?? v;
|
|
30
30
|
}
|
|
31
31
|
}
|
|
32
32
|
|
|
@@ -121,7 +121,16 @@ class ModelPlugin {
|
|
|
121
121
|
|
|
122
122
|
// compile the Prompt
|
|
123
123
|
getCompiledPrompt(text, parameters, prompt) {
|
|
124
|
-
|
|
124
|
+
|
|
125
|
+
const mergeParameters = (promptParameters, parameters) => {
|
|
126
|
+
let result = { ...promptParameters };
|
|
127
|
+
for (let key in parameters) {
|
|
128
|
+
if (parameters[key] !== null) result[key] = parameters[key];
|
|
129
|
+
}
|
|
130
|
+
return result;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
const combinedParameters = mergeParameters(this.promptParameters, parameters);
|
|
125
134
|
const modelPrompt = this.getModelPrompt(prompt, parameters);
|
|
126
135
|
const modelPromptText = modelPrompt.prompt ? HandleBars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
|
|
127
136
|
const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
|
|
@@ -197,24 +206,25 @@ class ModelPlugin {
|
|
|
197
206
|
}
|
|
198
207
|
|
|
199
208
|
// Default response parsing
|
|
200
|
-
parseResponse(data) { return data; }
|
|
209
|
+
parseResponse(data) { return data; }
|
|
201
210
|
|
|
202
211
|
// Default simple logging
|
|
203
|
-
logRequestStart(url,
|
|
212
|
+
logRequestStart(url, _data) {
|
|
204
213
|
this.requestCount++;
|
|
214
|
+
this.lastRequestStartTime = new Date();
|
|
205
215
|
const logMessage = `>>> [${this.requestId}: ${this.pathwayName}.${this.requestCount}] request`;
|
|
206
216
|
const header = '>'.repeat(logMessage.length);
|
|
207
217
|
console.log(`\n${header}\n${logMessage}`);
|
|
208
218
|
console.log(`>>> Making API request to ${url}`);
|
|
209
|
-
}
|
|
219
|
+
}
|
|
210
220
|
|
|
211
221
|
logAIRequestFinished() {
|
|
212
222
|
const currentTime = new Date();
|
|
213
223
|
const timeElapsed = (currentTime - this.lastRequestStartTime) / 1000;
|
|
214
|
-
const logMessage = `<<< [${this.requestId}: ${this.pathwayName}
|
|
224
|
+
const logMessage = `<<< [${this.requestId}: ${this.pathwayName}] response - complete in ${timeElapsed}s - data:`;
|
|
215
225
|
const header = '<'.repeat(logMessage.length);
|
|
216
226
|
console.log(`\n${header}\n${logMessage}\n`);
|
|
217
|
-
}
|
|
227
|
+
}
|
|
218
228
|
|
|
219
229
|
logRequestData(data, responseData, prompt) {
|
|
220
230
|
this.logAIRequestFinished();
|
|
@@ -226,21 +236,26 @@ class ModelPlugin {
|
|
|
226
236
|
|
|
227
237
|
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
228
238
|
|
|
229
|
-
prompt && prompt.debugInfo && (prompt.debugInfo +=
|
|
239
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
|
|
230
240
|
}
|
|
231
241
|
|
|
232
|
-
async executeRequest(url, data, params, headers, prompt, requestId) {
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
242
|
+
async executeRequest(url, data, params, headers, prompt, requestId, pathway) {
|
|
243
|
+
try {
|
|
244
|
+
this.aiRequestStartTime = new Date();
|
|
245
|
+
this.requestId = requestId;
|
|
246
|
+
this.logRequestStart(url, data);
|
|
247
|
+
const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName, this.requestId, pathway);
|
|
248
|
+
|
|
249
|
+
if (responseData.error) {
|
|
250
|
+
throw new Error(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
|
|
251
|
+
}
|
|
237
252
|
|
|
238
|
-
|
|
239
|
-
|
|
253
|
+
this.logRequestData(data, responseData, prompt);
|
|
254
|
+
return this.parseResponse(responseData);
|
|
255
|
+
} catch (error) {
|
|
256
|
+
// Log the error and continue
|
|
257
|
+
console.error(error);
|
|
240
258
|
}
|
|
241
|
-
|
|
242
|
-
this.logRequestData(data, responseData, prompt);
|
|
243
|
-
return this.parseResponse(responseData);
|
|
244
259
|
}
|
|
245
260
|
|
|
246
261
|
}
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
// OpenAIChatPlugin.js
|
|
2
|
+
import OpenAIChatPlugin from './openAiChatPlugin.js';
|
|
3
|
+
|
|
4
|
+
class OpenAIChatExtensionPlugin extends OpenAIChatPlugin {
|
|
5
|
+
constructor(config, pathway, modelName, model) {
|
|
6
|
+
super(config, pathway, modelName, model);
|
|
7
|
+
this.tool = '';
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
// Parse the response from the OpenAI Extension API
|
|
11
|
+
parseResponse(data) {
|
|
12
|
+
const { choices } = data;
|
|
13
|
+
if (!choices || !choices.length) {
|
|
14
|
+
return data;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
// if we got a choices array back with more than one choice, return the whole array
|
|
18
|
+
if (choices.length > 1) {
|
|
19
|
+
return choices;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// otherwise, return the first choice messages based on role
|
|
23
|
+
const messageResult = [];
|
|
24
|
+
for(const message of choices[0].messages) {
|
|
25
|
+
if(message.role === "tool"){
|
|
26
|
+
this.tool = message.content;
|
|
27
|
+
}else{
|
|
28
|
+
messageResult.push(message.content);
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
return messageResult.join('\n\n') ?? null;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
// Set up parameters specific to the OpenAI Chat API
|
|
35
|
+
getRequestParameters(text, parameters, prompt) {
|
|
36
|
+
const reqParams = super.getRequestParameters(text, parameters, prompt);
|
|
37
|
+
reqParams.dataSources = this.model.dataSources || reqParams.dataSources || []; // add dataSources to the request parameters
|
|
38
|
+
const {roleInformation, indexName } = parameters; // add roleInformation and indexName to the dataSource if given
|
|
39
|
+
for(const dataSource of reqParams.dataSources) {
|
|
40
|
+
if(!dataSource) continue;
|
|
41
|
+
if(!dataSource.parameters) dataSource.parameters = {};
|
|
42
|
+
roleInformation && (dataSource.parameters.roleInformation = roleInformation);
|
|
43
|
+
indexName && (dataSource.parameters.indexName = indexName);
|
|
44
|
+
}
|
|
45
|
+
return reqParams;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
async execute(text, parameters, prompt, pathwayResolver) {
|
|
49
|
+
const result = await super.execute(text, parameters, prompt, pathwayResolver);
|
|
50
|
+
pathwayResolver.tool = this.tool; // add tool info back
|
|
51
|
+
return result;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
export default OpenAIChatExtensionPlugin;
|
|
@@ -79,12 +79,12 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
79
79
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
80
80
|
const url = this.requestUrl(text);
|
|
81
81
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
82
|
-
const requestId = pathwayResolver
|
|
82
|
+
const { requestId, pathway} = pathwayResolver;
|
|
83
83
|
|
|
84
84
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
85
|
-
const params = {};
|
|
85
|
+
const params = {}; // query params
|
|
86
86
|
const headers = this.model.headers || {};
|
|
87
|
-
return this.executeRequest(url, data, params, headers, prompt, requestId);
|
|
87
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
88
88
|
}
|
|
89
89
|
|
|
90
90
|
// Parse the response from the OpenAI Chat API
|
|
@@ -122,12 +122,12 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
122
122
|
}
|
|
123
123
|
|
|
124
124
|
if (stream) {
|
|
125
|
-
console.log(`\x1b[34m>
|
|
125
|
+
console.log(`\x1b[34m> [response is an SSE stream]\x1b[0m`);
|
|
126
126
|
} else {
|
|
127
127
|
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
128
128
|
}
|
|
129
129
|
|
|
130
|
-
prompt && prompt.debugInfo && (prompt.debugInfo +=
|
|
130
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
|
|
131
131
|
}
|
|
132
132
|
}
|
|
133
133
|
|
|
@@ -79,13 +79,13 @@ class OpenAICompletionPlugin extends ModelPlugin {
|
|
|
79
79
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
80
80
|
const url = this.requestUrl(text);
|
|
81
81
|
const requestParameters = this.getRequestParameters(text, parameters, prompt, pathwayResolver);
|
|
82
|
-
const requestId = pathwayResolver
|
|
82
|
+
const { requestId, pathway} = pathwayResolver;
|
|
83
83
|
|
|
84
84
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
85
85
|
const params = {};
|
|
86
86
|
const headers = this.model.headers || {};
|
|
87
87
|
|
|
88
|
-
return this.executeRequest(url, data, params, headers, prompt, requestId);
|
|
88
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
89
89
|
}
|
|
90
90
|
|
|
91
91
|
// Parse the response from the OpenAI Completion API
|
|
@@ -115,12 +115,12 @@ class OpenAICompletionPlugin extends ModelPlugin {
|
|
|
115
115
|
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
116
116
|
|
|
117
117
|
if (stream) {
|
|
118
|
-
console.log(`\x1b[34m>
|
|
118
|
+
console.log(`\x1b[34m> [response is an SSE stream]\x1b[0m`);
|
|
119
119
|
} else {
|
|
120
120
|
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
121
121
|
}
|
|
122
122
|
|
|
123
|
-
prompt && prompt.debugInfo && (prompt.debugInfo +=
|
|
123
|
+
prompt && prompt.debugInfo && (prompt.debugInfo += `\n${JSON.stringify(data)}`);
|
|
124
124
|
}
|
|
125
125
|
}
|
|
126
126
|
|
|
@@ -129,7 +129,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
129
129
|
|
|
130
130
|
try {
|
|
131
131
|
// const res = await axios.post(WHISPER_TS_API_URL, { params: { fileurl: uri } });
|
|
132
|
-
const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri},{},{});
|
|
132
|
+
const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri}, {}, {}, {}, requestId, pathway);
|
|
133
133
|
return res;
|
|
134
134
|
} catch (err) {
|
|
135
135
|
console.log(`Error getting word timestamped data from api:`, err);
|
|
@@ -150,7 +150,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
150
150
|
language && formData.append('language', language);
|
|
151
151
|
modelPromptText && formData.append('prompt', modelPromptText);
|
|
152
152
|
|
|
153
|
-
return this.executeRequest(url, formData, params, { ...this.model.headers, ...formData.getHeaders() });
|
|
153
|
+
return this.executeRequest(url, formData, params, { ...this.model.headers, ...formData.getHeaders() }, {}, requestId, pathway);
|
|
154
154
|
} catch (err) {
|
|
155
155
|
console.log(err);
|
|
156
156
|
throw err;
|
|
@@ -161,7 +161,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
161
161
|
let { file } = parameters;
|
|
162
162
|
let totalCount = 0;
|
|
163
163
|
let completedCount = 0;
|
|
164
|
-
const { requestId } = pathwayResolver;
|
|
164
|
+
const { requestId, pathway } = pathwayResolver;
|
|
165
165
|
|
|
166
166
|
const sendProgress = () => {
|
|
167
167
|
completedCount++;
|
|
@@ -140,7 +140,7 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
140
140
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
141
141
|
const url = this.requestUrl(text);
|
|
142
142
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
143
|
-
const requestId = pathwayResolver
|
|
143
|
+
const { requestId, pathway} = pathwayResolver;
|
|
144
144
|
|
|
145
145
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
146
146
|
const params = {};
|
|
@@ -148,7 +148,7 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
148
148
|
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
149
149
|
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
150
150
|
headers.Authorization = `Bearer ${authToken}`;
|
|
151
|
-
return this.executeRequest(url, data, params, headers, prompt, requestId);
|
|
151
|
+
return this.executeRequest(url, data, params, headers, prompt, requestId, pathway);
|
|
152
152
|
}
|
|
153
153
|
|
|
154
154
|
// Parse the response from the PaLM Chat API
|
|
@@ -219,7 +219,7 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
219
219
|
}
|
|
220
220
|
|
|
221
221
|
if (prompt && prompt.debugInfo) {
|
|
222
|
-
prompt.debugInfo +=
|
|
222
|
+
prompt.debugInfo += `\n${JSON.stringify(data)}`;
|
|
223
223
|
}
|
|
224
224
|
}
|
|
225
225
|
}
|