@aj-archipelago/cortex 0.0.9 → 0.0.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/config.js +17 -11
- package/graphql/chunker.js +97 -107
- package/graphql/graphql.js +19 -22
- package/graphql/parser.js +1 -1
- package/graphql/pathwayPrompter.js +8 -9
- package/graphql/pathwayResolver.js +12 -14
- package/graphql/pathwayResponseParser.js +2 -2
- package/graphql/plugins/azureTranslatePlugin.js +2 -2
- package/graphql/plugins/modelPlugin.js +67 -25
- package/graphql/plugins/openAiChatPlugin.js +3 -3
- package/graphql/plugins/openAiCompletionPlugin.js +5 -4
- package/graphql/plugins/openAiWhisperPlugin.js +7 -6
- package/graphql/prompt.js +1 -1
- package/graphql/pubsub.js +2 -2
- package/graphql/requestState.js +1 -1
- package/graphql/resolver.js +4 -4
- package/graphql/subscriptions.js +5 -4
- package/graphql/typeDef.js +53 -53
- package/index.js +5 -5
- package/lib/fileChunker.js +15 -11
- package/lib/keyValueStorageClient.js +5 -5
- package/lib/promiser.js +2 -2
- package/lib/request.js +11 -9
- package/lib/requestMonitor.js +2 -2
- package/package.json +15 -5
- package/pathways/basePathway.js +5 -4
- package/pathways/bias.js +2 -2
- package/pathways/chat.js +3 -2
- package/pathways/complete.js +4 -2
- package/pathways/edit.js +3 -2
- package/pathways/entities.js +3 -2
- package/pathways/index.js +25 -12
- package/pathways/lc_test.mjs +99 -0
- package/pathways/paraphrase.js +3 -2
- package/pathways/sentiment.js +3 -2
- package/pathways/summary.js +27 -10
- package/pathways/transcribe.js +4 -2
- package/pathways/translate.js +3 -2
- package/start.js +5 -2
- package/tests/chunkfunction.test.js +125 -0
- package/tests/chunking.test.js +25 -19
- package/tests/main.test.js +52 -38
- package/tests/translate.test.js +13 -10
package/config.js
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
|
|
2
|
-
const
|
|
3
|
-
|
|
4
|
-
|
|
1
|
+
import path from 'path';
|
|
2
|
+
const __dirname = path.dirname(new URL(import.meta.url).pathname);
|
|
3
|
+
import convict from 'convict';
|
|
4
|
+
import handlebars from 'handlebars';
|
|
5
|
+
import fs from 'fs';
|
|
5
6
|
|
|
6
7
|
// Schema for config
|
|
7
8
|
var config = convict({
|
|
@@ -108,7 +109,13 @@ var config = convict({
|
|
|
108
109
|
format: String,
|
|
109
110
|
default: null,
|
|
110
111
|
env: 'CORTEX_CONFIG_FILE'
|
|
111
|
-
}
|
|
112
|
+
},
|
|
113
|
+
serpApiKey: {
|
|
114
|
+
format: String,
|
|
115
|
+
default: null,
|
|
116
|
+
env: 'SERPAPI_API_KEY',
|
|
117
|
+
sensitive: true
|
|
118
|
+
},
|
|
112
119
|
});
|
|
113
120
|
|
|
114
121
|
// Read in environment variables and set up service configuration
|
|
@@ -127,22 +134,21 @@ if (configFile && fs.existsSync(configFile)) {
|
|
|
127
134
|
}
|
|
128
135
|
}
|
|
129
136
|
|
|
130
|
-
|
|
131
137
|
// Build and load pathways to config
|
|
132
|
-
const buildPathways = (config) => {
|
|
138
|
+
const buildPathways = async (config) => {
|
|
133
139
|
const { pathwaysPath, corePathwaysPath, basePathwayPath } = config.getProperties();
|
|
134
140
|
|
|
135
141
|
// Load cortex base pathway
|
|
136
|
-
const basePathway =
|
|
142
|
+
const basePathway = await import(basePathwayPath).then(module => module.default);
|
|
137
143
|
|
|
138
144
|
// Load core pathways, default from the Cortex package
|
|
139
145
|
console.log('Loading core pathways from', corePathwaysPath)
|
|
140
|
-
let loadedPathways =
|
|
146
|
+
let loadedPathways = await import(`${corePathwaysPath}/index.js`).then(module => module);
|
|
141
147
|
|
|
142
148
|
// Load custom pathways and override core pathways if same
|
|
143
149
|
if (pathwaysPath && fs.existsSync(pathwaysPath)) {
|
|
144
150
|
console.log('Loading custom pathways from', pathwaysPath)
|
|
145
|
-
const customPathways =
|
|
151
|
+
const customPathways = await import(`${pathwaysPath}/index.js`).then(module => module);
|
|
146
152
|
loadedPathways = { ...loadedPathways, ...customPathways };
|
|
147
153
|
}
|
|
148
154
|
|
|
@@ -191,4 +197,4 @@ const buildModels = (config) => {
|
|
|
191
197
|
// TODO: Perform validation
|
|
192
198
|
// config.validate({ allowed: 'strict' });
|
|
193
199
|
|
|
194
|
-
|
|
200
|
+
export { config, buildPathways, buildModels };
|
package/graphql/chunker.js
CHANGED
|
@@ -1,21 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
const estimateCharPerToken = (text) => {
|
|
4
|
-
// check text only contains asciish characters
|
|
5
|
-
if (/^[ -~\t\n\r]+$/.test(text)) {
|
|
6
|
-
return 4;
|
|
7
|
-
}
|
|
8
|
-
return 1;
|
|
9
|
-
}
|
|
10
|
-
|
|
11
|
-
const getLastNChar = (text, maxLen) => {
|
|
12
|
-
if (text.length > maxLen) {
|
|
13
|
-
//slice text to avoid maxLen limit but keep the last n characters up to a \n or space to avoid cutting words
|
|
14
|
-
text = text.slice(-maxLen);
|
|
15
|
-
text = text.slice(text.search(/\s/) + 1);
|
|
16
|
-
}
|
|
17
|
-
return text;
|
|
18
|
-
}
|
|
1
|
+
import { encode, decode } from 'gpt-3-encoder';
|
|
19
2
|
|
|
20
3
|
const getLastNToken = (text, maxTokenLen) => {
|
|
21
4
|
const encoded = encode(text);
|
|
@@ -35,113 +18,120 @@ const getFirstNToken = (text, maxTokenLen) => {
|
|
|
35
18
|
return text;
|
|
36
19
|
}
|
|
37
20
|
|
|
38
|
-
const
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if (
|
|
56
|
-
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
21
|
+
const getSemanticChunks = (text, chunkSize) => {
|
|
22
|
+
|
|
23
|
+
const breakByRegex = (str, regex, preserveWhitespace = false) => {
|
|
24
|
+
const result = [];
|
|
25
|
+
let match;
|
|
26
|
+
|
|
27
|
+
while ((match = regex.exec(str)) !== null) {
|
|
28
|
+
const value = str.slice(0, match.index);
|
|
29
|
+
result.push(value);
|
|
30
|
+
|
|
31
|
+
if (preserveWhitespace || /\S/.test(match[0])) {
|
|
32
|
+
result.push(match[0]);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
str = str.slice(match.index + match[0].length);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
if (str) {
|
|
39
|
+
result.push(str);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
return result.filter(Boolean);
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
const breakByParagraphs = (str) => breakByRegex(str, /[\r\n]+/, true);
|
|
46
|
+
const breakBySentences = (str) => breakByRegex(str, /(?<=[.。؟!\?!\n])\s+/, true);
|
|
47
|
+
const breakByWords = (str) => breakByRegex(str, /(\s,;:.+)/);
|
|
48
|
+
|
|
49
|
+
const createChunks = (tokens) => {
|
|
50
|
+
let chunks = [];
|
|
51
|
+
let currentChunk = '';
|
|
52
|
+
|
|
53
|
+
for (const token of tokens) {
|
|
54
|
+
const currentTokenLength = encode(currentChunk + token).length;
|
|
55
|
+
if (currentTokenLength <= chunkSize) {
|
|
56
|
+
currentChunk += token;
|
|
57
|
+
} else {
|
|
58
|
+
if (currentChunk) {
|
|
59
|
+
chunks.push(currentChunk);
|
|
73
60
|
}
|
|
61
|
+
currentChunk = token;
|
|
62
|
+
}
|
|
74
63
|
}
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
for (let i = 0; enableLineChunks && i < sentenceChunks.length; i++) {
|
|
79
|
-
if (isBig(sentenceChunks[i])) { // too long, split into lines
|
|
80
|
-
newlineChunks.push(...sentenceChunks[i].split('\n'));
|
|
81
|
-
} else {
|
|
82
|
-
newlineChunks.push(sentenceChunks[i]);
|
|
83
|
-
}
|
|
64
|
+
|
|
65
|
+
if (currentChunk) {
|
|
66
|
+
chunks.push(currentChunk);
|
|
84
67
|
}
|
|
68
|
+
|
|
69
|
+
return chunks;
|
|
70
|
+
};
|
|
85
71
|
|
|
86
|
-
|
|
87
|
-
let
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
}
|
|
98
|
-
chunk += words[k] + ' ';
|
|
99
|
-
}
|
|
100
|
-
if (chunk.length > 0) {
|
|
101
|
-
chunks.push(chunk.trim());
|
|
102
|
-
}
|
|
72
|
+
const combineChunks = (chunks) => {
|
|
73
|
+
let optimizedChunks = [];
|
|
74
|
+
|
|
75
|
+
for (let i = 0; i < chunks.length; i++) {
|
|
76
|
+
if (i < chunks.length - 1) {
|
|
77
|
+
const combinedChunk = chunks[i] + chunks[i + 1];
|
|
78
|
+
const combinedLen = encode(combinedChunk).length;
|
|
79
|
+
|
|
80
|
+
if (combinedLen <= chunkSize) {
|
|
81
|
+
optimizedChunks.push(combinedChunk);
|
|
82
|
+
i += 1;
|
|
103
83
|
} else {
|
|
104
|
-
|
|
84
|
+
optimizedChunks.push(chunks[i]);
|
|
105
85
|
}
|
|
86
|
+
} else {
|
|
87
|
+
optimizedChunks.push(chunks[i]);
|
|
88
|
+
}
|
|
106
89
|
}
|
|
90
|
+
|
|
91
|
+
return optimizedChunks;
|
|
92
|
+
};
|
|
107
93
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
return finallyMergeChunks ? mergeChunks({ chunks, maxChunkLength, maxChunkToken }) : chunks;
|
|
111
|
-
}
|
|
94
|
+
const breakText = (str) => {
|
|
95
|
+
const tokenLength = encode(str).length;
|
|
112
96
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
return isBigChunk({ text, maxChunkLength, maxChunkToken });
|
|
97
|
+
if (tokenLength <= chunkSize) {
|
|
98
|
+
return [str];
|
|
116
99
|
}
|
|
117
100
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
let
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
101
|
+
const breakers = [breakByParagraphs, breakBySentences, breakByWords];
|
|
102
|
+
|
|
103
|
+
for (let i = 0; i < breakers.length; i++) {
|
|
104
|
+
const tokens = breakers[i](str);
|
|
105
|
+
if (tokens.length > 1) {
|
|
106
|
+
let chunks = createChunks(tokens);
|
|
107
|
+
chunks = combineChunks(chunks);
|
|
108
|
+
const brokenChunks = chunks.flatMap(breakText);
|
|
109
|
+
if (brokenChunks.every(chunk => encode(chunk).length <= chunkSize)) {
|
|
110
|
+
return brokenChunks;
|
|
125
111
|
}
|
|
126
|
-
|
|
127
|
-
}
|
|
128
|
-
if (chunk.length > 0) {
|
|
129
|
-
mergedChunks.push(chunk);
|
|
112
|
+
}
|
|
130
113
|
}
|
|
131
|
-
|
|
114
|
+
|
|
115
|
+
return createChunks([...str]); // Split by characters
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
return breakText(text);
|
|
132
119
|
}
|
|
133
120
|
|
|
134
121
|
|
|
135
122
|
const semanticTruncate = (text, maxLength) => {
|
|
136
|
-
|
|
137
|
-
text = getSemanticChunks({ text, maxChunkLength: maxLength })[0].slice(0, maxLength - 3).trim() + "...";
|
|
138
|
-
}
|
|
123
|
+
if (text.length <= maxLength) {
|
|
139
124
|
return text;
|
|
140
|
-
}
|
|
125
|
+
}
|
|
141
126
|
|
|
127
|
+
const truncatedText = text.slice(0, maxLength - 3).trim();
|
|
128
|
+
const lastSpaceIndex = truncatedText.lastIndexOf(" ");
|
|
142
129
|
|
|
130
|
+
return (lastSpaceIndex !== -1)
|
|
131
|
+
? truncatedText.slice(0, lastSpaceIndex) + "..."
|
|
132
|
+
: truncatedText + "...";
|
|
133
|
+
};
|
|
143
134
|
|
|
144
|
-
|
|
145
|
-
getSemanticChunks, semanticTruncate,
|
|
146
|
-
|
|
147
|
-
}
|
|
135
|
+
export {
|
|
136
|
+
getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken
|
|
137
|
+
};
|
package/graphql/graphql.js
CHANGED
|
@@ -1,23 +1,21 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import { createServer } from 'http';
|
|
2
|
+
import {
|
|
3
3
|
ApolloServerPluginDrainHttpServer,
|
|
4
4
|
ApolloServerPluginLandingPageLocalDefault,
|
|
5
|
-
}
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
const { buildPathways, buildModels } = require('../config');
|
|
20
|
-
const { requestState } = require('./requestState');
|
|
5
|
+
} from 'apollo-server-core';
|
|
6
|
+
import { makeExecutableSchema } from '@graphql-tools/schema';
|
|
7
|
+
import { WebSocketServer } from 'ws';
|
|
8
|
+
import { useServer } from 'graphql-ws/lib/use/ws';
|
|
9
|
+
import express from 'express';
|
|
10
|
+
import { ApolloServer } from 'apollo-server-express';
|
|
11
|
+
import Keyv from 'keyv';
|
|
12
|
+
import { KeyvAdapter } from '@apollo/utils.keyvadapter';
|
|
13
|
+
import responseCachePlugin from 'apollo-server-plugin-response-cache';
|
|
14
|
+
import subscriptions from './subscriptions.js';
|
|
15
|
+
import { buildLimiters } from '../lib/request.js';
|
|
16
|
+
import { cancelRequestResolver } from './resolver.js';
|
|
17
|
+
import { buildPathways, buildModels } from '../config.js';
|
|
18
|
+
import { requestState } from './requestState.js';
|
|
21
19
|
|
|
22
20
|
const getPlugins = (config) => {
|
|
23
21
|
// server plugins
|
|
@@ -134,9 +132,9 @@ const getResolvers = (config, pathways) => {
|
|
|
134
132
|
}
|
|
135
133
|
|
|
136
134
|
//graphql api build factory method
|
|
137
|
-
const build = (config) => {
|
|
135
|
+
const build = async (config) => {
|
|
138
136
|
// First perform config build
|
|
139
|
-
buildPathways(config);
|
|
137
|
+
await buildPathways(config);
|
|
140
138
|
buildModels(config);
|
|
141
139
|
|
|
142
140
|
// build api limiters
|
|
@@ -152,7 +150,6 @@ const build = (config) => {
|
|
|
152
150
|
|
|
153
151
|
const { plugins, cache } = getPlugins(config);
|
|
154
152
|
|
|
155
|
-
const { ApolloServer, gql } = require('apollo-server-express');
|
|
156
153
|
const app = express()
|
|
157
154
|
|
|
158
155
|
const httpServer = createServer(app);
|
|
@@ -221,6 +218,6 @@ const build = (config) => {
|
|
|
221
218
|
}
|
|
222
219
|
|
|
223
220
|
|
|
224
|
-
|
|
221
|
+
export {
|
|
225
222
|
build
|
|
226
223
|
};
|
package/graphql/parser.js
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
// PathwayPrompter.js
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
const { Exception } = require("handlebars");
|
|
2
|
+
import OpenAIChatPlugin from './plugins/openAIChatPlugin.js';
|
|
3
|
+
import OpenAICompletionPlugin from './plugins/openAICompletionPlugin.js';
|
|
4
|
+
import AzureTranslatePlugin from './plugins/azureTranslatePlugin.js';
|
|
5
|
+
import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
|
|
6
|
+
import handlebars from 'handlebars';
|
|
8
7
|
|
|
9
8
|
// register functions that can be called directly in the prompt markdown
|
|
10
9
|
handlebars.registerHelper('stripHTML', function (value) {
|
|
@@ -27,7 +26,7 @@ class PathwayPrompter {
|
|
|
27
26
|
const model = config.get('models')[modelName];
|
|
28
27
|
|
|
29
28
|
if (!model) {
|
|
30
|
-
throw new Exception(`Model ${modelName} not found in config`);
|
|
29
|
+
throw new handlebars.Exception(`Model ${modelName} not found in config`);
|
|
31
30
|
}
|
|
32
31
|
|
|
33
32
|
let plugin;
|
|
@@ -46,7 +45,7 @@ class PathwayPrompter {
|
|
|
46
45
|
plugin = new OpenAIWhisperPlugin(config, pathway);
|
|
47
46
|
break;
|
|
48
47
|
default:
|
|
49
|
-
throw new Exception(`Unsupported model type: ${model.type}`);
|
|
48
|
+
throw new handlebars.Exception(`Unsupported model type: ${model.type}`);
|
|
50
49
|
}
|
|
51
50
|
|
|
52
51
|
this.plugin = plugin;
|
|
@@ -57,6 +56,6 @@ class PathwayPrompter {
|
|
|
57
56
|
}
|
|
58
57
|
}
|
|
59
58
|
|
|
60
|
-
|
|
59
|
+
export {
|
|
61
60
|
PathwayPrompter
|
|
62
61
|
};
|
|
@@ -1,14 +1,12 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
}
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
const { getv, setv } = require('../lib/keyValueStorageClient');
|
|
11
|
-
const { requestState } = require('./requestState');
|
|
1
|
+
import { PathwayPrompter } from './pathwayPrompter.js';
|
|
2
|
+
import { v4 as uuidv4 } from 'uuid';
|
|
3
|
+
import pubsub from './pubsub.js';
|
|
4
|
+
import { encode } from 'gpt-3-encoder';
|
|
5
|
+
import { getFirstNToken, getLastNToken, getSemanticChunks } from './chunker.js';
|
|
6
|
+
import { PathwayResponseParser } from './pathwayResponseParser.js';
|
|
7
|
+
import { Prompt } from './prompt.js';
|
|
8
|
+
import { getv, setv } from '../lib/keyValueStorageClient.js';
|
|
9
|
+
import { requestState } from './requestState.js';
|
|
12
10
|
|
|
13
11
|
const MAX_PREVIOUS_RESULT_TOKEN_LENGTH = 1000;
|
|
14
12
|
|
|
@@ -125,7 +123,7 @@ class PathwayResolver {
|
|
|
125
123
|
// Get saved context from contextId or change contextId if needed
|
|
126
124
|
const { contextId } = args;
|
|
127
125
|
this.savedContextId = contextId ? contextId : null;
|
|
128
|
-
this.savedContext = contextId ? (getv && await getv(contextId) || {}) : {};
|
|
126
|
+
this.savedContext = contextId ? (getv && (await getv(contextId)) || {}) : {};
|
|
129
127
|
|
|
130
128
|
// Save the context before processing the request
|
|
131
129
|
const savedContextStr = JSON.stringify(this.savedContext);
|
|
@@ -163,7 +161,7 @@ class PathwayResolver {
|
|
|
163
161
|
}
|
|
164
162
|
|
|
165
163
|
// chunk the text and return the chunks with newline separators
|
|
166
|
-
return getSemanticChunks(
|
|
164
|
+
return getSemanticChunks(text, chunkTokenLength);
|
|
167
165
|
}
|
|
168
166
|
|
|
169
167
|
truncate(str, n) {
|
|
@@ -312,4 +310,4 @@ class PathwayResolver {
|
|
|
312
310
|
}
|
|
313
311
|
}
|
|
314
312
|
|
|
315
|
-
|
|
313
|
+
export { PathwayResolver };
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
import { parseNumberedList, parseNumberedObjectList } from './parser.js';
|
|
2
2
|
|
|
3
3
|
class PathwayResponseParser {
|
|
4
4
|
constructor(pathway) {
|
|
@@ -21,4 +21,4 @@ class PathwayResponseParser {
|
|
|
21
21
|
}
|
|
22
22
|
}
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
export { PathwayResponseParser };
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
// AzureTranslatePlugin.js
|
|
2
|
-
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
3
|
|
|
4
4
|
class AzureTranslatePlugin extends ModelPlugin {
|
|
5
5
|
constructor(config, pathway) {
|
|
@@ -37,4 +37,4 @@ class AzureTranslatePlugin extends ModelPlugin {
|
|
|
37
37
|
}
|
|
38
38
|
}
|
|
39
39
|
|
|
40
|
-
|
|
40
|
+
export default AzureTranslatePlugin;
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
// ModelPlugin.js
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
2
|
+
import handlebars from 'handlebars';
|
|
3
|
+
|
|
4
|
+
import { request } from '../../lib/request.js';
|
|
5
|
+
import { encode } from 'gpt-3-encoder';
|
|
6
|
+
import { getFirstNToken } from '../chunker.js';
|
|
5
7
|
|
|
6
8
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
7
9
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
@@ -38,37 +40,77 @@ class ModelPlugin {
|
|
|
38
40
|
this.shouldCache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
|
|
39
41
|
}
|
|
40
42
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
43
|
+
truncateMessagesToTargetLength = (messages, targetTokenLength) => {
|
|
44
|
+
// Calculate the token length of each message
|
|
45
|
+
const tokenLengths = messages.map((message) => ({
|
|
46
|
+
message,
|
|
47
|
+
tokenLength: encode(this.messagesToChatML([message], false)).length,
|
|
48
|
+
}));
|
|
45
49
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
// Calculate the total token length of all messages
|
|
51
|
+
let totalTokenLength = tokenLengths.reduce(
|
|
52
|
+
(sum, { tokenLength }) => sum + tokenLength,
|
|
53
|
+
0
|
|
54
|
+
);
|
|
55
|
+
|
|
56
|
+
// If we're already under the target token length, just bail
|
|
57
|
+
if (totalTokenLength <= targetTokenLength) return messages;
|
|
58
|
+
|
|
59
|
+
// Remove and/or truncate messages until the target token length is reached
|
|
60
|
+
let index = 0;
|
|
61
|
+
while (totalTokenLength > targetTokenLength) {
|
|
62
|
+
const message = tokenLengths[index].message;
|
|
63
|
+
|
|
64
|
+
// Skip system messages
|
|
65
|
+
if (message.role === 'system') {
|
|
66
|
+
index++;
|
|
67
|
+
continue;
|
|
54
68
|
}
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
|
|
70
|
+
const currentTokenLength = tokenLengths[index].tokenLength;
|
|
71
|
+
|
|
72
|
+
if (totalTokenLength - currentTokenLength >= targetTokenLength) {
|
|
73
|
+
// Remove the message entirely if doing so won't go below the target token length
|
|
74
|
+
totalTokenLength -= currentTokenLength;
|
|
75
|
+
tokenLengths.splice(index, 1);
|
|
76
|
+
} else {
|
|
77
|
+
// Truncate the message to fit the remaining target token length
|
|
78
|
+
const emptyContentLength = encode(this.messagesToChatML([{ ...message, content: '' }], false)).length;
|
|
79
|
+
const otherMessageTokens = totalTokenLength - currentTokenLength;
|
|
80
|
+
const tokensToKeep = targetTokenLength - (otherMessageTokens + emptyContentLength);
|
|
81
|
+
|
|
82
|
+
const truncatedContent = getFirstNToken(message.content, tokensToKeep);
|
|
83
|
+
const truncatedMessage = { ...message, content: truncatedContent };
|
|
84
|
+
|
|
85
|
+
tokenLengths[index] = {
|
|
86
|
+
message: truncatedMessage,
|
|
87
|
+
tokenLength: encode(this.messagesToChatML([ truncatedMessage ], false)).length
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// calculate the length again to keep us honest
|
|
91
|
+
totalTokenLength = tokenLengths.reduce(
|
|
92
|
+
(sum, { tokenLength }) => sum + tokenLength,
|
|
93
|
+
0
|
|
94
|
+
);
|
|
57
95
|
}
|
|
58
96
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
97
|
+
|
|
98
|
+
// Return the modified messages array
|
|
99
|
+
return tokenLengths.map(({ message }) => message);
|
|
100
|
+
};
|
|
101
|
+
|
|
62
102
|
//convert a messages array to a simple chatML format
|
|
63
|
-
messagesToChatML =
|
|
103
|
+
messagesToChatML(messages, addAssistant = true) {
|
|
64
104
|
let output = "";
|
|
65
105
|
if (messages && messages.length) {
|
|
66
106
|
for (let message of messages) {
|
|
67
|
-
output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
|
|
107
|
+
output += (message.role && (message.content || message.content === '')) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
|
|
68
108
|
}
|
|
69
109
|
// you always want the assistant to respond next so add a
|
|
70
110
|
// directive for that
|
|
71
|
-
|
|
111
|
+
if (addAssistant) {
|
|
112
|
+
output += "<|im_start|>assistant\n";
|
|
113
|
+
}
|
|
72
114
|
}
|
|
73
115
|
return output;
|
|
74
116
|
}
|
|
@@ -196,7 +238,7 @@ class ModelPlugin {
|
|
|
196
238
|
const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName);
|
|
197
239
|
|
|
198
240
|
if (responseData.error) {
|
|
199
|
-
throw new
|
|
241
|
+
throw new Error(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
|
|
200
242
|
}
|
|
201
243
|
|
|
202
244
|
this.logRequestData(data, responseData, prompt);
|
|
@@ -205,6 +247,6 @@ class ModelPlugin {
|
|
|
205
247
|
|
|
206
248
|
}
|
|
207
249
|
|
|
208
|
-
|
|
250
|
+
export default ModelPlugin;
|
|
209
251
|
|
|
210
252
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
// OpenAIChatPlugin.js
|
|
2
|
-
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
3
|
|
|
4
4
|
class OpenAIChatPlugin extends ModelPlugin {
|
|
5
5
|
constructor(config, pathway) {
|
|
@@ -19,7 +19,7 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
19
19
|
// Check if the token length exceeds the model's max token length
|
|
20
20
|
if (tokenLength > modelMaxTokenLength) {
|
|
21
21
|
// Remove older messages until the token length is within the model's limit
|
|
22
|
-
requestMessages = this.
|
|
22
|
+
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelMaxTokenLength);
|
|
23
23
|
}
|
|
24
24
|
|
|
25
25
|
const requestParameters = {
|
|
@@ -43,4 +43,4 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
43
43
|
}
|
|
44
44
|
}
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
export default OpenAIChatPlugin;
|