@aj-archipelago/cortex 1.0.6 → 1.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/config/default.example.json +4 -2
- package/package.json +2 -1
- package/pathways/basePathway.js +1 -0
- package/server/chunker.js +48 -3
- package/server/graphql.js +7 -1
- package/server/pathwayPrompter.js +13 -16
- package/server/pathwayResolver.js +58 -32
- package/server/plugins/azureTranslatePlugin.js +2 -2
- package/server/plugins/localModelPlugin.js +2 -2
- package/server/plugins/modelPlugin.js +8 -10
- package/server/plugins/openAiChatPlugin.js +2 -2
- package/server/plugins/openAiCompletionPlugin.js +2 -2
- package/server/plugins/openAiWhisperPlugin.js +3 -3
- package/server/plugins/palmChatPlugin.js +4 -6
- package/server/plugins/palmCodeCompletionPlugin.js +46 -0
- package/server/plugins/palmCompletionPlugin.js +13 -15
- package/tests/chunkfunction.test.js +112 -26
- package/tests/mocks.js +42 -1
- package/tests/modelPlugin.test.js +3 -3
- package/tests/openAiChatPlugin.test.js +20 -13
- package/tests/palmChatPlugin.test.js +2 -3
- package/tests/palmCompletionPlugin.test.js +2 -3
- package/tests/truncateMessages.test.js +3 -4
- package/tests/server.js +0 -23
|
@@ -58,7 +58,8 @@
|
|
|
58
58
|
"Content-Type": "application/json"
|
|
59
59
|
},
|
|
60
60
|
"requestsPerSecond": 10,
|
|
61
|
-
"maxTokenLength": 2048
|
|
61
|
+
"maxTokenLength": 2048,
|
|
62
|
+
"maxReturnTokens": 1024
|
|
62
63
|
},
|
|
63
64
|
"palm-chat": {
|
|
64
65
|
"type": "PALM-CHAT",
|
|
@@ -67,7 +68,8 @@
|
|
|
67
68
|
"Content-Type": "application/json"
|
|
68
69
|
},
|
|
69
70
|
"requestsPerSecond": 10,
|
|
70
|
-
"maxTokenLength": 2048
|
|
71
|
+
"maxTokenLength": 2048,
|
|
72
|
+
"maxReturnTokens": 1024
|
|
71
73
|
},
|
|
72
74
|
"local-llama13B": {
|
|
73
75
|
"type": "LOCAL-CPP-MODEL",
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aj-archipelago/cortex",
|
|
3
|
-
"version": "1.0.
|
|
3
|
+
"version": "1.0.7",
|
|
4
4
|
"description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
|
|
5
5
|
"repository": {
|
|
6
6
|
"type": "git",
|
|
@@ -36,6 +36,7 @@
|
|
|
36
36
|
"axios": "^1.3.4",
|
|
37
37
|
"axios-cache-interceptor": "^1.0.1",
|
|
38
38
|
"bottleneck": "^2.19.5",
|
|
39
|
+
"cheerio": "^1.0.0-rc.12",
|
|
39
40
|
"compromise": "^14.8.1",
|
|
40
41
|
"compromise-paragraphs": "^0.1.0",
|
|
41
42
|
"convict": "^6.2.3",
|
package/pathways/basePathway.js
CHANGED
package/server/chunker.js
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import { encode, decode } from 'gpt-3-encoder';
|
|
2
|
+
import cheerio from 'cheerio';
|
|
2
3
|
|
|
3
4
|
const getLastNToken = (text, maxTokenLen) => {
|
|
4
5
|
const encoded = encode(text);
|
|
@@ -18,8 +19,18 @@ const getFirstNToken = (text, maxTokenLen) => {
|
|
|
18
19
|
return text;
|
|
19
20
|
}
|
|
20
21
|
|
|
21
|
-
const
|
|
22
|
+
const determineTextFormat = (text) => {
|
|
23
|
+
const htmlTagPattern = /<[^>]*>/g;
|
|
24
|
+
|
|
25
|
+
if (htmlTagPattern.test(text)) {
|
|
26
|
+
return 'html';
|
|
27
|
+
}
|
|
28
|
+
else {
|
|
29
|
+
return 'text';
|
|
30
|
+
}
|
|
31
|
+
}
|
|
22
32
|
|
|
33
|
+
const getSemanticChunks = (text, chunkSize, inputFormat = 'text') => {
|
|
23
34
|
const breakByRegex = (str, regex, preserveWhitespace = false) => {
|
|
24
35
|
const result = [];
|
|
25
36
|
let match;
|
|
@@ -46,6 +57,19 @@ const getSemanticChunks = (text, chunkSize) => {
|
|
|
46
57
|
const breakBySentences = (str) => breakByRegex(str, /(?<=[.。؟!?!\n])\s+/, true);
|
|
47
58
|
const breakByWords = (str) => breakByRegex(str, /(\s,;:.+)/);
|
|
48
59
|
|
|
60
|
+
const breakByHtmlElements = (str) => {
|
|
61
|
+
const $ = cheerio.load(str, null, true);
|
|
62
|
+
|
|
63
|
+
// the .filter() call is important to get the text nodes
|
|
64
|
+
// https://stackoverflow.com/questions/54878673/cheerio-get-normal-text-nodes
|
|
65
|
+
let rootNodes = $('body').contents();
|
|
66
|
+
|
|
67
|
+
// create an array with the outerHTML of each node
|
|
68
|
+
const nodes = rootNodes.map((i, el) => $(el).prop('outerHTML') || $(el).text()).get();
|
|
69
|
+
|
|
70
|
+
return nodes;
|
|
71
|
+
};
|
|
72
|
+
|
|
49
73
|
const createChunks = (tokens) => {
|
|
50
74
|
let chunks = [];
|
|
51
75
|
let currentChunk = '';
|
|
@@ -115,7 +139,28 @@ const getSemanticChunks = (text, chunkSize) => {
|
|
|
115
139
|
return createChunks([...str]); // Split by characters
|
|
116
140
|
};
|
|
117
141
|
|
|
118
|
-
|
|
142
|
+
if (inputFormat === 'html') {
|
|
143
|
+
const tokens = breakByHtmlElements(text);
|
|
144
|
+
let chunks = createChunks(tokens);
|
|
145
|
+
chunks = combineChunks(chunks);
|
|
146
|
+
|
|
147
|
+
chunks = chunks.flatMap(chunk => {
|
|
148
|
+
if (determineTextFormat(chunk) === 'text') {
|
|
149
|
+
return getSemanticChunks(chunk, chunkSize);
|
|
150
|
+
} else {
|
|
151
|
+
return chunk;
|
|
152
|
+
}
|
|
153
|
+
});
|
|
154
|
+
|
|
155
|
+
if (chunks.some(chunk => encode(chunk).length > chunkSize)) {
|
|
156
|
+
throw new Error('The HTML contains elements that are larger than the chunk size. Please try again with HTML that has smaller elements.');
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
return chunks;
|
|
160
|
+
}
|
|
161
|
+
else {
|
|
162
|
+
return breakText(text);
|
|
163
|
+
}
|
|
119
164
|
}
|
|
120
165
|
|
|
121
166
|
|
|
@@ -133,5 +178,5 @@ const semanticTruncate = (text, maxLength) => {
|
|
|
133
178
|
};
|
|
134
179
|
|
|
135
180
|
export {
|
|
136
|
-
getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken
|
|
181
|
+
getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken, determineTextFormat
|
|
137
182
|
};
|
package/server/graphql.js
CHANGED
|
@@ -164,7 +164,13 @@ const build = async (config) => {
|
|
|
164
164
|
const cortexApiKey = config.get('cortexApiKey');
|
|
165
165
|
if (cortexApiKey) {
|
|
166
166
|
app.use((req, res, next) => {
|
|
167
|
-
|
|
167
|
+
let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key'];
|
|
168
|
+
if (!providedApiKey) {
|
|
169
|
+
providedApiKey = req.headers['authorization'];
|
|
170
|
+
providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if (cortexApiKey && cortexApiKey !== providedApiKey) {
|
|
168
174
|
if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') {
|
|
169
175
|
res.status(401)
|
|
170
176
|
.set('WWW-Authenticate', 'Cortex-Api-Key')
|
|
@@ -6,40 +6,37 @@ import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
|
|
|
6
6
|
import LocalModelPlugin from './plugins/localModelPlugin.js';
|
|
7
7
|
import PalmChatPlugin from './plugins/palmChatPlugin.js';
|
|
8
8
|
import PalmCompletionPlugin from './plugins/palmCompletionPlugin.js';
|
|
9
|
+
import PalmCodeCompletionPlugin from './plugins/palmCodeCompletionPlugin.js';
|
|
9
10
|
|
|
10
11
|
class PathwayPrompter {
|
|
11
|
-
constructor(
|
|
12
|
-
|
|
13
|
-
const modelName = pathway.model || config.get('defaultModelName');
|
|
14
|
-
const model = config.get('models')[modelName];
|
|
15
|
-
|
|
16
|
-
if (!model) {
|
|
17
|
-
throw new Error(`Model ${modelName} not found in config`);
|
|
18
|
-
}
|
|
19
|
-
|
|
12
|
+
constructor(config, pathway, modelName, model) {
|
|
13
|
+
|
|
20
14
|
let plugin;
|
|
21
15
|
|
|
22
16
|
switch (model.type) {
|
|
23
17
|
case 'OPENAI-CHAT':
|
|
24
|
-
plugin = new OpenAIChatPlugin(config, pathway);
|
|
18
|
+
plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
|
|
25
19
|
break;
|
|
26
20
|
case 'AZURE-TRANSLATE':
|
|
27
|
-
plugin = new AzureTranslatePlugin(config, pathway);
|
|
21
|
+
plugin = new AzureTranslatePlugin(config, pathway, modelName, model);
|
|
28
22
|
break;
|
|
29
23
|
case 'OPENAI-COMPLETION':
|
|
30
|
-
plugin = new OpenAICompletionPlugin(config, pathway);
|
|
24
|
+
plugin = new OpenAICompletionPlugin(config, pathway, modelName, model);
|
|
31
25
|
break;
|
|
32
26
|
case 'OPENAI-WHISPER':
|
|
33
|
-
plugin = new OpenAIWhisperPlugin(config, pathway);
|
|
27
|
+
plugin = new OpenAIWhisperPlugin(config, pathway, modelName, model);
|
|
34
28
|
break;
|
|
35
29
|
case 'LOCAL-CPP-MODEL':
|
|
36
|
-
plugin = new LocalModelPlugin(config, pathway);
|
|
30
|
+
plugin = new LocalModelPlugin(config, pathway, modelName, model);
|
|
37
31
|
break;
|
|
38
32
|
case 'PALM-CHAT':
|
|
39
|
-
plugin = new PalmChatPlugin(config, pathway);
|
|
33
|
+
plugin = new PalmChatPlugin(config, pathway, modelName, model);
|
|
40
34
|
break;
|
|
41
35
|
case 'PALM-COMPLETION':
|
|
42
|
-
plugin = new PalmCompletionPlugin(config, pathway);
|
|
36
|
+
plugin = new PalmCompletionPlugin(config, pathway, modelName, model);
|
|
37
|
+
break;
|
|
38
|
+
case 'PALM-CODE-COMPLETION':
|
|
39
|
+
plugin = new PalmCodeCompletionPlugin(config, pathway, modelName, model);
|
|
43
40
|
break;
|
|
44
41
|
default:
|
|
45
42
|
throw new Error(`Unsupported model type: ${model.type}`);
|
|
@@ -20,9 +20,31 @@ class PathwayResolver {
|
|
|
20
20
|
this.warnings = [];
|
|
21
21
|
this.requestId = uuidv4();
|
|
22
22
|
this.responseParser = new PathwayResponseParser(pathway);
|
|
23
|
-
this.
|
|
23
|
+
this.modelName = [
|
|
24
|
+
pathway.model,
|
|
25
|
+
args?.model,
|
|
26
|
+
pathway.inputParameters?.model,
|
|
27
|
+
config.get('defaultModelName')
|
|
28
|
+
].find(modelName => modelName && config.get('models').hasOwnProperty(modelName));
|
|
29
|
+
this.model = config.get('models')[this.modelName];
|
|
30
|
+
|
|
31
|
+
if (!this.model) {
|
|
32
|
+
throw new Error(`Model ${this.modelName} not found in config`);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const specifiedModelName = pathway.model || args?.model || pathway.inputParameters?.model;
|
|
36
|
+
|
|
37
|
+
if (this.modelName !== (specifiedModelName)) {
|
|
38
|
+
if (specifiedModelName) {
|
|
39
|
+
this.logWarning(`Specified model ${specifiedModelName} not found in config, using ${this.modelName} instead.`);
|
|
40
|
+
} else {
|
|
41
|
+
this.logWarning(`No model specified in the pathway, using ${this.modelName}.`);
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
24
45
|
this.previousResult = '';
|
|
25
46
|
this.prompts = [];
|
|
47
|
+
this.pathwayPrompter = new PathwayPrompter(this.config, this.pathway, this.modelName, this.model);
|
|
26
48
|
|
|
27
49
|
Object.defineProperty(this, 'pathwayPrompt', {
|
|
28
50
|
get() {
|
|
@@ -56,37 +78,41 @@ class PathwayResolver {
|
|
|
56
78
|
}
|
|
57
79
|
});
|
|
58
80
|
} else { // stream
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
81
|
+
try {
|
|
82
|
+
const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
|
|
83
|
+
incomingMessage.on('data', data => {
|
|
84
|
+
const events = data.toString().split('\n');
|
|
85
|
+
|
|
86
|
+
events.forEach(event => {
|
|
87
|
+
if (event.trim() === '') return; // Skip empty lines
|
|
88
|
+
|
|
89
|
+
const message = event.replace(/^data: /, '');
|
|
90
|
+
|
|
91
|
+
//console.log(`====================================`);
|
|
92
|
+
//console.log(`STREAM EVENT: ${event}`);
|
|
93
|
+
//console.log(`MESSAGE: ${message}`);
|
|
94
|
+
|
|
95
|
+
const requestProgress = {
|
|
96
|
+
requestId: this.requestId,
|
|
97
|
+
data: message,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
if (message.trim() === '[DONE]') {
|
|
101
|
+
requestProgress.progress = 1;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
try {
|
|
105
|
+
pubsub.publish('REQUEST_PROGRESS', {
|
|
106
|
+
requestProgress: requestProgress
|
|
107
|
+
});
|
|
108
|
+
} catch (error) {
|
|
109
|
+
console.error('Could not JSON parse stream message', message, error);
|
|
110
|
+
}
|
|
111
|
+
});
|
|
88
112
|
});
|
|
89
|
-
})
|
|
113
|
+
} catch (error) {
|
|
114
|
+
console.error('Could not subscribe to stream', error);
|
|
115
|
+
}
|
|
90
116
|
}
|
|
91
117
|
}
|
|
92
118
|
|
|
@@ -152,7 +178,7 @@ class PathwayResolver {
|
|
|
152
178
|
}
|
|
153
179
|
|
|
154
180
|
// chunk the text and return the chunks with newline separators
|
|
155
|
-
return getSemanticChunks(text, chunkTokenLength);
|
|
181
|
+
return getSemanticChunks(text, chunkTokenLength, this.pathway.inputFormat);
|
|
156
182
|
}
|
|
157
183
|
|
|
158
184
|
truncate(str, n) {
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
import ModelPlugin from './modelPlugin.js';
|
|
3
3
|
|
|
4
4
|
class AzureTranslatePlugin extends ModelPlugin {
|
|
5
|
-
constructor(config, pathway) {
|
|
6
|
-
super(config, pathway);
|
|
5
|
+
constructor(config, pathway, modelName, model) {
|
|
6
|
+
super(config, pathway, modelName, model);
|
|
7
7
|
}
|
|
8
8
|
|
|
9
9
|
// Set up parameters specific to the Azure Translate API
|
|
@@ -4,8 +4,8 @@ import { execFileSync } from 'child_process';
|
|
|
4
4
|
import { encode } from 'gpt-3-encoder';
|
|
5
5
|
|
|
6
6
|
class LocalModelPlugin extends ModelPlugin {
|
|
7
|
-
constructor(config, pathway) {
|
|
8
|
-
super(config, pathway);
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
// if the input starts with a chatML response, just return that
|
|
@@ -6,19 +6,13 @@ import { encode } from 'gpt-3-encoder';
|
|
|
6
6
|
import { getFirstNToken } from '../chunker.js';
|
|
7
7
|
|
|
8
8
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
9
|
+
const DEFAULT_MAX_RETURN_TOKENS = 256;
|
|
9
10
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
10
11
|
|
|
11
12
|
class ModelPlugin {
|
|
12
|
-
constructor(config, pathway) {
|
|
13
|
-
|
|
14
|
-
this.
|
|
15
|
-
// Get the model from the config
|
|
16
|
-
this.model = config.get('models')[this.modelName];
|
|
17
|
-
// If the model doesn't exist, throw an exception
|
|
18
|
-
if (!this.model) {
|
|
19
|
-
throw new Error(`Model ${this.modelName} not found in config`);
|
|
20
|
-
}
|
|
21
|
-
|
|
13
|
+
constructor(config, pathway, modelName, model) {
|
|
14
|
+
this.modelName = modelName;
|
|
15
|
+
this.model = model;
|
|
22
16
|
this.config = config;
|
|
23
17
|
this.environmentVariables = config.getEnv();
|
|
24
18
|
this.temperature = pathway.temperature;
|
|
@@ -143,6 +137,10 @@ class ModelPlugin {
|
|
|
143
137
|
return (this.promptParameters.maxTokenLength ?? this.model.maxTokenLength ?? DEFAULT_MAX_TOKENS);
|
|
144
138
|
}
|
|
145
139
|
|
|
140
|
+
getModelMaxReturnTokens() {
|
|
141
|
+
return (this.promptParameters.maxReturnTokens ?? this.model.maxReturnTokens ?? DEFAULT_MAX_RETURN_TOKENS);
|
|
142
|
+
}
|
|
143
|
+
|
|
146
144
|
getPromptTokenRatio() {
|
|
147
145
|
// TODO: Is this the right order of precedence? inputParameters should maybe be second?
|
|
148
146
|
return this.promptParameters.inputParameters?.tokenRatio ?? this.promptParameters.tokenRatio ?? DEFAULT_PROMPT_TOKEN_RATIO;
|
|
@@ -3,8 +3,8 @@ import ModelPlugin from './modelPlugin.js';
|
|
|
3
3
|
import { encode } from 'gpt-3-encoder';
|
|
4
4
|
|
|
5
5
|
class OpenAIChatPlugin extends ModelPlugin {
|
|
6
|
-
constructor(config, pathway) {
|
|
7
|
-
super(config, pathway);
|
|
6
|
+
constructor(config, pathway, modelName, model) {
|
|
7
|
+
super(config, pathway, modelName, model);
|
|
8
8
|
}
|
|
9
9
|
|
|
10
10
|
// convert to OpenAI messages array format if necessary
|
|
@@ -15,8 +15,8 @@ const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, tar
|
|
|
15
15
|
}
|
|
16
16
|
|
|
17
17
|
class OpenAICompletionPlugin extends ModelPlugin {
|
|
18
|
-
constructor(config, pathway) {
|
|
19
|
-
super(config, pathway);
|
|
18
|
+
constructor(config, pathway, modelName, model) {
|
|
19
|
+
super(config, pathway, modelName, model);
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
// Set up parameters specific to the OpenAI Completion API
|
|
@@ -75,14 +75,14 @@ const downloadFile = async (fileUrl) => {
|
|
|
75
75
|
fs.unlink(localFilePath, () => {
|
|
76
76
|
reject(error);
|
|
77
77
|
});
|
|
78
|
-
throw error;
|
|
78
|
+
//throw error;
|
|
79
79
|
}
|
|
80
80
|
});
|
|
81
81
|
};
|
|
82
82
|
|
|
83
83
|
class OpenAIWhisperPlugin extends ModelPlugin {
|
|
84
|
-
constructor(config, pathway) {
|
|
85
|
-
super(config, pathway);
|
|
84
|
+
constructor(config, pathway, modelName, model) {
|
|
85
|
+
super(config, pathway, modelName, model);
|
|
86
86
|
}
|
|
87
87
|
|
|
88
88
|
async getMediaChunks(file, requestId) {
|
|
@@ -4,8 +4,8 @@ import { encode } from 'gpt-3-encoder';
|
|
|
4
4
|
import HandleBars from '../../lib/handleBars.js';
|
|
5
5
|
|
|
6
6
|
class PalmChatPlugin extends ModelPlugin {
|
|
7
|
-
constructor(config, pathway) {
|
|
8
|
-
super(config, pathway);
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
// Convert to PaLM messages array format if necessary
|
|
@@ -92,10 +92,8 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
92
92
|
const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
|
|
93
93
|
const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
|
|
98
|
-
|
|
95
|
+
const max_tokens = this.getModelMaxReturnTokens();
|
|
96
|
+
|
|
99
97
|
if (max_tokens < 0) {
|
|
100
98
|
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
101
99
|
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
// palmCodeCompletionPlugin.js
|
|
2
|
+
|
|
3
|
+
import PalmCompletionPlugin from './palmCompletionPlugin.js';
|
|
4
|
+
|
|
5
|
+
// PalmCodeCompletionPlugin class for handling requests and responses to the PaLM API Code Completion API
|
|
6
|
+
class PalmCodeCompletionPlugin extends PalmCompletionPlugin {
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
// Set up parameters specific to the PaLM API Code Completion API
|
|
12
|
+
getRequestParameters(text, parameters, prompt, pathwayResolver) {
|
|
13
|
+
const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
14
|
+
const { stream } = parameters;
|
|
15
|
+
// Define the model's max token length
|
|
16
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
17
|
+
|
|
18
|
+
const truncatedPrompt = this.truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
|
|
19
|
+
|
|
20
|
+
const max_tokens = this.getModelMaxReturnTokens();
|
|
21
|
+
|
|
22
|
+
if (max_tokens < 0) {
|
|
23
|
+
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
if (!truncatedPrompt) {
|
|
27
|
+
throw new Error(`Prompt is empty. The model will not be called.`);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const requestParameters = {
|
|
31
|
+
instances: [
|
|
32
|
+
{ prefix: truncatedPrompt }
|
|
33
|
+
],
|
|
34
|
+
parameters: {
|
|
35
|
+
temperature: this.temperature ?? 0.7,
|
|
36
|
+
maxOutputTokens: max_tokens,
|
|
37
|
+
topP: parameters.topP ?? 0.95,
|
|
38
|
+
topK: parameters.topK ?? 40,
|
|
39
|
+
}
|
|
40
|
+
};
|
|
41
|
+
|
|
42
|
+
return requestParameters;
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
export default PalmCodeCompletionPlugin;
|
|
@@ -2,23 +2,21 @@
|
|
|
2
2
|
|
|
3
3
|
import ModelPlugin from './modelPlugin.js';
|
|
4
4
|
|
|
5
|
-
// Helper function to truncate the prompt if it is too long
|
|
6
|
-
const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) => {
|
|
7
|
-
const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
|
|
8
|
-
|
|
9
|
-
if (textTokenCount > maxAllowedTokens) {
|
|
10
|
-
pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
|
|
11
|
-
return pathwayResolver.truncate(text, maxAllowedTokens);
|
|
12
|
-
}
|
|
13
|
-
return text;
|
|
14
|
-
}
|
|
15
|
-
|
|
16
5
|
// PalmCompletionPlugin class for handling requests and responses to the PaLM API Text Completion API
|
|
17
6
|
class PalmCompletionPlugin extends ModelPlugin {
|
|
18
|
-
constructor(config, pathway) {
|
|
19
|
-
super(config, pathway);
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
20
9
|
}
|
|
21
10
|
|
|
11
|
+
truncatePromptIfNecessary (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) {
|
|
12
|
+
const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
|
|
13
|
+
|
|
14
|
+
if (textTokenCount > maxAllowedTokens) {
|
|
15
|
+
pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
|
|
16
|
+
return pathwayResolver.truncate(text, maxAllowedTokens);
|
|
17
|
+
}
|
|
18
|
+
return text;
|
|
19
|
+
}
|
|
22
20
|
// Set up parameters specific to the PaLM API Text Completion API
|
|
23
21
|
getRequestParameters(text, parameters, prompt, pathwayResolver) {
|
|
24
22
|
const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
@@ -26,9 +24,9 @@ class PalmCompletionPlugin extends ModelPlugin {
|
|
|
26
24
|
// Define the model's max token length
|
|
27
25
|
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
28
26
|
|
|
29
|
-
const truncatedPrompt = truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
|
|
27
|
+
const truncatedPrompt = this.truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
|
|
30
28
|
|
|
31
|
-
const max_tokens =
|
|
29
|
+
const max_tokens = this.getModelMaxReturnTokens();
|
|
32
30
|
|
|
33
31
|
if (max_tokens < 0) {
|
|
34
32
|
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import test from 'ava';
|
|
2
|
-
import { getSemanticChunks } from '../server/chunker.js';
|
|
2
|
+
import { getSemanticChunks, determineTextFormat } from '../server/chunker.js';
|
|
3
|
+
|
|
3
4
|
import { encode } from 'gpt-3-encoder';
|
|
4
5
|
|
|
5
6
|
const testText = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. In id erat sem. Phasellus ac dapibus purus, in fermentum nunc. Mauris quis rutrum magna. Quisque rutrum, augue vel blandit posuere, augue magna convallis turpis, nec elementum augue mauris sit amet nunc. Aenean sit amet leo est. Nunc ante ex, blandit et felis ut, iaculis lacinia est. Phasellus dictum orci id libero ullamcorper tempor.
|
|
@@ -69,34 +70,119 @@ test('should return identical text that chunker was passed, given tiny chunk siz
|
|
|
69
70
|
t.is(recomposedText, testText); //check recomposition
|
|
70
71
|
});
|
|
71
72
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
73
|
+
const htmlChunkOne = `<p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. <a href="https://www.google.com">Google</a></p> Vivamus id pharetra odio. Sed consectetur leo sed tortor dictum venenatis.Donec gravida libero non accumsan suscipit.Donec lectus turpis, ullamcorper eu pulvinar iaculis, ornare ut risus.Phasellus aliquam, turpis quis viverra condimentum, risus est pretium metus, in porta ipsum tortor vitae elit.Pellentesque id finibus erat. In suscipit, sapien non posuere dignissim, augue nisl ultrices tortor, sit amet eleifend nibh elit at risus.`
|
|
74
|
+
const htmlVoidElement = `<br>`
|
|
75
|
+
const htmlChunkTwo = `<p><img src="https://www.google.com/googlelogo_color_272x92dp.png"></p>`
|
|
76
|
+
const htmlSelfClosingElement = `<img src="https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png" />`
|
|
77
|
+
const plainTextChunk = 'Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Fusce at dignissim quam.'
|
|
78
|
+
|
|
79
|
+
test('should throw an error if html cannot be accommodated within the chunk size', async t => {
|
|
80
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
81
|
+
const error = t.throws(() => getSemanticChunks(htmlChunkTwo, chunkSize - 1, 'html'));
|
|
82
|
+
t.is(error.message, 'The HTML contains elements that are larger than the chunk size. Please try again with HTML that has smaller elements.');
|
|
83
|
+
});
|
|
84
|
+
|
|
85
|
+
test('should chunk text between html elements if needed', async t => {
|
|
86
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
87
|
+
const chunks = getSemanticChunks(htmlChunkTwo + plainTextChunk + htmlChunkTwo, chunkSize, 'html');
|
|
88
|
+
|
|
89
|
+
t.is(chunks.length, 4);
|
|
90
|
+
t.is(chunks[0], htmlChunkTwo);
|
|
91
|
+
t.is(chunks[1], 'Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae');
|
|
92
|
+
t.is(encode(chunks[1]).length, chunkSize);
|
|
93
|
+
t.is(chunks[2], '; Fusce at dignissim quam.');
|
|
94
|
+
t.is(chunks[3], htmlChunkTwo);
|
|
95
|
+
});
|
|
96
|
+
|
|
97
|
+
test('should chunk html element correctly when chunk size is exactly the same as the element length', async t => {
|
|
98
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
99
|
+
const chunks = getSemanticChunks(htmlChunkTwo, chunkSize, 'html');
|
|
100
|
+
|
|
101
|
+
t.is(chunks.length, 1);
|
|
102
|
+
t.is(chunks[0], htmlChunkTwo);
|
|
103
|
+
});
|
|
104
|
+
|
|
105
|
+
test('should chunk html element correctly when chunk size is greater than the element length', async t => {
|
|
106
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
107
|
+
const chunks = getSemanticChunks(htmlChunkTwo, chunkSize + 1, 'html');
|
|
108
|
+
|
|
109
|
+
t.is(chunks.length, 1);
|
|
110
|
+
t.is(chunks[0], htmlChunkTwo);
|
|
111
|
+
});
|
|
112
|
+
|
|
113
|
+
test('should not break up second html element correctly when chunk size is greater than the first element length', async t => {
|
|
114
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
115
|
+
const chunks = getSemanticChunks(htmlChunkTwo + htmlChunkTwo, chunkSize + 10, 'html');
|
|
116
|
+
|
|
117
|
+
t.is(chunks.length, 2);
|
|
118
|
+
t.is(chunks[0], htmlChunkTwo);
|
|
119
|
+
t.is(chunks[1], htmlChunkTwo);
|
|
120
|
+
});
|
|
121
|
+
|
|
122
|
+
test('should treat text chunks as also unbreakable chunks', async t => {
|
|
123
|
+
const chunkSize = encode(htmlChunkTwo).length;
|
|
124
|
+
const chunks = getSemanticChunks(htmlChunkTwo + plainTextChunk + htmlChunkTwo, chunkSize + 20, 'html');
|
|
125
|
+
|
|
126
|
+
t.is(chunks.length, 3);
|
|
127
|
+
t.is(chunks[0], htmlChunkTwo);
|
|
128
|
+
t.is(chunks[1], plainTextChunk);
|
|
129
|
+
t.is(chunks[2], htmlChunkTwo);
|
|
130
|
+
});
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
test('should determine format correctly for text only', async t => {
|
|
134
|
+
const format = determineTextFormat(plainTextChunk);
|
|
135
|
+
t.is(format, 'text');
|
|
136
|
+
});
|
|
137
|
+
|
|
138
|
+
test('should determine format correctly for simple html element', async t => {
|
|
139
|
+
const format = determineTextFormat(htmlChunkTwo);
|
|
140
|
+
t.is(format, 'html');
|
|
141
|
+
});
|
|
142
|
+
|
|
143
|
+
test('should determine format correctly for simple html element embedded in text', async t => {
|
|
144
|
+
const format = determineTextFormat(plainTextChunk + htmlChunkTwo + plainTextChunk);
|
|
145
|
+
t.is(format, 'html');
|
|
146
|
+
});
|
|
147
|
+
|
|
148
|
+
test('should determine format correctly for self-closing html element', async t => {
|
|
149
|
+
const format = determineTextFormat(htmlSelfClosingElement);
|
|
150
|
+
t.is(format, 'html');
|
|
151
|
+
});
|
|
152
|
+
|
|
153
|
+
test('should determine format correctly for self-closing html element embedded in text', async t => {
|
|
154
|
+
const format = determineTextFormat(plainTextChunk + htmlSelfClosingElement + plainTextChunk);
|
|
155
|
+
t.is(format, 'html');
|
|
156
|
+
});
|
|
157
|
+
|
|
158
|
+
test('should determine format correctly for void element', async t => {
|
|
159
|
+
const format = determineTextFormat(htmlVoidElement);
|
|
160
|
+
t.is(format, 'html');
|
|
161
|
+
});
|
|
162
|
+
|
|
163
|
+
test('should determine format correctly for void element embedded in text', async t => {
|
|
164
|
+
const format = determineTextFormat(plainTextChunk + htmlVoidElement + plainTextChunk);
|
|
165
|
+
t.is(format, 'html');
|
|
80
166
|
});
|
|
81
167
|
|
|
82
|
-
|
|
168
|
+
test('should return identical text that chunker was passed, given huge chunk size (32000)', t => {
|
|
83
169
|
const maxChunkToken = 32000;
|
|
84
170
|
const chunks = getSemanticChunks(testText, maxChunkToken);
|
|
85
|
-
|
|
86
|
-
|
|
171
|
+
t.assert(chunks.length === 1); //check chunking
|
|
172
|
+
t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
|
|
87
173
|
const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
|
|
88
|
-
|
|
174
|
+
t.assert(recomposedText === testText); //check recomposition
|
|
89
175
|
});
|
|
90
176
|
|
|
91
177
|
const testTextNoSpaces = `Loremipsumdolorsitamet,consecteturadipiscingelit.Inideratsem.Phasellusacdapibuspurus,infermentumnunc.Maurisquisrutrummagna.Quisquerutrum,auguevelblanditposuere,auguemagnacon vallisturpis,necelementumauguemaurissitametnunc.Aeneansitametleoest.Nuncanteex,blanditetfelisut,iaculislaciniaest.Phasellusdictumorciidliberoullamcorpertempor.Vivamusidpharetraodioq.Sedconsecteturleosedtortordictumvenenatis.Donecgravidaliberononaccumsansuscipit.Doneclectusturpis,ullamcorpereupulvinariaculis,ornareutrisus.Phasellusaliquam,turpisquisviverracondimentum,risusestpretiummetus,inportaips umtortorvita elit.Pellentesqueidfinibuserat.Insuscipit,sapiennonposueredignissim,auguenisl ultricestortor,sitameteleifendnibhelitatrisus.`;
|
|
92
178
|
|
|
93
|
-
|
|
179
|
+
test('should return identical text that chunker was passed, given no spaces and small chunks(5)', t => {
|
|
94
180
|
const maxChunkToken = 5;
|
|
95
181
|
const chunks = getSemanticChunks(testTextNoSpaces, maxChunkToken);
|
|
96
|
-
|
|
97
|
-
|
|
182
|
+
t.assert(chunks.length > 0); //check chunking
|
|
183
|
+
t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
|
|
98
184
|
const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
|
|
99
|
-
|
|
185
|
+
t.assert(recomposedText === testTextNoSpaces); //check recomposition
|
|
100
186
|
});
|
|
101
187
|
|
|
102
188
|
const testTextShortWeirdSpaces=`Lorem ipsum dolor sit amet, consectetur adipiscing elit. In id erat sem. Phasellus ac dapibus purus, in fermentum nunc.............................. Mauris quis rutrum magna. Quisque rutrum, augue vel blandit posuere, augue magna convallis turpis, nec elementum augue mauris sit amet nunc. Aenean sit a;lksjdf 098098- -23 eln ;lkn l;kn09 oij[0u ,,,,,,,,,,,,,,,,,,,,, amet leo est. Nunc ante ex, blandit et felis ut, iaculis lacinia est. Phasellus dictum orci id libero ullamcorper tempor.
|
|
@@ -106,20 +192,20 @@ const testTextShortWeirdSpaces=`Lorem ipsum dolor sit amet, consectetur adipisci
|
|
|
106
192
|
|
|
107
193
|
Vivamus id pharetra odio. Sed consectetur leo sed tortor dictum venenatis.Donec gravida libero non accumsan suscipit.Donec lectus turpis, ullamcorper eu pulvinar iaculis, ornare ut risus.Phasellus aliquam, turpis quis viverra condimentum, risus est pretium metus, in porta ipsum tortor vitae elit.Pellentesque id finibus erat. In suscipit, sapien non posuere dignissim, augue nisl ultrices tortor, sit amet eleifend nibh elit at risus.`;
|
|
108
194
|
|
|
109
|
-
|
|
195
|
+
test('should return identical text that chunker was passed, given weird spaces and tiny chunks(1)', t => {
|
|
110
196
|
const maxChunkToken = 1;
|
|
111
197
|
const chunks = getSemanticChunks(testTextShortWeirdSpaces, maxChunkToken);
|
|
112
|
-
|
|
113
|
-
|
|
198
|
+
t.assert(chunks.length > 0); //check chunking
|
|
199
|
+
t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
|
|
114
200
|
const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
|
|
115
|
-
|
|
201
|
+
t.assert(recomposedText === testTextShortWeirdSpaces); //check recomposition
|
|
116
202
|
});
|
|
117
203
|
|
|
118
|
-
|
|
204
|
+
test('should return identical text that chunker was passed, given weird spaces and small chunks(10)', t => {
|
|
119
205
|
const maxChunkToken = 1;
|
|
120
206
|
const chunks = getSemanticChunks(testTextShortWeirdSpaces, maxChunkToken);
|
|
121
|
-
|
|
122
|
-
|
|
207
|
+
t.assert(chunks.length > 0); //check chunking
|
|
208
|
+
t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
|
|
123
209
|
const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
|
|
124
|
-
|
|
125
|
-
})
|
|
210
|
+
t.assert(recomposedText === testTextShortWeirdSpaces); //check recomposition
|
|
211
|
+
});
|
package/tests/mocks.js
CHANGED
|
@@ -36,4 +36,45 @@ export const mockConfig = {
|
|
|
36
36
|
{ role: 'assistant', content: 'Translating: {{{text}}}' },
|
|
37
37
|
],
|
|
38
38
|
}),
|
|
39
|
-
};
|
|
39
|
+
};
|
|
40
|
+
|
|
41
|
+
export const mockPathwayResolverString = {
|
|
42
|
+
model: {
|
|
43
|
+
url: 'https://api.example.com/testModel',
|
|
44
|
+
type: 'OPENAI-COMPLETION',
|
|
45
|
+
},
|
|
46
|
+
modelName: 'testModel',
|
|
47
|
+
pathway: mockPathwayString,
|
|
48
|
+
config: mockConfig,
|
|
49
|
+
prompt: new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.'),
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
export const mockPathwayResolverFunction = {
|
|
53
|
+
model: {
|
|
54
|
+
url: 'https://api.example.com/testModel',
|
|
55
|
+
type: 'OPENAI-COMPLETION',
|
|
56
|
+
},
|
|
57
|
+
modelName: 'testModel',
|
|
58
|
+
pathway: mockPathwayFunction,
|
|
59
|
+
config: mockConfig,
|
|
60
|
+
prompt: () => {
|
|
61
|
+
return new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.')
|
|
62
|
+
}
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
export const mockPathwayResolverMessages = {
|
|
66
|
+
model: {
|
|
67
|
+
url: 'https://api.example.com/testModel',
|
|
68
|
+
type: 'OPENAI-COMPLETION',
|
|
69
|
+
},
|
|
70
|
+
modelName: 'testModel',
|
|
71
|
+
pathway: mockPathwayMessages,
|
|
72
|
+
config: mockConfig,
|
|
73
|
+
prompt: new Prompt({
|
|
74
|
+
messages: [
|
|
75
|
+
{ role: 'user', content: 'Translate this: {{{text}}}' },
|
|
76
|
+
{ role: 'assistant', content: 'Translating: {{{text}}}' },
|
|
77
|
+
],
|
|
78
|
+
}),
|
|
79
|
+
};
|
|
80
|
+
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
import test from 'ava';
|
|
3
3
|
import ModelPlugin from '../server/plugins/modelPlugin.js';
|
|
4
4
|
import HandleBars from '../lib/handleBars.js';
|
|
5
|
-
import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages } from './mocks.js';
|
|
5
|
+
import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages, mockPathwayResolverString } from './mocks.js';
|
|
6
6
|
|
|
7
7
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
8
8
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
@@ -12,7 +12,7 @@ const config = mockConfig;
|
|
|
12
12
|
const pathway = mockPathwayString;
|
|
13
13
|
|
|
14
14
|
test('ModelPlugin constructor', (t) => {
|
|
15
|
-
const modelPlugin = new ModelPlugin(
|
|
15
|
+
const modelPlugin = new ModelPlugin(mockPathwayResolverString);
|
|
16
16
|
|
|
17
17
|
t.is(modelPlugin.modelName, pathway.model, 'modelName should be set from pathway');
|
|
18
18
|
t.deepEqual(modelPlugin.model, config.get('models')[pathway.model], 'model should be set from config');
|
|
@@ -21,7 +21,7 @@ test('ModelPlugin constructor', (t) => {
|
|
|
21
21
|
});
|
|
22
22
|
|
|
23
23
|
test.beforeEach((t) => {
|
|
24
|
-
t.context.modelPlugin = new ModelPlugin(
|
|
24
|
+
t.context.modelPlugin = new ModelPlugin(mockPathwayResolverString);
|
|
25
25
|
});
|
|
26
26
|
|
|
27
27
|
test('getCompiledPrompt - text and parameters', (t) => {
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
import test from 'ava';
|
|
2
2
|
import OpenAIChatPlugin from '../server/plugins/openAiChatPlugin.js';
|
|
3
|
-
import {
|
|
3
|
+
import { mockPathwayResolverMessages } from './mocks.js';
|
|
4
4
|
|
|
5
5
|
// Test the constructor
|
|
6
6
|
test('constructor', (t) => {
|
|
7
|
-
const plugin = new OpenAIChatPlugin(
|
|
8
|
-
t.is(plugin.config,
|
|
9
|
-
t.is(plugin.pathwayPrompt,
|
|
7
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
8
|
+
t.is(plugin.config, mockPathwayResolverMessages.config);
|
|
9
|
+
t.is(plugin.pathwayPrompt, mockPathwayResolverMessages.pathway.prompt);
|
|
10
10
|
});
|
|
11
11
|
|
|
12
12
|
// Test the convertPalmToOpenAIMessages function
|
|
13
13
|
test('convertPalmToOpenAIMessages', (t) => {
|
|
14
|
-
const plugin = new OpenAIChatPlugin(
|
|
14
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
15
15
|
const context = 'This is a test context.';
|
|
16
16
|
const examples = [
|
|
17
17
|
{
|
|
@@ -35,14 +35,21 @@ test('convertPalmToOpenAIMessages', (t) => {
|
|
|
35
35
|
|
|
36
36
|
// Test the getRequestParameters function
|
|
37
37
|
test('getRequestParameters', async (t) => {
|
|
38
|
-
const plugin = new OpenAIChatPlugin(
|
|
38
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
39
39
|
const text = 'Help me';
|
|
40
40
|
const parameters = { name: 'John', age: 30 };
|
|
41
|
-
const prompt =
|
|
41
|
+
const prompt = mockPathwayResolverMessages.pathway.prompt;
|
|
42
42
|
const result = await plugin.getRequestParameters(text, parameters, prompt);
|
|
43
43
|
t.deepEqual(result, {
|
|
44
44
|
messages: [
|
|
45
|
-
{
|
|
45
|
+
{
|
|
46
|
+
content: 'Translate this: Help me',
|
|
47
|
+
role: 'user',
|
|
48
|
+
},
|
|
49
|
+
{
|
|
50
|
+
content: 'Translating: Help me',
|
|
51
|
+
role: 'assistant',
|
|
52
|
+
},
|
|
46
53
|
],
|
|
47
54
|
temperature: 0.7,
|
|
48
55
|
});
|
|
@@ -50,10 +57,10 @@ test('getRequestParameters', async (t) => {
|
|
|
50
57
|
|
|
51
58
|
// Test the execute function
|
|
52
59
|
test('execute', async (t) => {
|
|
53
|
-
const plugin = new OpenAIChatPlugin(
|
|
60
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
54
61
|
const text = 'Help me';
|
|
55
62
|
const parameters = { name: 'John', age: 30 };
|
|
56
|
-
const prompt =
|
|
63
|
+
const prompt = mockPathwayResolverMessages.pathway.prompt;
|
|
57
64
|
|
|
58
65
|
// Mock the executeRequest function
|
|
59
66
|
plugin.executeRequest = () => {
|
|
@@ -82,7 +89,7 @@ test('execute', async (t) => {
|
|
|
82
89
|
|
|
83
90
|
// Test the parseResponse function
|
|
84
91
|
test('parseResponse', (t) => {
|
|
85
|
-
const plugin = new OpenAIChatPlugin(
|
|
92
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
86
93
|
const data = {
|
|
87
94
|
choices: [
|
|
88
95
|
{
|
|
@@ -98,7 +105,7 @@ test('parseResponse', (t) => {
|
|
|
98
105
|
|
|
99
106
|
// Test the logRequestData function
|
|
100
107
|
test('logRequestData', (t) => {
|
|
101
|
-
const plugin = new OpenAIChatPlugin(
|
|
108
|
+
const plugin = new OpenAIChatPlugin(mockPathwayResolverMessages);
|
|
102
109
|
const data = {
|
|
103
110
|
messages: [
|
|
104
111
|
{ role: 'user', content: 'User: Help me\nAssistant: Please help John who is 30 years old.' },
|
|
@@ -113,7 +120,7 @@ test('logRequestData', (t) => {
|
|
|
113
120
|
},
|
|
114
121
|
],
|
|
115
122
|
};
|
|
116
|
-
const prompt =
|
|
123
|
+
const prompt = mockPathwayResolverMessages.pathway.prompt;
|
|
117
124
|
|
|
118
125
|
// Mock console.log function
|
|
119
126
|
const originalConsoleLog = console.log;
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
// test_palmChatPlugin.js
|
|
2
2
|
import test from 'ava';
|
|
3
3
|
import PalmChatPlugin from '../server/plugins/palmChatPlugin.js';
|
|
4
|
-
import {
|
|
4
|
+
import { mockPathwayResolverMessages } from './mocks.js';
|
|
5
5
|
|
|
6
6
|
test.beforeEach((t) => {
|
|
7
|
-
const
|
|
8
|
-
const palmChatPlugin = new PalmChatPlugin(mockConfig, pathway);
|
|
7
|
+
const palmChatPlugin = new PalmChatPlugin(mockPathwayResolverMessages);
|
|
9
8
|
t.context = { palmChatPlugin };
|
|
10
9
|
});
|
|
11
10
|
|
|
@@ -2,11 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
import test from 'ava';
|
|
4
4
|
import PalmCompletionPlugin from '../server/plugins/palmCompletionPlugin.js';
|
|
5
|
-
import {
|
|
5
|
+
import { mockPathwayResolverString } from './mocks.js';
|
|
6
6
|
|
|
7
7
|
test.beforeEach((t) => {
|
|
8
|
-
const
|
|
9
|
-
const palmCompletionPlugin = new PalmCompletionPlugin(mockConfig, pathway);
|
|
8
|
+
const palmCompletionPlugin = new PalmCompletionPlugin(mockPathwayResolverString);
|
|
10
9
|
t.context = { palmCompletionPlugin };
|
|
11
10
|
});
|
|
12
11
|
|
|
@@ -2,12 +2,11 @@
|
|
|
2
2
|
import test from 'ava';
|
|
3
3
|
import ModelPlugin from '../server/plugins/modelPlugin.js';
|
|
4
4
|
import { encode } from 'gpt-3-encoder';
|
|
5
|
-
import {
|
|
5
|
+
import { mockPathwayResolverString } from './mocks.js';
|
|
6
6
|
|
|
7
|
-
const config =
|
|
8
|
-
const pathway = mockPathwayString;
|
|
7
|
+
const { config, pathway } = mockPathwayResolverString;
|
|
9
8
|
|
|
10
|
-
const modelPlugin = new ModelPlugin(
|
|
9
|
+
const modelPlugin = new ModelPlugin(mockPathwayResolverString);
|
|
11
10
|
|
|
12
11
|
const generateMessage = (role, content) => ({ role, content });
|
|
13
12
|
|
package/tests/server.js
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
import 'dotenv/config'
|
|
2
|
-
import { ApolloServer } from 'apollo-server';
|
|
3
|
-
import { config } from '../config.js';
|
|
4
|
-
import typeDefsresolversFactory from '../index.js';
|
|
5
|
-
|
|
6
|
-
let typeDefs;
|
|
7
|
-
let resolvers;
|
|
8
|
-
|
|
9
|
-
const initTypeDefsResolvers = async () => {
|
|
10
|
-
const result = await typeDefsresolversFactory();
|
|
11
|
-
typeDefs = result.typeDefs;
|
|
12
|
-
resolvers = result.resolvers;
|
|
13
|
-
};
|
|
14
|
-
|
|
15
|
-
export const startTestServer = async () => {
|
|
16
|
-
await initTypeDefsResolvers();
|
|
17
|
-
|
|
18
|
-
return new ApolloServer({
|
|
19
|
-
typeDefs,
|
|
20
|
-
resolvers,
|
|
21
|
-
context: () => ({ config, requestState: {} }),
|
|
22
|
-
});
|
|
23
|
-
};
|