@aj-archipelago/cortex 1.0.2 → 1.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.eslintrc +1 -1
- package/README.md +8 -6
- package/config.js +7 -2
- package/graphql/parser.js +6 -0
- package/graphql/pathwayPrompter.js +2 -17
- package/graphql/pathwayResolver.js +10 -8
- package/graphql/pathwayResponseParser.js +13 -4
- package/graphql/plugins/modelPlugin.js +27 -18
- package/graphql/plugins/openAiCompletionPlugin.js +29 -12
- package/graphql/plugins/openAiWhisperPlugin.js +112 -19
- package/helper_apps/MediaFileChunker/blobHandler.js +150 -0
- package/helper_apps/MediaFileChunker/fileChunker.js +123 -0
- package/helper_apps/MediaFileChunker/function.json +20 -0
- package/helper_apps/MediaFileChunker/helper.js +33 -0
- package/helper_apps/MediaFileChunker/index.js +116 -0
- package/helper_apps/MediaFileChunker/localFileHandler.js +36 -0
- package/helper_apps/MediaFileChunker/package-lock.json +2919 -0
- package/helper_apps/MediaFileChunker/package.json +22 -0
- package/helper_apps/MediaFileChunker/redis.js +32 -0
- package/helper_apps/MediaFileChunker/start.js +27 -0
- package/lib/handleBars.js +26 -0
- package/lib/pathwayTools.js +15 -0
- package/lib/redisSubscription.js +51 -0
- package/lib/request.js +4 -4
- package/package.json +5 -6
- package/pathways/transcribe.js +2 -1
- package/tests/config.test.js +69 -0
- package/tests/handleBars.test.js +43 -0
- package/tests/mocks.js +39 -0
- package/tests/modelPlugin.test.js +129 -0
- package/tests/pathwayResolver.test.js +77 -0
- package/tests/truncateMessages.test.js +99 -0
- package/lib/fileChunker.js +0 -147
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@aj-archipelago/mediafilechunker",
|
|
3
|
+
"version": "1.0.0",
|
|
4
|
+
"description": "",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"scripts": {
|
|
7
|
+
"start": "node start.js",
|
|
8
|
+
"test": "echo \"No tests yet...\""
|
|
9
|
+
},
|
|
10
|
+
"dependencies": {
|
|
11
|
+
"@azure/storage-blob": "^12.13.0",
|
|
12
|
+
"@ffmpeg-installer/ffmpeg": "^1.1.0",
|
|
13
|
+
"@ffprobe-installer/ffprobe": "^2.0.0",
|
|
14
|
+
"busboy": "^1.6.0",
|
|
15
|
+
"express": "^4.18.2",
|
|
16
|
+
"fluent-ffmpeg": "^2.1.2",
|
|
17
|
+
"ioredis": "^5.3.1",
|
|
18
|
+
"public-ip": "^6.0.1",
|
|
19
|
+
"uuid": "^9.0.0",
|
|
20
|
+
"ytdl-core": "git+ssh://git@github.com:khlevon/node-ytdl-core.git#v4.11.3-patch.1"
|
|
21
|
+
}
|
|
22
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import redis from 'ioredis';
|
|
2
|
+
const connectionString = process.env["REDIS_CONNECTION_STRING"];
|
|
3
|
+
const client = redis.createClient(connectionString);
|
|
4
|
+
// client.connect();
|
|
5
|
+
|
|
6
|
+
const channel = 'requestProgress';
|
|
7
|
+
|
|
8
|
+
const connectClient = async () => {
|
|
9
|
+
if (!client.connected) {
|
|
10
|
+
try {
|
|
11
|
+
await client.connect();
|
|
12
|
+
} catch (error) {
|
|
13
|
+
console.error(`Error reconnecting to Redis: ${error}`);
|
|
14
|
+
return;
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
const publishRequestProgress = async (data) => {
|
|
20
|
+
// await connectClient();
|
|
21
|
+
try {
|
|
22
|
+
const message = JSON.stringify(data);
|
|
23
|
+
console.log(`Publishing message ${message} to channel ${channel}`);
|
|
24
|
+
await client.publish(channel, message);
|
|
25
|
+
} catch (error) {
|
|
26
|
+
console.error(`Error publishing message: ${error}`);
|
|
27
|
+
}
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
export {
|
|
31
|
+
publishRequestProgress, connectClient
|
|
32
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import MediaFileChunker from "./index.js";
|
|
2
|
+
import express from "express";
|
|
3
|
+
import { fileURLToPath } from 'url';
|
|
4
|
+
import { dirname, join } from 'path';
|
|
5
|
+
|
|
6
|
+
import { publicIpv4 } from 'public-ip';
|
|
7
|
+
const ipAddress = await publicIpv4();
|
|
8
|
+
|
|
9
|
+
const app = express();
|
|
10
|
+
const port = process.env.PORT || 7071;
|
|
11
|
+
const publicFolder = join(dirname(fileURLToPath(import.meta.url)), 'files');
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
// Serve static files from the public folder
|
|
15
|
+
app.use('/files', express.static(publicFolder));
|
|
16
|
+
|
|
17
|
+
app.all('/api/MediaFileChunker', async (req, res) => {
|
|
18
|
+
const context = { req, res, log: console.log }
|
|
19
|
+
await MediaFileChunker(context, req);
|
|
20
|
+
res.send(context.res.body);
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
app.listen(port, () => {
|
|
24
|
+
console.log(`MediaFileChunker helper running on port ${port}`);
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
export { port, publicFolder, ipAddress };
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
// handleBars.js
|
|
2
|
+
|
|
3
|
+
import HandleBars from 'handlebars';
|
|
4
|
+
|
|
5
|
+
// register functions that can be called directly in the prompt markdown
|
|
6
|
+
HandleBars.registerHelper('stripHTML', function (value) {
|
|
7
|
+
return value.replace(/<[^>]*>/g, '');
|
|
8
|
+
});
|
|
9
|
+
|
|
10
|
+
HandleBars.registerHelper('now', function () {
|
|
11
|
+
return new Date().toISOString();
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
HandleBars.registerHelper('toJSON', function (object) {
|
|
15
|
+
return JSON.stringify(object);
|
|
16
|
+
});
|
|
17
|
+
|
|
18
|
+
HandleBars.registerHelper('ctoW', function (value) {
|
|
19
|
+
// if value is not a number, return it
|
|
20
|
+
if (isNaN(value)) {
|
|
21
|
+
return value;
|
|
22
|
+
}
|
|
23
|
+
return Math.round(value / 6.6);
|
|
24
|
+
});
|
|
25
|
+
|
|
26
|
+
export default HandleBars;
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
// pathwayTools.js
|
|
2
|
+
|
|
3
|
+
// callPathway - call a pathway from another pathway
|
|
4
|
+
const callPathway = async (config, pathwayName, args) => {
|
|
5
|
+
const pathway = config.get(`pathways.${pathwayName}`);
|
|
6
|
+
if (!pathway) {
|
|
7
|
+
throw new Error(`Pathway ${pathwayName} not found`);
|
|
8
|
+
}
|
|
9
|
+
const requestState = {};
|
|
10
|
+
const parent = {};
|
|
11
|
+
const data = await pathway.rootResolver(parent, args, { config, pathway, requestState } );
|
|
12
|
+
return data?.result;
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
export { callPathway };
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import Redis from 'ioredis';
|
|
2
|
+
import { config } from '../config.js';
|
|
3
|
+
import pubsub from '../graphql/pubsub.js';
|
|
4
|
+
|
|
5
|
+
const connectionString = config.get('storageConnectionString');
|
|
6
|
+
const client = new Redis(connectionString);
|
|
7
|
+
|
|
8
|
+
const channel = 'requestProgress';
|
|
9
|
+
|
|
10
|
+
client.on('error', (error) => {
|
|
11
|
+
console.error(`Redis client error: ${error}`);
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
client.on('connect', () => {
|
|
15
|
+
client.subscribe(channel, (error) => {
|
|
16
|
+
if (error) {
|
|
17
|
+
console.error(`Error subscribing to channel ${channel}: ${error}`);
|
|
18
|
+
} else {
|
|
19
|
+
console.log(`Subscribed to channel ${channel}`);
|
|
20
|
+
}
|
|
21
|
+
});
|
|
22
|
+
});
|
|
23
|
+
|
|
24
|
+
client.on('message', (channel, message) => {
|
|
25
|
+
if (channel === 'requestProgress') {
|
|
26
|
+
console.log(`Received message from ${channel}: ${message}`);
|
|
27
|
+
let parsedMessage;
|
|
28
|
+
|
|
29
|
+
try {
|
|
30
|
+
parsedMessage = JSON.parse(message);
|
|
31
|
+
} catch (error) {
|
|
32
|
+
parsedMessage = message;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
handleMessage(parsedMessage);
|
|
36
|
+
}
|
|
37
|
+
});
|
|
38
|
+
|
|
39
|
+
const handleMessage = (data) => {
|
|
40
|
+
// Process the received data
|
|
41
|
+
console.log('Processing data:', data);
|
|
42
|
+
try {
|
|
43
|
+
pubsub.publish('REQUEST_PROGRESS', { requestProgress: data });
|
|
44
|
+
} catch (error) {
|
|
45
|
+
console.error(`Error publishing data to pubsub: ${error}`);
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
|
|
49
|
+
export {
|
|
50
|
+
client as subscriptionClient,
|
|
51
|
+
};
|
package/lib/request.js
CHANGED
|
@@ -80,8 +80,8 @@ const postRequest = async ({ url, data, params, headers, cache }, model) => {
|
|
|
80
80
|
}
|
|
81
81
|
return await limiters[model].schedule(() => postWithMonitor(model, url, data, axiosConfigObj));
|
|
82
82
|
} catch (e) {
|
|
83
|
-
console.error(`Failed request with data ${JSON.stringify(data)}: ${e}`);
|
|
84
|
-
if (e.response
|
|
83
|
+
console.error(`Failed request with data ${JSON.stringify(data)}: ${e} - ${e.response?.data?.error?.type || 'error'}: ${e.response?.data?.error?.message}`);
|
|
84
|
+
if (e.response?.status && e.response?.status === 429) {
|
|
85
85
|
monitors[model].incrementError429Count();
|
|
86
86
|
}
|
|
87
87
|
errors.push(e);
|
|
@@ -94,7 +94,7 @@ const request = async (params, model) => {
|
|
|
94
94
|
const response = await postRequest(params, model);
|
|
95
95
|
const { error, data, cached } = response;
|
|
96
96
|
if (cached) {
|
|
97
|
-
console.info('
|
|
97
|
+
console.info('=== Request served with cached response. ===');
|
|
98
98
|
}
|
|
99
99
|
if (error && error.length > 0) {
|
|
100
100
|
const lastError = error[error.length - 1];
|
|
@@ -105,5 +105,5 @@ const request = async (params, model) => {
|
|
|
105
105
|
}
|
|
106
106
|
|
|
107
107
|
export {
|
|
108
|
-
request, postRequest, buildLimiters
|
|
108
|
+
axios,request, postRequest, buildLimiters
|
|
109
109
|
};
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aj-archipelago/cortex",
|
|
3
|
-
"version": "1.0.
|
|
3
|
+
"version": "1.0.3",
|
|
4
4
|
"description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
|
|
5
5
|
"repository": {
|
|
6
6
|
"type": "git",
|
|
@@ -29,7 +29,6 @@
|
|
|
29
29
|
"homepage": "https://github.com/aj-archipelago/cortex#readme",
|
|
30
30
|
"dependencies": {
|
|
31
31
|
"@apollo/utils.keyvadapter": "^1.1.2",
|
|
32
|
-
"@ffmpeg-installer/ffmpeg": "^1.1.0",
|
|
33
32
|
"@graphql-tools/schema": "^9.0.12",
|
|
34
33
|
"@keyv/redis": "^2.5.4",
|
|
35
34
|
"apollo-server": "^3.12.0",
|
|
@@ -43,24 +42,24 @@
|
|
|
43
42
|
"compromise-paragraphs": "^0.1.0",
|
|
44
43
|
"convict": "^6.2.3",
|
|
45
44
|
"express": "^4.18.2",
|
|
46
|
-
"fluent-ffmpeg": "^2.1.2",
|
|
47
45
|
"form-data": "^4.0.0",
|
|
48
46
|
"gpt-3-encoder": "^1.1.4",
|
|
49
47
|
"graphql": "^16.6.0",
|
|
50
48
|
"graphql-subscriptions": "^2.0.0",
|
|
51
49
|
"graphql-ws": "^5.11.2",
|
|
52
50
|
"handlebars": "^4.7.7",
|
|
51
|
+
"ioredis": "^5.3.1",
|
|
53
52
|
"keyv": "^4.5.2",
|
|
54
53
|
"langchain": "^0.0.47",
|
|
55
54
|
"uuid": "^9.0.0",
|
|
56
|
-
"ws": "^8.12.0"
|
|
57
|
-
"ytdl-core": "^4.11.2"
|
|
55
|
+
"ws": "^8.12.0"
|
|
58
56
|
},
|
|
59
57
|
"devDependencies": {
|
|
60
58
|
"ava": "^5.2.0",
|
|
61
59
|
"dotenv": "^16.0.3",
|
|
62
60
|
"eslint": "^8.38.0",
|
|
63
|
-
"eslint-plugin-import": "^2.27.5"
|
|
61
|
+
"eslint-plugin-import": "^2.27.5",
|
|
62
|
+
"sinon": "^15.0.3"
|
|
64
63
|
},
|
|
65
64
|
"publishConfig": {
|
|
66
65
|
"access": "restricted"
|
package/pathways/transcribe.js
CHANGED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
// config.test.js
|
|
2
|
+
|
|
3
|
+
import test from 'ava';
|
|
4
|
+
import path from 'path';
|
|
5
|
+
import { config, buildPathways, buildModels } from '../config.js';
|
|
6
|
+
|
|
7
|
+
test.before(async () => {
|
|
8
|
+
await buildPathways(config);
|
|
9
|
+
buildModels(config);
|
|
10
|
+
});
|
|
11
|
+
|
|
12
|
+
test('config pathwaysPath', (t) => {
|
|
13
|
+
const expectedDefault = path.join(process.cwd(), '/pathways');
|
|
14
|
+
t.is(config.get('pathwaysPath'), expectedDefault);
|
|
15
|
+
});
|
|
16
|
+
|
|
17
|
+
test('config corePathwaysPath', (t) => {
|
|
18
|
+
const expectedPath = path.join(path.dirname(new URL(import.meta.url).pathname), '..', 'pathways');
|
|
19
|
+
t.is(config.get('corePathwaysPath'), expectedPath);
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
test('config basePathwayPath', (t) => {
|
|
23
|
+
const expectedPath = path.join(path.dirname(new URL(import.meta.url).pathname), '..', 'pathways', 'basePathway.js');
|
|
24
|
+
t.is(config.get('basePathwayPath'), expectedPath);
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
test('config PORT', (t) => {
|
|
28
|
+
const expectedDefault = 4000;
|
|
29
|
+
t.is(config.get('PORT'), expectedDefault);
|
|
30
|
+
});
|
|
31
|
+
|
|
32
|
+
test('config enableCache', (t) => {
|
|
33
|
+
const expectedDefault = true;
|
|
34
|
+
t.is(config.get('enableCache'), expectedDefault);
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
test('config enableGraphqlCache', (t) => {
|
|
38
|
+
const expectedDefault = false;
|
|
39
|
+
t.is(config.get('enableGraphqlCache'), expectedDefault);
|
|
40
|
+
});
|
|
41
|
+
|
|
42
|
+
test('config enableRestEndpoints', (t) => {
|
|
43
|
+
const expectedDefault = false;
|
|
44
|
+
t.is(config.get('enableRestEndpoints'), expectedDefault);
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
test('config openaiDefaultModel', (t) => {
|
|
48
|
+
const expectedDefault = 'text-davinci-003';
|
|
49
|
+
t.is(config.get('openaiDefaultModel'), expectedDefault);
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
test('config openaiApiUrl', (t) => {
|
|
53
|
+
const expectedDefault = 'https://api.openai.com/v1/completions';
|
|
54
|
+
t.is(config.get('openaiApiUrl'), expectedDefault);
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
test('buildPathways adds pathways to config', (t) => {
|
|
58
|
+
const pathways = config.get('pathways');
|
|
59
|
+
t.true(Object.keys(pathways).length > 0);
|
|
60
|
+
});
|
|
61
|
+
|
|
62
|
+
test('buildModels adds models to config', (t) => {
|
|
63
|
+
const models = config.get('models');
|
|
64
|
+
t.true(Object.keys(models).length > 0);
|
|
65
|
+
});
|
|
66
|
+
|
|
67
|
+
test('buildModels sets defaultModelName if not provided', (t) => {
|
|
68
|
+
t.truthy(config.get('defaultModelName'));
|
|
69
|
+
});
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
// handleBars.test.js
|
|
2
|
+
|
|
3
|
+
import test from 'ava';
|
|
4
|
+
import HandleBars from '../lib/handleBars.js';
|
|
5
|
+
|
|
6
|
+
test('stripHTML', (t) => {
|
|
7
|
+
const stringWithHTML = '<h1>Hello, World!</h1>';
|
|
8
|
+
const expectedResult = 'Hello, World!';
|
|
9
|
+
|
|
10
|
+
const result = HandleBars.helpers.stripHTML(stringWithHTML);
|
|
11
|
+
t.is(result, expectedResult);
|
|
12
|
+
});
|
|
13
|
+
|
|
14
|
+
test('now', (t) => {
|
|
15
|
+
const expectedResult = new Date().toISOString();
|
|
16
|
+
|
|
17
|
+
const result = HandleBars.helpers.now();
|
|
18
|
+
t.is(result.slice(0, 10), expectedResult.slice(0, 10)); // Comparing only the date part
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
test('toJSON', (t) => {
|
|
22
|
+
const object = { key: 'value' };
|
|
23
|
+
const expectedResult = '{"key":"value"}';
|
|
24
|
+
|
|
25
|
+
const result = HandleBars.helpers.toJSON(object);
|
|
26
|
+
t.is(result, expectedResult);
|
|
27
|
+
});
|
|
28
|
+
|
|
29
|
+
test('ctoW', (t) => {
|
|
30
|
+
const value = 66;
|
|
31
|
+
const expectedResult = 10;
|
|
32
|
+
|
|
33
|
+
const result = HandleBars.helpers.ctoW(value);
|
|
34
|
+
t.is(result, expectedResult);
|
|
35
|
+
});
|
|
36
|
+
|
|
37
|
+
test('ctoW non-numeric', (t) => {
|
|
38
|
+
const value = 'Hello, World!';
|
|
39
|
+
const expectedResult = 'Hello, World!';
|
|
40
|
+
|
|
41
|
+
const result = HandleBars.helpers.ctoW(value);
|
|
42
|
+
t.is(result, expectedResult);
|
|
43
|
+
});
|
package/tests/mocks.js
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import { Prompt } from '../graphql/prompt.js';
|
|
2
|
+
|
|
3
|
+
export const mockConfig = {
|
|
4
|
+
get: (key) => {
|
|
5
|
+
const configs = {
|
|
6
|
+
defaultModelName: 'testModel',
|
|
7
|
+
models: {
|
|
8
|
+
testModel: {
|
|
9
|
+
url: 'https://api.example.com/testModel',
|
|
10
|
+
type: 'OPENAI-COMPLETION',
|
|
11
|
+
},
|
|
12
|
+
},
|
|
13
|
+
};
|
|
14
|
+
return configs[key];
|
|
15
|
+
},
|
|
16
|
+
getEnv: () => ({}),
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
export const mockPathwayString = {
|
|
20
|
+
model: 'testModel',
|
|
21
|
+
prompt: new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.'),
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
export const mockPathwayFunction = {
|
|
25
|
+
model: 'testModel',
|
|
26
|
+
prompt: () => {
|
|
27
|
+
return new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.')
|
|
28
|
+
},
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
export const mockPathwayMessages = {
|
|
32
|
+
model: 'testModel',
|
|
33
|
+
prompt: new Prompt({
|
|
34
|
+
messages: [
|
|
35
|
+
{ role: 'user', content: 'Translate this: {{{text}}}' },
|
|
36
|
+
{ role: 'assistant', content: 'Translating: {{{text}}}' },
|
|
37
|
+
],
|
|
38
|
+
}),
|
|
39
|
+
};
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
// test/ModelPlugin.test.js
|
|
2
|
+
import test from 'ava';
|
|
3
|
+
import ModelPlugin from '../graphql/plugins/modelPlugin.js';
|
|
4
|
+
import HandleBars from '../lib/handleBars.js';
|
|
5
|
+
import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages } from './mocks.js';
|
|
6
|
+
|
|
7
|
+
const DEFAULT_MAX_TOKENS = 4096;
|
|
8
|
+
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
9
|
+
|
|
10
|
+
// Mock configuration and pathway objects
|
|
11
|
+
const config = mockConfig;
|
|
12
|
+
const pathway = mockPathwayString;
|
|
13
|
+
|
|
14
|
+
test('ModelPlugin constructor', (t) => {
|
|
15
|
+
const modelPlugin = new ModelPlugin(config, pathway);
|
|
16
|
+
|
|
17
|
+
t.is(modelPlugin.modelName, pathway.model, 'modelName should be set from pathway');
|
|
18
|
+
t.deepEqual(modelPlugin.model, config.get('models')[pathway.model], 'model should be set from config');
|
|
19
|
+
t.is(modelPlugin.temperature, pathway.temperature, 'temperature should be set from pathway');
|
|
20
|
+
t.is(modelPlugin.pathwayPrompt, pathway.prompt, 'pathwayPrompt should be set from pathway');
|
|
21
|
+
});
|
|
22
|
+
|
|
23
|
+
test.beforeEach((t) => {
|
|
24
|
+
t.context.modelPlugin = new ModelPlugin(mockConfig, mockPathwayString);
|
|
25
|
+
});
|
|
26
|
+
|
|
27
|
+
test('getCompiledPrompt - text and parameters', (t) => {
|
|
28
|
+
const { modelPlugin } = t.context;
|
|
29
|
+
const text = 'Hello, World!';
|
|
30
|
+
const parameters = { name: 'John', age: 30 };
|
|
31
|
+
|
|
32
|
+
const { modelPromptText, tokenLength } = modelPlugin.getCompiledPrompt(text, parameters, pathway.prompt);
|
|
33
|
+
|
|
34
|
+
t.true(modelPromptText.includes(text));
|
|
35
|
+
t.true(modelPromptText.includes(parameters.name));
|
|
36
|
+
t.true(modelPromptText.includes(parameters.age.toString()));
|
|
37
|
+
t.is(typeof tokenLength, 'number');
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
test('getCompiledPrompt - custom prompt function', (t) => {
|
|
41
|
+
const { modelPlugin } = t.context;
|
|
42
|
+
const text = 'Hello, World!';
|
|
43
|
+
const parameters = { name: 'John', age: 30 };
|
|
44
|
+
|
|
45
|
+
const { modelPromptText, tokenLength } = modelPlugin.getCompiledPrompt(text, parameters, mockPathwayFunction.prompt);
|
|
46
|
+
|
|
47
|
+
t.true(modelPromptText.includes(text));
|
|
48
|
+
t.true(modelPromptText.includes(parameters.name));
|
|
49
|
+
t.true(modelPromptText.includes(parameters.age.toString()));
|
|
50
|
+
t.is(typeof tokenLength, 'number');
|
|
51
|
+
});
|
|
52
|
+
|
|
53
|
+
test('getCompiledPrompt - model prompt messages', (t) => {
|
|
54
|
+
const { modelPlugin } = t.context;
|
|
55
|
+
const text = 'Translate the following text to French: "Hello, World!"';
|
|
56
|
+
const parameters = {}
|
|
57
|
+
|
|
58
|
+
const { modelPromptMessages, tokenLength } = modelPlugin.getCompiledPrompt(text, parameters, mockPathwayMessages.prompt);
|
|
59
|
+
|
|
60
|
+
t.true(modelPromptMessages[0].content.includes(text));
|
|
61
|
+
t.true(modelPromptMessages[1].content.includes(text));
|
|
62
|
+
t.is(typeof tokenLength, 'number');
|
|
63
|
+
});
|
|
64
|
+
|
|
65
|
+
test('getModelMaxTokenLength', (t) => {
|
|
66
|
+
const { modelPlugin } = t.context;
|
|
67
|
+
t.is(modelPlugin.getModelMaxTokenLength(), DEFAULT_MAX_TOKENS, 'getModelMaxTokenLength should return default max tokens');
|
|
68
|
+
});
|
|
69
|
+
|
|
70
|
+
test('getPromptTokenRatio', (t) => {
|
|
71
|
+
const { modelPlugin } = t.context;
|
|
72
|
+
t.is(modelPlugin.getPromptTokenRatio(), DEFAULT_PROMPT_TOKEN_RATIO, 'getPromptTokenRatio should return default prompt token ratio');
|
|
73
|
+
});
|
|
74
|
+
|
|
75
|
+
test('requestUrl', (t) => {
|
|
76
|
+
const { modelPlugin } = t.context;
|
|
77
|
+
|
|
78
|
+
const expectedUrl = HandleBars.compile(modelPlugin.model.url)({ ...modelPlugin.model, ...config.getEnv(), ...config });
|
|
79
|
+
t.is(modelPlugin.requestUrl(), expectedUrl, 'requestUrl should return the correct URL');
|
|
80
|
+
});
|
|
81
|
+
|
|
82
|
+
test('parseResponse - single choice', (t) => {
|
|
83
|
+
const { modelPlugin } = t.context;
|
|
84
|
+
const singleChoiceResponse = {
|
|
85
|
+
choices: [{
|
|
86
|
+
text: '42'
|
|
87
|
+
}]
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
const result = modelPlugin.parseResponse(singleChoiceResponse);
|
|
91
|
+
t.is(result, '42', 'parseResponse should return the correct value for a single choice response');
|
|
92
|
+
});
|
|
93
|
+
|
|
94
|
+
test('parseResponse - multiple choices', (t) => {
|
|
95
|
+
const { modelPlugin } = t.context;
|
|
96
|
+
const multipleChoicesResponse = {
|
|
97
|
+
choices: [
|
|
98
|
+
{ text: '42' },
|
|
99
|
+
{ text: 'life' }
|
|
100
|
+
]
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
const result = modelPlugin.parseResponse(multipleChoicesResponse);
|
|
104
|
+
t.deepEqual(result, multipleChoicesResponse.choices, 'parseResponse should return the choices array for multiple choices response');
|
|
105
|
+
});
|
|
106
|
+
|
|
107
|
+
test('truncateMessagesToTargetLength', (t) => {
|
|
108
|
+
const { modelPlugin } = t.context;
|
|
109
|
+
const messages = [
|
|
110
|
+
{ role: 'user', content: 'What is the meaning of life?' },
|
|
111
|
+
{ role: 'assistant', content: 'The meaning of life is a philosophical question regarding the purpose and significance of life or existence in general.' }
|
|
112
|
+
];
|
|
113
|
+
const targetTokenLength = 10;
|
|
114
|
+
|
|
115
|
+
const result = modelPlugin.truncateMessagesToTargetLength(messages, targetTokenLength);
|
|
116
|
+
t.true(Array.isArray(result), 'truncateMessagesToTargetLength should return an array');
|
|
117
|
+
t.true(result.length <= messages.length, 'truncateMessagesToTargetLength should not return more messages than the input');
|
|
118
|
+
});
|
|
119
|
+
|
|
120
|
+
test('messagesToChatML', (t) => {
|
|
121
|
+
const { modelPlugin } = t.context;
|
|
122
|
+
const messages = [
|
|
123
|
+
{ role: 'user', content: 'What is the meaning of life?' },
|
|
124
|
+
{ role: 'assistant', content: 'The meaning of life is a philosophical question regarding the purpose and significance of life or existence in general.' }
|
|
125
|
+
];
|
|
126
|
+
|
|
127
|
+
const result = modelPlugin.messagesToChatML(messages);
|
|
128
|
+
t.is(typeof result, 'string', 'messagesToChatML should return a string');
|
|
129
|
+
});
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import test from 'ava';
|
|
2
|
+
import { PathwayResolver } from '../graphql/pathwayResolver.js';
|
|
3
|
+
import sinon from 'sinon';
|
|
4
|
+
import { mockConfig, mockPathwayString } from './mocks.js';
|
|
5
|
+
|
|
6
|
+
const mockPathway = mockPathwayString;
|
|
7
|
+
mockPathway.useInputChunking = false;
|
|
8
|
+
mockPathway.prompt = 'What is AI?';
|
|
9
|
+
|
|
10
|
+
const mockArgs = {
|
|
11
|
+
text: 'Artificial intelligence',
|
|
12
|
+
};
|
|
13
|
+
|
|
14
|
+
test.beforeEach((t) => {
|
|
15
|
+
t.context.pathwayResolver = new PathwayResolver({
|
|
16
|
+
config: mockConfig,
|
|
17
|
+
pathway: mockPathway,
|
|
18
|
+
args: mockArgs,
|
|
19
|
+
});
|
|
20
|
+
});
|
|
21
|
+
|
|
22
|
+
test('constructor initializes properties correctly', (t) => {
|
|
23
|
+
const resolver = t.context.pathwayResolver;
|
|
24
|
+
t.deepEqual(resolver.config, mockConfig);
|
|
25
|
+
t.deepEqual(resolver.pathway, mockPathway);
|
|
26
|
+
t.deepEqual(resolver.args, mockArgs);
|
|
27
|
+
t.is(resolver.useInputChunking, mockPathway.useInputChunking);
|
|
28
|
+
t.is(typeof resolver.requestId, 'string');
|
|
29
|
+
});
|
|
30
|
+
|
|
31
|
+
test('resolve returns request id when async is true', async (t) => {
|
|
32
|
+
const resolver = t.context.pathwayResolver;
|
|
33
|
+
const requestId = await resolver.resolve({ ...mockArgs, async: true });
|
|
34
|
+
t.is(typeof requestId, 'string');
|
|
35
|
+
t.is(requestId, resolver.requestId);
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
test('resolve calls promptAndParse when async is false', async (t) => {
|
|
39
|
+
const resolver = t.context.pathwayResolver;
|
|
40
|
+
const promptAndParseStub = sinon.stub(resolver, 'promptAndParse').returns(Promise.resolve('test-result'));
|
|
41
|
+
|
|
42
|
+
const result = await resolver.resolve(mockArgs);
|
|
43
|
+
t.true(promptAndParseStub.calledOnce);
|
|
44
|
+
t.is(result, 'test-result');
|
|
45
|
+
});
|
|
46
|
+
|
|
47
|
+
test('processInputText returns input text if no chunking', (t) => {
|
|
48
|
+
const resolver = t.context.pathwayResolver;
|
|
49
|
+
const text = 'This is a test input text';
|
|
50
|
+
const result = resolver.processInputText(text);
|
|
51
|
+
t.deepEqual(result, [text]);
|
|
52
|
+
});
|
|
53
|
+
|
|
54
|
+
test('applyPromptsSerially returns result of last prompt', async (t) => {
|
|
55
|
+
const resolver = t.context.pathwayResolver;
|
|
56
|
+
const text = 'This is a test input text';
|
|
57
|
+
const applyPromptStub = sinon.stub(resolver, 'applyPrompt');
|
|
58
|
+
applyPromptStub.onCall(0).returns(Promise.resolve('result1'));
|
|
59
|
+
applyPromptStub.onCall(1).returns(Promise.resolve('result2'));
|
|
60
|
+
|
|
61
|
+
resolver.pathwayPrompt = ['prompt1', 'prompt2'];
|
|
62
|
+
const result = await resolver.applyPromptsSerially(text, mockArgs);
|
|
63
|
+
|
|
64
|
+
t.is(result, 'result2');
|
|
65
|
+
});
|
|
66
|
+
|
|
67
|
+
test('processRequest returns empty result when input text is empty', async (t) => {
|
|
68
|
+
const resolver = t.context.pathwayResolver;
|
|
69
|
+
const text = '';
|
|
70
|
+
const processRequestStub = sinon.stub(resolver, 'processRequest').returns(Promise.resolve(''));
|
|
71
|
+
|
|
72
|
+
await resolver.resolve({ ...mockArgs, text });
|
|
73
|
+
|
|
74
|
+
t.true(processRequestStub.calledOnce);
|
|
75
|
+
const returnValue = await processRequestStub.firstCall.returnValue;
|
|
76
|
+
t.is(returnValue, text);
|
|
77
|
+
});
|