@aj-archipelago/cortex 0.0.5 → 0.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +108 -72
- package/config.js +25 -0
- package/graphql/graphql.js +56 -13
- package/graphql/pathwayPrompter.js +10 -6
- package/graphql/pathwayResolver.js +128 -63
- package/graphql/plugins/azureTranslatePlugin.js +16 -8
- package/graphql/plugins/modelPlugin.js +67 -9
- package/graphql/plugins/openAiChatPlugin.js +34 -7
- package/graphql/plugins/openAiCompletionPlugin.js +53 -33
- package/graphql/plugins/openAiWhisperPlugin.js +79 -0
- package/graphql/prompt.js +1 -0
- package/graphql/requestState.js +5 -0
- package/graphql/resolver.js +8 -8
- package/graphql/subscriptions.js +15 -2
- package/graphql/typeDef.js +47 -38
- package/lib/fileChunker.js +152 -0
- package/lib/request.js +65 -8
- package/lib/requestMonitor.js +43 -0
- package/package.json +18 -6
- package/pathways/basePathway.js +3 -4
- package/pathways/bias.js +7 -0
- package/pathways/chat.js +4 -1
- package/pathways/complete.js +4 -0
- package/pathways/edit.js +6 -0
- package/pathways/entities.js +12 -0
- package/pathways/index.js +1 -1
- package/pathways/paraphrase.js +4 -0
- package/pathways/sentiment.js +5 -1
- package/pathways/summary.js +25 -8
- package/pathways/transcribe.js +8 -0
- package/pathways/translate.js +10 -1
- package/tests/chunking.test.js +5 -0
- package/tests/main.test.js +5 -13
- package/tests/translate.test.js +5 -0
- package/pathways/topics.js +0 -9
|
@@ -8,19 +8,20 @@ const { getFirstNToken, getLastNToken, getSemanticChunks } = require('./chunker'
|
|
|
8
8
|
const { PathwayResponseParser } = require('./pathwayResponseParser');
|
|
9
9
|
const { Prompt } = require('./prompt');
|
|
10
10
|
const { getv, setv } = require('../lib/keyValueStorageClient');
|
|
11
|
+
const { requestState } = require('./requestState');
|
|
11
12
|
|
|
12
13
|
const MAX_PREVIOUS_RESULT_TOKEN_LENGTH = 1000;
|
|
13
14
|
|
|
14
|
-
const callPathway = async (config, pathwayName, requestState, { text, ...parameters }) => {
|
|
15
|
-
const pathwayResolver = new PathwayResolver({ config, pathway: config.get(`pathways.${pathwayName}`), requestState });
|
|
15
|
+
const callPathway = async (config, pathwayName, args, requestState, { text, ...parameters }) => {
|
|
16
|
+
const pathwayResolver = new PathwayResolver({ config, pathway: config.get(`pathways.${pathwayName}`), args, requestState });
|
|
16
17
|
return await pathwayResolver.resolve({ text, ...parameters });
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
class PathwayResolver {
|
|
20
|
-
constructor({ config, pathway,
|
|
21
|
+
constructor({ config, pathway, args }) {
|
|
21
22
|
this.config = config;
|
|
22
|
-
this.requestState = requestState;
|
|
23
23
|
this.pathway = pathway;
|
|
24
|
+
this.args = args;
|
|
24
25
|
this.useInputChunking = pathway.useInputChunking;
|
|
25
26
|
this.chunkMaxTokenLength = 0;
|
|
26
27
|
this.warnings = [];
|
|
@@ -29,38 +30,89 @@ class PathwayResolver {
|
|
|
29
30
|
this.pathwayPrompter = new PathwayPrompter({ config, pathway });
|
|
30
31
|
this.previousResult = '';
|
|
31
32
|
this.prompts = [];
|
|
32
|
-
this._pathwayPrompt = '';
|
|
33
33
|
|
|
34
34
|
Object.defineProperty(this, 'pathwayPrompt', {
|
|
35
35
|
get() {
|
|
36
|
-
return this.
|
|
36
|
+
return this.prompts
|
|
37
37
|
},
|
|
38
38
|
set(value) {
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
this._pathwayPrompt = [this._pathwayPrompt];
|
|
39
|
+
if (!Array.isArray(value)) {
|
|
40
|
+
value = [value];
|
|
42
41
|
}
|
|
43
|
-
this.prompts =
|
|
42
|
+
this.prompts = value.map(p => (p instanceof Prompt) ? p : new Prompt({ prompt:p }));
|
|
44
43
|
this.chunkMaxTokenLength = this.getChunkMaxTokenLength();
|
|
45
44
|
}
|
|
46
45
|
});
|
|
47
46
|
|
|
47
|
+
// set up initial prompt
|
|
48
48
|
this.pathwayPrompt = pathway.prompt;
|
|
49
49
|
}
|
|
50
50
|
|
|
51
|
-
async
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
async asyncResolve(args) {
|
|
52
|
+
// Wait with a sleep promise for the race condition to resolve
|
|
53
|
+
// const results = await Promise.all([this.promptAndParse(args), await new Promise(resolve => setTimeout(resolve, 250))]);
|
|
54
|
+
const data = await this.promptAndParse(args);
|
|
55
|
+
// Process the results for async
|
|
56
|
+
if(args.async || typeof data === 'string') { // if async flag set or processed async and got string response
|
|
57
|
+
const { completedCount, totalCount } = requestState[this.requestId];
|
|
58
|
+
requestState[this.requestId].data = data;
|
|
59
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
60
|
+
requestProgress: {
|
|
61
|
+
requestId: this.requestId,
|
|
62
|
+
progress: completedCount / totalCount,
|
|
63
|
+
data: JSON.stringify(data),
|
|
64
|
+
}
|
|
65
|
+
});
|
|
66
|
+
} else { //stream
|
|
67
|
+
for (const handle of data) {
|
|
68
|
+
handle.on('data', data => {
|
|
69
|
+
console.log(data.toString());
|
|
70
|
+
const lines = data.toString().split('\n').filter(line => line.trim() !== '');
|
|
71
|
+
for (const line of lines) {
|
|
72
|
+
const message = line.replace(/^data: /, '');
|
|
73
|
+
if (message === '[DONE]') {
|
|
74
|
+
// Send stream finished message
|
|
75
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
76
|
+
requestProgress: {
|
|
77
|
+
requestId: this.requestId,
|
|
78
|
+
data: null,
|
|
79
|
+
progress: 1,
|
|
80
|
+
}
|
|
81
|
+
});
|
|
82
|
+
return; // Stream finished
|
|
83
|
+
}
|
|
84
|
+
try {
|
|
85
|
+
const parsed = JSON.parse(message);
|
|
86
|
+
const result = this.pathwayPrompter.plugin.parseResponse(parsed)
|
|
87
|
+
|
|
88
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
89
|
+
requestProgress: {
|
|
90
|
+
requestId: this.requestId,
|
|
91
|
+
data: JSON.stringify(result)
|
|
92
|
+
}
|
|
93
|
+
});
|
|
94
|
+
} catch (error) {
|
|
95
|
+
console.error('Could not JSON parse stream message', message, error);
|
|
96
|
+
}
|
|
60
97
|
}
|
|
61
98
|
});
|
|
62
|
-
});
|
|
63
99
|
|
|
100
|
+
// data.on('end', () => {
|
|
101
|
+
// console.log("stream done");
|
|
102
|
+
// });
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
async resolve(args) {
|
|
109
|
+
if (args.async || args.stream) {
|
|
110
|
+
// Asyncronously process the request
|
|
111
|
+
// this.asyncResolve(args);
|
|
112
|
+
if (!requestState[this.requestId]) {
|
|
113
|
+
requestState[this.requestId] = {}
|
|
114
|
+
}
|
|
115
|
+
requestState[this.requestId] = { ...requestState[this.requestId], args, resolver: this.asyncResolve.bind(this) };
|
|
64
116
|
return this.requestId;
|
|
65
117
|
}
|
|
66
118
|
else {
|
|
@@ -70,7 +122,6 @@ class PathwayResolver {
|
|
|
70
122
|
}
|
|
71
123
|
|
|
72
124
|
async promptAndParse(args) {
|
|
73
|
-
|
|
74
125
|
// Get saved context from contextId or change contextId if needed
|
|
75
126
|
const { contextId } = args;
|
|
76
127
|
this.savedContextId = contextId ? contextId : null;
|
|
@@ -94,25 +145,25 @@ class PathwayResolver {
|
|
|
94
145
|
|
|
95
146
|
// Here we choose how to handle long input - either summarize or chunk
|
|
96
147
|
processInputText(text) {
|
|
97
|
-
let
|
|
148
|
+
let chunkTokenLength = 0;
|
|
98
149
|
if (this.pathway.inputChunkSize) {
|
|
99
|
-
|
|
150
|
+
chunkTokenLength = Math.min(this.pathway.inputChunkSize, this.chunkMaxTokenLength);
|
|
100
151
|
} else {
|
|
101
|
-
|
|
152
|
+
chunkTokenLength = this.chunkMaxTokenLength;
|
|
102
153
|
}
|
|
103
154
|
const encoded = encode(text);
|
|
104
|
-
if (!this.useInputChunking || encoded.length <=
|
|
105
|
-
if (encoded.length >=
|
|
106
|
-
const warnText = `
|
|
155
|
+
if (!this.useInputChunking || encoded.length <= chunkTokenLength) { // no chunking, return as is
|
|
156
|
+
if (encoded.length >= chunkTokenLength) {
|
|
157
|
+
const warnText = `Truncating long input text. Text length: ${text.length}`;
|
|
107
158
|
this.warnings.push(warnText);
|
|
108
159
|
console.warn(warnText);
|
|
109
|
-
text = this.truncate(text,
|
|
160
|
+
text = this.truncate(text, chunkTokenLength);
|
|
110
161
|
}
|
|
111
162
|
return [text];
|
|
112
163
|
}
|
|
113
164
|
|
|
114
165
|
// chunk the text and return the chunks with newline separators
|
|
115
|
-
return getSemanticChunks({ text, maxChunkToken:
|
|
166
|
+
return getSemanticChunks({ text, maxChunkToken: chunkTokenLength });
|
|
116
167
|
}
|
|
117
168
|
|
|
118
169
|
truncate(str, n) {
|
|
@@ -124,7 +175,7 @@ class PathwayResolver {
|
|
|
124
175
|
|
|
125
176
|
async summarizeIfEnabled({ text, ...parameters }) {
|
|
126
177
|
if (this.pathway.useInputSummarization) {
|
|
127
|
-
return await callPathway(this.config, 'summary', this.requestState, { text, targetLength: 1000, ...parameters });
|
|
178
|
+
return await callPathway(this.config, 'summary', this.args, requestState, { text, targetLength: 1000, ...parameters });
|
|
128
179
|
}
|
|
129
180
|
return text;
|
|
130
181
|
}
|
|
@@ -132,46 +183,44 @@ class PathwayResolver {
|
|
|
132
183
|
// Calculate the maximum token length for a chunk
|
|
133
184
|
getChunkMaxTokenLength() {
|
|
134
185
|
// find the longest prompt
|
|
135
|
-
const maxPromptTokenLength = Math.max(...this.prompts.map((
|
|
136
|
-
|
|
137
|
-
return (role && content) ? acc + encode(role).length + encode(content).length : acc;
|
|
138
|
-
}, 0) : 0));
|
|
139
|
-
|
|
140
|
-
const maxTokenLength = Math.max(maxPromptTokenLength, maxMessagesTokenLength);
|
|
141
|
-
|
|
186
|
+
const maxPromptTokenLength = Math.max(...this.prompts.map((promptData) => this.pathwayPrompter.plugin.getCompiledPrompt('', this.args, promptData).tokenLength));
|
|
187
|
+
|
|
142
188
|
// find out if any prompts use both text input and previous result
|
|
143
|
-
const hasBothProperties = this.prompts.some(prompt => prompt.
|
|
189
|
+
const hasBothProperties = this.prompts.some(prompt => prompt.usesTextInput && prompt.usesPreviousResult);
|
|
144
190
|
|
|
145
191
|
// the token ratio is the ratio of the total prompt to the result text - both have to be included
|
|
146
192
|
// in computing the max token length
|
|
147
193
|
const promptRatio = this.pathwayPrompter.plugin.getPromptTokenRatio();
|
|
148
|
-
let
|
|
149
|
-
|
|
194
|
+
let chunkMaxTokenLength = promptRatio * this.pathwayPrompter.plugin.getModelMaxTokenLength() - maxPromptTokenLength;
|
|
195
|
+
|
|
150
196
|
// if we have to deal with prompts that have both text input
|
|
151
197
|
// and previous result, we need to split the maxChunkToken in half
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
if (maxChunkToken && maxChunkToken <= 0) {
|
|
156
|
-
throw new Error(`Your prompt is too long! Split to multiple prompts or reduce length of your prompt, prompt length: ${maxPromptLength}`);
|
|
157
|
-
}
|
|
158
|
-
return maxChunkToken;
|
|
198
|
+
chunkMaxTokenLength = hasBothProperties ? chunkMaxTokenLength / 2 : chunkMaxTokenLength;
|
|
199
|
+
|
|
200
|
+
return chunkMaxTokenLength;
|
|
159
201
|
}
|
|
160
202
|
|
|
161
203
|
// Process the request and return the result
|
|
162
204
|
async processRequest({ text, ...parameters }) {
|
|
163
|
-
|
|
164
205
|
text = await this.summarizeIfEnabled({ text, ...parameters }); // summarize if flag enabled
|
|
165
206
|
const chunks = this.processInputText(text);
|
|
166
207
|
|
|
167
208
|
const anticipatedRequestCount = chunks.length * this.prompts.length;
|
|
168
209
|
|
|
169
|
-
if ((
|
|
210
|
+
if ((requestState[this.requestId] || {}).canceled) {
|
|
170
211
|
throw new Error('Request canceled');
|
|
171
212
|
}
|
|
172
213
|
|
|
173
214
|
// Store the request state
|
|
174
|
-
|
|
215
|
+
requestState[this.requestId] = { ...requestState[this.requestId], totalCount: anticipatedRequestCount, completedCount: 0 };
|
|
216
|
+
|
|
217
|
+
if (chunks.length > 1) {
|
|
218
|
+
// stream behaves as async if there are multiple chunks
|
|
219
|
+
if (parameters.stream) {
|
|
220
|
+
parameters.async = true;
|
|
221
|
+
parameters.stream = false;
|
|
222
|
+
}
|
|
223
|
+
}
|
|
175
224
|
|
|
176
225
|
// If pre information is needed, apply current prompt with previous prompt info, only parallelize current call
|
|
177
226
|
if (this.pathway.useParallelChunkProcessing) {
|
|
@@ -189,17 +238,31 @@ class PathwayResolver {
|
|
|
189
238
|
let result = '';
|
|
190
239
|
|
|
191
240
|
for (let i = 0; i < this.prompts.length; i++) {
|
|
241
|
+
const currentParameters = { ...parameters, previousResult };
|
|
242
|
+
|
|
243
|
+
if (currentParameters.stream) { // stream special flow
|
|
244
|
+
if (i < this.prompts.length - 1) {
|
|
245
|
+
currentParameters.stream = false; // if not the last prompt then don't stream
|
|
246
|
+
}
|
|
247
|
+
else {
|
|
248
|
+
// use the stream parameter if not async
|
|
249
|
+
currentParameters.stream = currentParameters.async ? false : currentParameters.stream;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
192
253
|
// If the prompt doesn't contain {{text}} then we can skip the chunking, and also give that token space to the previous result
|
|
193
254
|
if (!this.prompts[i].usesTextInput) {
|
|
194
255
|
// Limit context to it's N + text's characters
|
|
195
256
|
previousResult = this.truncate(previousResult, 2 * this.chunkMaxTokenLength);
|
|
196
|
-
result = await this.applyPrompt(this.prompts[i], null,
|
|
257
|
+
result = await this.applyPrompt(this.prompts[i], null, currentParameters);
|
|
197
258
|
} else {
|
|
198
259
|
// Limit context to N characters
|
|
199
260
|
previousResult = this.truncate(previousResult, this.chunkMaxTokenLength);
|
|
200
261
|
result = await Promise.all(chunks.map(chunk =>
|
|
201
|
-
this.applyPrompt(this.prompts[i], chunk,
|
|
202
|
-
|
|
262
|
+
this.applyPrompt(this.prompts[i], chunk, currentParameters)));
|
|
263
|
+
if (!currentParameters.stream) {
|
|
264
|
+
result = result.join("\n\n")
|
|
265
|
+
}
|
|
203
266
|
}
|
|
204
267
|
|
|
205
268
|
// If this is any prompt other than the last, use the result as the previous context
|
|
@@ -225,20 +288,22 @@ class PathwayResolver {
|
|
|
225
288
|
}
|
|
226
289
|
|
|
227
290
|
async applyPrompt(prompt, text, parameters) {
|
|
228
|
-
if (
|
|
291
|
+
if (requestState[this.requestId].canceled) {
|
|
229
292
|
return;
|
|
230
293
|
}
|
|
231
|
-
const result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt);
|
|
232
|
-
|
|
294
|
+
const result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt, this);
|
|
295
|
+
requestState[this.requestId].completedCount++;
|
|
233
296
|
|
|
234
|
-
const { completedCount, totalCount } =
|
|
297
|
+
const { completedCount, totalCount } = requestState[this.requestId];
|
|
235
298
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
299
|
+
if (completedCount < totalCount) {
|
|
300
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
301
|
+
requestProgress: {
|
|
302
|
+
requestId: this.requestId,
|
|
303
|
+
progress: completedCount / totalCount,
|
|
304
|
+
}
|
|
305
|
+
});
|
|
306
|
+
}
|
|
242
307
|
|
|
243
308
|
if (prompt.saveResultTo) {
|
|
244
309
|
this.savedContext[prompt.saveResultTo] = result;
|
|
@@ -1,19 +1,26 @@
|
|
|
1
1
|
// AzureTranslatePlugin.js
|
|
2
2
|
const ModelPlugin = require('./modelPlugin');
|
|
3
3
|
const handlebars = require("handlebars");
|
|
4
|
+
const { encode } = require("gpt-3-encoder");
|
|
4
5
|
|
|
5
6
|
class AzureTranslatePlugin extends ModelPlugin {
|
|
6
|
-
constructor(config,
|
|
7
|
-
super(config,
|
|
7
|
+
constructor(config, pathway) {
|
|
8
|
+
super(config, pathway);
|
|
8
9
|
}
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
requestParameters(text, parameters, prompt) {
|
|
11
|
+
getCompiledPrompt(text, parameters, prompt) {
|
|
12
12
|
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
13
13
|
const modelPrompt = this.getModelPrompt(prompt, parameters);
|
|
14
14
|
const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
|
|
15
|
-
|
|
16
|
-
return {
|
|
15
|
+
|
|
16
|
+
return { modelPromptText, tokenLength: encode(modelPromptText).length };
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
// Set up parameters specific to the Azure Translate API
|
|
20
|
+
getRequestParameters(text, parameters, prompt) {
|
|
21
|
+
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
22
|
+
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
23
|
+
const requestParameters = {
|
|
17
24
|
data: [
|
|
18
25
|
{
|
|
19
26
|
Text: modelPromptText,
|
|
@@ -23,11 +30,12 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
23
30
|
to: combinedParameters.to
|
|
24
31
|
}
|
|
25
32
|
};
|
|
33
|
+
return requestParameters;
|
|
26
34
|
}
|
|
27
35
|
|
|
28
36
|
// Execute the request to the Azure Translate API
|
|
29
37
|
async execute(text, parameters, prompt) {
|
|
30
|
-
const requestParameters = this.
|
|
38
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
31
39
|
|
|
32
40
|
const url = this.requestUrl(text);
|
|
33
41
|
|
|
@@ -35,7 +43,7 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
35
43
|
const params = requestParameters.params;
|
|
36
44
|
const headers = this.model.headers || {};
|
|
37
45
|
|
|
38
|
-
return this.executeRequest(url, data, params, headers);
|
|
46
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
39
47
|
}
|
|
40
48
|
}
|
|
41
49
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
// ModelPlugin.js
|
|
2
2
|
const handlebars = require('handlebars');
|
|
3
3
|
const { request } = require("../../lib/request");
|
|
4
|
-
const {
|
|
4
|
+
const { encode } = require("gpt-3-encoder");
|
|
5
5
|
|
|
6
6
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
7
7
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
@@ -35,6 +35,42 @@ class ModelPlugin {
|
|
|
35
35
|
}
|
|
36
36
|
|
|
37
37
|
this.requestCount = 1;
|
|
38
|
+
this.shouldCache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
// Function to remove non-system messages until token length is less than target
|
|
42
|
+
removeMessagesUntilTarget = (messages, targetTokenLength) => {
|
|
43
|
+
let chatML = this.messagesToChatML(messages);
|
|
44
|
+
let tokenLength = encode(chatML).length;
|
|
45
|
+
|
|
46
|
+
while (tokenLength > targetTokenLength) {
|
|
47
|
+
for (let i = 0; i < messages.length; i++) {
|
|
48
|
+
if (messages[i].role !== 'system') {
|
|
49
|
+
messages.splice(i, 1);
|
|
50
|
+
chatML = this.messagesToChatML(messages);
|
|
51
|
+
tokenLength = encode(chatML).length;
|
|
52
|
+
break;
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
if (messages.every(message => message.role === 'system')) {
|
|
56
|
+
break; // All remaining messages are 'system', stop removing messages
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
return messages;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
//convert a messages array to a simple chatML format
|
|
63
|
+
messagesToChatML = (messages) => {
|
|
64
|
+
let output = "";
|
|
65
|
+
if (messages && messages.length) {
|
|
66
|
+
for (let message of messages) {
|
|
67
|
+
output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
|
|
68
|
+
}
|
|
69
|
+
// you always want the assistant to respond next so add a
|
|
70
|
+
// directive for that
|
|
71
|
+
output += "<|im_start|>assistant\n";
|
|
72
|
+
}
|
|
73
|
+
return output;
|
|
38
74
|
}
|
|
39
75
|
|
|
40
76
|
getModelMaxTokenLength() {
|
|
@@ -102,6 +138,8 @@ class ModelPlugin {
|
|
|
102
138
|
if (!choices || !choices.length) {
|
|
103
139
|
if (Array.isArray(data) && data.length > 0 && data[0].translations) {
|
|
104
140
|
return data[0].translations[0].text.trim();
|
|
141
|
+
} else {
|
|
142
|
+
return data;
|
|
105
143
|
}
|
|
106
144
|
}
|
|
107
145
|
|
|
@@ -114,20 +152,40 @@ class ModelPlugin {
|
|
|
114
152
|
const textResult = choices[0].text && choices[0].text.trim();
|
|
115
153
|
const messageResult = choices[0].message && choices[0].message.content && choices[0].message.content.trim();
|
|
116
154
|
|
|
117
|
-
return messageResult
|
|
155
|
+
return messageResult ?? textResult ?? null;
|
|
118
156
|
}
|
|
119
157
|
|
|
120
|
-
|
|
121
|
-
const
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
158
|
+
logRequestData(data, responseData, prompt) {
|
|
159
|
+
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
160
|
+
console.log(separator);
|
|
161
|
+
|
|
162
|
+
const modelInput = data.prompt || (data.messages && data.messages[0].content) || (data.length > 0 && data[0].Text) || null;
|
|
163
|
+
|
|
164
|
+
if (data.messages && data.messages.length > 1) {
|
|
165
|
+
data.messages.forEach((message, index) => {
|
|
166
|
+
const words = message.content.split(" ");
|
|
167
|
+
const tokenCount = encode(message.content).length;
|
|
168
|
+
const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
169
|
+
|
|
170
|
+
console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
|
|
171
|
+
});
|
|
172
|
+
} else {
|
|
173
|
+
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
174
|
+
}
|
|
175
|
+
|
|
125
176
|
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
126
|
-
|
|
177
|
+
|
|
178
|
+
prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
async executeRequest(url, data, params, headers, prompt) {
|
|
182
|
+
const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName);
|
|
183
|
+
|
|
127
184
|
if (responseData.error) {
|
|
128
185
|
throw new Exception(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
|
|
129
186
|
}
|
|
130
|
-
|
|
187
|
+
|
|
188
|
+
this.logRequestData(data, responseData, prompt);
|
|
131
189
|
return this.parseResponse(responseData);
|
|
132
190
|
}
|
|
133
191
|
|
|
@@ -1,34 +1,61 @@
|
|
|
1
1
|
// OpenAIChatPlugin.js
|
|
2
2
|
const ModelPlugin = require('./modelPlugin');
|
|
3
3
|
const handlebars = require("handlebars");
|
|
4
|
+
const { encode } = require("gpt-3-encoder");
|
|
4
5
|
|
|
5
6
|
class OpenAIChatPlugin extends ModelPlugin {
|
|
6
7
|
constructor(config, pathway) {
|
|
7
8
|
super(config, pathway);
|
|
8
9
|
}
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
requestParameters(text, parameters, prompt) {
|
|
11
|
+
getCompiledPrompt(text, parameters, prompt) {
|
|
12
12
|
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
13
13
|
const modelPrompt = this.getModelPrompt(prompt, parameters);
|
|
14
14
|
const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
|
|
15
15
|
const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
|
|
16
|
+
const modelPromptMessagesML = this.messagesToChatML(modelPromptMessages);
|
|
17
|
+
|
|
18
|
+
if (modelPromptMessagesML) {
|
|
19
|
+
return { modelPromptMessages, tokenLength: encode(modelPromptMessagesML).length };
|
|
20
|
+
} else {
|
|
21
|
+
return { modelPromptText, tokenLength: encode(modelPromptText).length };
|
|
22
|
+
}
|
|
23
|
+
}
|
|
16
24
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
25
|
+
// Set up parameters specific to the OpenAI Chat API
|
|
26
|
+
getRequestParameters(text, parameters, prompt) {
|
|
27
|
+
const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
28
|
+
const { stream } = parameters;
|
|
29
|
+
|
|
30
|
+
// Define the model's max token length
|
|
31
|
+
const modelMaxTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
32
|
+
|
|
33
|
+
let requestMessages = modelPromptMessages || [{ "role": "user", "content": modelPromptText }];
|
|
34
|
+
|
|
35
|
+
// Check if the token length exceeds the model's max token length
|
|
36
|
+
if (tokenLength > modelMaxTokenLength) {
|
|
37
|
+
// Remove older messages until the token length is within the model's limit
|
|
38
|
+
requestMessages = this.removeMessagesUntilTarget(requestMessages, modelMaxTokenLength);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
const requestParameters = {
|
|
42
|
+
messages: requestMessages,
|
|
43
|
+
temperature: this.temperature ?? 0.7,
|
|
44
|
+
stream
|
|
20
45
|
};
|
|
46
|
+
|
|
47
|
+
return requestParameters;
|
|
21
48
|
}
|
|
22
49
|
|
|
23
50
|
// Execute the request to the OpenAI Chat API
|
|
24
51
|
async execute(text, parameters, prompt) {
|
|
25
52
|
const url = this.requestUrl(text);
|
|
26
|
-
const requestParameters = this.
|
|
53
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
27
54
|
|
|
28
55
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
29
56
|
const params = {};
|
|
30
57
|
const headers = this.model.headers || {};
|
|
31
|
-
return this.executeRequest(url, data, params, headers);
|
|
58
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
32
59
|
}
|
|
33
60
|
}
|
|
34
61
|
|
|
@@ -3,61 +3,81 @@ const ModelPlugin = require('./modelPlugin');
|
|
|
3
3
|
const handlebars = require("handlebars");
|
|
4
4
|
const { encode } = require("gpt-3-encoder");
|
|
5
5
|
|
|
6
|
-
//convert a messages array to a simple chatML format
|
|
7
|
-
const messagesToChatML = (messages) => {
|
|
8
|
-
let output = "";
|
|
9
|
-
if (messages && messages.length) {
|
|
10
|
-
for (let message of messages) {
|
|
11
|
-
output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
|
|
12
|
-
}
|
|
13
|
-
// you always want the assistant to respond next so add a
|
|
14
|
-
// directive for that
|
|
15
|
-
output += "<|im_start|>assistant\n";
|
|
16
|
-
}
|
|
17
|
-
return output;
|
|
18
|
-
}
|
|
19
|
-
|
|
20
6
|
class OpenAICompletionPlugin extends ModelPlugin {
|
|
21
7
|
constructor(config, pathway) {
|
|
22
8
|
super(config, pathway);
|
|
23
9
|
}
|
|
24
10
|
|
|
25
|
-
|
|
26
|
-
requestParameters(text, parameters, prompt) {
|
|
11
|
+
getCompiledPrompt(text, parameters, prompt) {
|
|
27
12
|
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
28
13
|
const modelPrompt = this.getModelPrompt(prompt, parameters);
|
|
29
14
|
const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
|
|
30
15
|
const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
|
|
31
|
-
const modelPromptMessagesML = messagesToChatML(modelPromptMessages);
|
|
16
|
+
const modelPromptMessagesML = this.messagesToChatML(modelPromptMessages);
|
|
32
17
|
|
|
33
18
|
if (modelPromptMessagesML) {
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
};
|
|
19
|
+
return { modelPromptMessages, tokenLength: encode(modelPromptMessagesML).length };
|
|
20
|
+
} else {
|
|
21
|
+
return { modelPromptText, tokenLength: encode(modelPromptText).length };
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
// Set up parameters specific to the OpenAI Completion API
|
|
26
|
+
getRequestParameters(text, parameters, prompt) {
|
|
27
|
+
let { modelPromptMessages, modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
28
|
+
const { stream } = parameters;
|
|
29
|
+
let modelPromptMessagesML = '';
|
|
30
|
+
const modelMaxTokenLength = this.getModelMaxTokenLength();
|
|
31
|
+
let requestParameters = {};
|
|
32
|
+
|
|
33
|
+
if (modelPromptMessages) {
|
|
34
|
+
const requestMessages = this.removeMessagesUntilTarget(modelPromptMessages, modelMaxTokenLength - 1);
|
|
35
|
+
modelPromptMessagesML = this.messagesToChatML(requestMessages);
|
|
36
|
+
tokenLength = encode(modelPromptMessagesML).length;
|
|
37
|
+
|
|
38
|
+
if (tokenLength >= modelMaxTokenLength) {
|
|
39
|
+
throw new Error(`The maximum number of tokens for this model is ${modelMaxTokenLength}. Please reduce the number of messages in the prompt.`);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
const max_tokens = modelMaxTokenLength - tokenLength - 1;
|
|
43
|
+
|
|
44
|
+
requestParameters = {
|
|
45
|
+
prompt: modelPromptMessagesML,
|
|
46
|
+
max_tokens: max_tokens,
|
|
47
|
+
temperature: this.temperature ?? 0.7,
|
|
48
|
+
top_p: 0.95,
|
|
49
|
+
frequency_penalty: 0,
|
|
50
|
+
presence_penalty: 0,
|
|
51
|
+
stop: ["<|im_end|>"],
|
|
52
|
+
stream
|
|
53
|
+
};
|
|
43
54
|
} else {
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
55
|
+
if (tokenLength >= modelMaxTokenLength) {
|
|
56
|
+
throw new Error(`The maximum number of tokens for this model is ${modelMaxTokenLength}. Please reduce the length of the prompt.`);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
const max_tokens = modelMaxTokenLength - tokenLength - 1;
|
|
60
|
+
|
|
61
|
+
requestParameters = {
|
|
62
|
+
prompt: modelPromptText,
|
|
63
|
+
max_tokens: max_tokens,
|
|
64
|
+
temperature: this.temperature ?? 0.7,
|
|
65
|
+
stream
|
|
66
|
+
};
|
|
49
67
|
}
|
|
68
|
+
|
|
69
|
+
return requestParameters;
|
|
50
70
|
}
|
|
51
71
|
|
|
52
72
|
// Execute the request to the OpenAI Completion API
|
|
53
73
|
async execute(text, parameters, prompt) {
|
|
54
74
|
const url = this.requestUrl(text);
|
|
55
|
-
const requestParameters = this.
|
|
75
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
56
76
|
|
|
57
77
|
const data = { ...(this.model.params || {}), ...requestParameters };
|
|
58
78
|
const params = {};
|
|
59
79
|
const headers = this.model.headers || {};
|
|
60
|
-
return this.executeRequest(url, data, params, headers);
|
|
80
|
+
return this.executeRequest(url, data, params, headers, prompt);
|
|
61
81
|
}
|
|
62
82
|
}
|
|
63
83
|
|