@aj-archipelago/cortex 1.0.5 → 1.0.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -2
- package/config/default.example.json +4 -2
- package/config.js +14 -8
- package/helper_apps/WhisperX/.dockerignore +27 -0
- package/helper_apps/WhisperX/Dockerfile +31 -0
- package/helper_apps/WhisperX/app-ts.py +76 -0
- package/helper_apps/WhisperX/app.py +115 -0
- package/helper_apps/WhisperX/docker-compose.debug.yml +12 -0
- package/helper_apps/WhisperX/docker-compose.yml +10 -0
- package/helper_apps/WhisperX/requirements.txt +6 -0
- package/index.js +1 -1
- package/lib/redisSubscription.js +1 -1
- package/package.json +8 -7
- package/pathways/basePathway.js +3 -2
- package/pathways/index.js +4 -0
- package/pathways/summary.js +2 -2
- package/pathways/sys_openai_chat.js +19 -0
- package/pathways/sys_openai_completion.js +11 -0
- package/pathways/test_palm_chat.js +1 -1
- package/pathways/transcribe.js +2 -1
- package/{graphql → server}/chunker.js +48 -3
- package/{graphql → server}/graphql.js +70 -62
- package/{graphql → server}/pathwayPrompter.js +14 -17
- package/{graphql → server}/pathwayResolver.js +59 -42
- package/{graphql → server}/plugins/azureTranslatePlugin.js +2 -2
- package/{graphql → server}/plugins/localModelPlugin.js +2 -2
- package/{graphql → server}/plugins/modelPlugin.js +8 -10
- package/{graphql → server}/plugins/openAiChatPlugin.js +13 -8
- package/{graphql → server}/plugins/openAiCompletionPlugin.js +9 -3
- package/{graphql → server}/plugins/openAiWhisperPlugin.js +30 -7
- package/{graphql → server}/plugins/palmChatPlugin.js +4 -6
- package/server/plugins/palmCodeCompletionPlugin.js +46 -0
- package/{graphql → server}/plugins/palmCompletionPlugin.js +13 -15
- package/server/rest.js +321 -0
- package/{graphql → server}/typeDef.js +30 -13
- package/tests/chunkfunction.test.js +112 -26
- package/tests/config.test.js +1 -1
- package/tests/main.test.js +282 -43
- package/tests/mocks.js +43 -2
- package/tests/modelPlugin.test.js +4 -4
- package/tests/openAiChatPlugin.test.js +21 -14
- package/tests/openai_api.test.js +147 -0
- package/tests/palmChatPlugin.test.js +10 -11
- package/tests/palmCompletionPlugin.test.js +3 -4
- package/tests/pathwayResolver.test.js +1 -1
- package/tests/truncateMessages.test.js +4 -5
- package/pathways/completions.js +0 -17
- package/pathways/test_oai_chat.js +0 -18
- package/pathways/test_oai_cmpl.js +0 -13
- package/tests/chunking.test.js +0 -157
- package/tests/translate.test.js +0 -126
- /package/{graphql → server}/parser.js +0 -0
- /package/{graphql → server}/pathwayResponseParser.js +0 -0
- /package/{graphql → server}/prompt.js +0 -0
- /package/{graphql → server}/pubsub.js +0 -0
- /package/{graphql → server}/requestState.js +0 -0
- /package/{graphql → server}/resolver.js +0 -0
- /package/{graphql → server}/subscriptions.js +0 -0
|
@@ -1,24 +1,29 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
} from 'apollo
|
|
1
|
+
// graphql.js
|
|
2
|
+
// Setup the Apollo server and Express middleware
|
|
3
|
+
|
|
4
|
+
import { ApolloServerPluginDrainHttpServer } from '@apollo/server/plugin/drainHttpServer';
|
|
5
|
+
import { ApolloServerPluginLandingPageLocalDefault } from '@apollo/server/plugin/landingPage/default';
|
|
6
|
+
import { ApolloServer } from '@apollo/server';
|
|
7
|
+
import { expressMiddleware } from '@apollo/server/express4';
|
|
6
8
|
import { makeExecutableSchema } from '@graphql-tools/schema';
|
|
7
9
|
import { WebSocketServer } from 'ws';
|
|
8
10
|
import { useServer } from 'graphql-ws/lib/use/ws';
|
|
9
11
|
import express from 'express';
|
|
10
|
-
import
|
|
12
|
+
import http from 'http';
|
|
11
13
|
import Keyv from 'keyv';
|
|
14
|
+
import cors from 'cors';
|
|
12
15
|
import { KeyvAdapter } from '@apollo/utils.keyvadapter';
|
|
13
|
-
import responseCachePlugin from 'apollo
|
|
16
|
+
import responseCachePlugin from '@apollo/server-plugin-response-cache';
|
|
14
17
|
import subscriptions from './subscriptions.js';
|
|
15
18
|
import { buildLimiters } from '../lib/request.js';
|
|
16
19
|
import { cancelRequestResolver } from './resolver.js';
|
|
17
20
|
import { buildPathways, buildModels } from '../config.js';
|
|
18
21
|
import { requestState } from './requestState.js';
|
|
22
|
+
import { buildRestEndpoints } from './rest.js';
|
|
19
23
|
|
|
24
|
+
// Utility functions
|
|
25
|
+
// Server plugins
|
|
20
26
|
const getPlugins = (config) => {
|
|
21
|
-
// server plugins
|
|
22
27
|
const plugins = [
|
|
23
28
|
ApolloServerPluginLandingPageLocalDefault({ embed: true }), // For local development.
|
|
24
29
|
];
|
|
@@ -39,41 +44,8 @@ const getPlugins = (config) => {
|
|
|
39
44
|
return { plugins, cache };
|
|
40
45
|
}
|
|
41
46
|
|
|
42
|
-
const buildRestEndpoints = (pathways, app, server, config) => {
|
|
43
|
-
for (const [name, pathway] of Object.entries(pathways)) {
|
|
44
|
-
// Only expose endpoints for enabled pathways that explicitly want to expose a REST endpoint
|
|
45
|
-
if (pathway.disabled || !config.get('enableRestEndpoints')) continue;
|
|
46
|
-
|
|
47
|
-
const fieldVariableDefs = pathway.typeDef(pathway).restDefinition || [];
|
|
48
|
-
|
|
49
|
-
app.post(`/rest/${name}`, async (req, res) => {
|
|
50
|
-
const variables = fieldVariableDefs.reduce((acc, variableDef) => {
|
|
51
|
-
if (Object.prototype.hasOwnProperty.call(req.body, variableDef.name)) {
|
|
52
|
-
acc[variableDef.name] = req.body[variableDef.name];
|
|
53
|
-
}
|
|
54
|
-
return acc;
|
|
55
|
-
}, {});
|
|
56
|
-
|
|
57
|
-
const variableParams = fieldVariableDefs.map(({ name, type }) => `$${name}: ${type}`).join(', ');
|
|
58
|
-
const queryArgs = fieldVariableDefs.map(({ name }) => `${name}: $${name}`).join(', ');
|
|
59
|
-
|
|
60
|
-
const query = `
|
|
61
|
-
query ${name}(${variableParams}) {
|
|
62
|
-
${name}(${queryArgs}) {
|
|
63
|
-
contextId
|
|
64
|
-
previousResult
|
|
65
|
-
result
|
|
66
|
-
}
|
|
67
|
-
}
|
|
68
|
-
`;
|
|
69
|
-
|
|
70
|
-
const result = await server.executeOperation({ query, variables });
|
|
71
|
-
res.json(result.data[name]);
|
|
72
|
-
});
|
|
73
|
-
}
|
|
74
|
-
};
|
|
75
47
|
|
|
76
|
-
//
|
|
48
|
+
// Type Definitions for GraphQL
|
|
77
49
|
const getTypedefs = (pathways) => {
|
|
78
50
|
|
|
79
51
|
const defaultTypeDefs = `#graphql
|
|
@@ -111,6 +83,7 @@ const getTypedefs = (pathways) => {
|
|
|
111
83
|
return typeDefs.join('\n');
|
|
112
84
|
}
|
|
113
85
|
|
|
86
|
+
// Resolvers for GraphQL
|
|
114
87
|
const getResolvers = (config, pathways) => {
|
|
115
88
|
const resolverFunctions = {};
|
|
116
89
|
for (const [name, pathway] of Object.entries(pathways)) {
|
|
@@ -118,6 +91,7 @@ const getResolvers = (config, pathways) => {
|
|
|
118
91
|
resolverFunctions[name] = (parent, args, contextValue, info) => {
|
|
119
92
|
// add shared state to contextValue
|
|
120
93
|
contextValue.pathway = pathway;
|
|
94
|
+
contextValue.config = config;
|
|
121
95
|
return pathway.rootResolver(parent, args, contextValue, info);
|
|
122
96
|
}
|
|
123
97
|
}
|
|
@@ -131,7 +105,7 @@ const getResolvers = (config, pathways) => {
|
|
|
131
105
|
return resolvers;
|
|
132
106
|
}
|
|
133
107
|
|
|
134
|
-
//
|
|
108
|
+
// Build the server including the GraphQL schema and REST endpoints
|
|
135
109
|
const build = async (config) => {
|
|
136
110
|
// First perform config build
|
|
137
111
|
await buildPathways(config);
|
|
@@ -150,9 +124,9 @@ const build = async (config) => {
|
|
|
150
124
|
|
|
151
125
|
const { plugins, cache } = getPlugins(config);
|
|
152
126
|
|
|
153
|
-
const app = express()
|
|
127
|
+
const app = express();
|
|
154
128
|
|
|
155
|
-
const httpServer = createServer(app);
|
|
129
|
+
const httpServer = http.createServer(app);
|
|
156
130
|
|
|
157
131
|
// Creating the WebSocket server
|
|
158
132
|
const wsServer = new WebSocketServer({
|
|
@@ -182,35 +156,69 @@ const build = async (config) => {
|
|
|
182
156
|
},
|
|
183
157
|
};
|
|
184
158
|
},
|
|
185
|
-
}
|
|
186
|
-
|
|
159
|
+
}
|
|
160
|
+
]),
|
|
187
161
|
});
|
|
188
162
|
|
|
189
163
|
// If CORTEX_API_KEY is set, we roll our own auth middleware - usually not used if you're being fronted by a proxy
|
|
190
164
|
const cortexApiKey = config.get('cortexApiKey');
|
|
165
|
+
if (cortexApiKey) {
|
|
166
|
+
app.use((req, res, next) => {
|
|
167
|
+
let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key'];
|
|
168
|
+
if (!providedApiKey) {
|
|
169
|
+
providedApiKey = req.headers['authorization'];
|
|
170
|
+
providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
if (cortexApiKey && cortexApiKey !== providedApiKey) {
|
|
174
|
+
if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') {
|
|
175
|
+
res.status(401)
|
|
176
|
+
.set('WWW-Authenticate', 'Cortex-Api-Key')
|
|
177
|
+
.set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key')
|
|
178
|
+
.json({
|
|
179
|
+
errors: [
|
|
180
|
+
{
|
|
181
|
+
message: 'Unauthorized',
|
|
182
|
+
extensions: {
|
|
183
|
+
code: 'UNAUTHENTICATED',
|
|
184
|
+
},
|
|
185
|
+
},
|
|
186
|
+
],
|
|
187
|
+
});
|
|
188
|
+
} else {
|
|
189
|
+
res.status(401)
|
|
190
|
+
.set('WWW-Authenticate', 'Cortex-Api-Key')
|
|
191
|
+
.set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key')
|
|
192
|
+
.send('Unauthorized');
|
|
193
|
+
}
|
|
194
|
+
} else {
|
|
195
|
+
next();
|
|
196
|
+
}
|
|
197
|
+
});
|
|
198
|
+
};
|
|
191
199
|
|
|
192
|
-
|
|
193
|
-
if (cortexApiKey && req.headers.cortexApiKey !== cortexApiKey && req.query.cortexApiKey !== cortexApiKey) {
|
|
194
|
-
res.status(401).send('Unauthorized');
|
|
195
|
-
} else {
|
|
196
|
-
next();
|
|
197
|
-
}
|
|
198
|
-
});
|
|
199
|
-
|
|
200
|
-
// Use the JSON body parser middleware for REST endpoints
|
|
200
|
+
// Parse the body for REST endpoints
|
|
201
201
|
app.use(express.json());
|
|
202
|
-
|
|
203
|
-
// add the REST endpoints
|
|
204
|
-
buildRestEndpoints(pathways, app, server, config);
|
|
205
202
|
|
|
206
|
-
//
|
|
203
|
+
// Server Startup Function
|
|
207
204
|
const startServer = async () => {
|
|
208
205
|
await server.start();
|
|
209
|
-
|
|
206
|
+
app.use(
|
|
207
|
+
'/graphql',
|
|
208
|
+
|
|
209
|
+
cors(),
|
|
210
|
+
|
|
211
|
+
expressMiddleware(server, {
|
|
212
|
+
context: async ({ req, res }) => ({ req, res, config, requestState }),
|
|
213
|
+
}),
|
|
214
|
+
);
|
|
215
|
+
|
|
216
|
+
// add the REST endpoints
|
|
217
|
+
buildRestEndpoints(pathways, app, server, config);
|
|
210
218
|
|
|
211
219
|
// Now that our HTTP server is fully set up, we can listen to it.
|
|
212
220
|
httpServer.listen(config.get('PORT'), () => {
|
|
213
|
-
console.log(`🚀 Server is now running at http://localhost:${config.get('PORT')}
|
|
221
|
+
console.log(`🚀 Server is now running at http://localhost:${config.get('PORT')}/graphql`);
|
|
214
222
|
});
|
|
215
223
|
};
|
|
216
224
|
|
|
@@ -6,40 +6,37 @@ import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
|
|
|
6
6
|
import LocalModelPlugin from './plugins/localModelPlugin.js';
|
|
7
7
|
import PalmChatPlugin from './plugins/palmChatPlugin.js';
|
|
8
8
|
import PalmCompletionPlugin from './plugins/palmCompletionPlugin.js';
|
|
9
|
+
import PalmCodeCompletionPlugin from './plugins/palmCodeCompletionPlugin.js';
|
|
9
10
|
|
|
10
11
|
class PathwayPrompter {
|
|
11
|
-
constructor(
|
|
12
|
-
|
|
13
|
-
const modelName = pathway.model || config.get('defaultModelName');
|
|
14
|
-
const model = config.get('models')[modelName];
|
|
15
|
-
|
|
16
|
-
if (!model) {
|
|
17
|
-
throw new Error(`Model ${modelName} not found in config`);
|
|
18
|
-
}
|
|
19
|
-
|
|
12
|
+
constructor(config, pathway, modelName, model) {
|
|
13
|
+
|
|
20
14
|
let plugin;
|
|
21
15
|
|
|
22
16
|
switch (model.type) {
|
|
23
17
|
case 'OPENAI-CHAT':
|
|
24
|
-
plugin = new OpenAIChatPlugin(config, pathway);
|
|
18
|
+
plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
|
|
25
19
|
break;
|
|
26
20
|
case 'AZURE-TRANSLATE':
|
|
27
|
-
plugin = new AzureTranslatePlugin(config, pathway);
|
|
21
|
+
plugin = new AzureTranslatePlugin(config, pathway, modelName, model);
|
|
28
22
|
break;
|
|
29
23
|
case 'OPENAI-COMPLETION':
|
|
30
|
-
plugin = new OpenAICompletionPlugin(config, pathway);
|
|
24
|
+
plugin = new OpenAICompletionPlugin(config, pathway, modelName, model);
|
|
31
25
|
break;
|
|
32
|
-
case '
|
|
33
|
-
plugin = new OpenAIWhisperPlugin(config, pathway);
|
|
26
|
+
case 'OPENAI-WHISPER':
|
|
27
|
+
plugin = new OpenAIWhisperPlugin(config, pathway, modelName, model);
|
|
34
28
|
break;
|
|
35
29
|
case 'LOCAL-CPP-MODEL':
|
|
36
|
-
plugin = new LocalModelPlugin(config, pathway);
|
|
30
|
+
plugin = new LocalModelPlugin(config, pathway, modelName, model);
|
|
37
31
|
break;
|
|
38
32
|
case 'PALM-CHAT':
|
|
39
|
-
plugin = new PalmChatPlugin(config, pathway);
|
|
33
|
+
plugin = new PalmChatPlugin(config, pathway, modelName, model);
|
|
40
34
|
break;
|
|
41
35
|
case 'PALM-COMPLETION':
|
|
42
|
-
plugin = new PalmCompletionPlugin(config, pathway);
|
|
36
|
+
plugin = new PalmCompletionPlugin(config, pathway, modelName, model);
|
|
37
|
+
break;
|
|
38
|
+
case 'PALM-CODE-COMPLETION':
|
|
39
|
+
plugin = new PalmCodeCompletionPlugin(config, pathway, modelName, model);
|
|
43
40
|
break;
|
|
44
41
|
default:
|
|
45
42
|
throw new Error(`Unsupported model type: ${model.type}`);
|
|
@@ -20,9 +20,31 @@ class PathwayResolver {
|
|
|
20
20
|
this.warnings = [];
|
|
21
21
|
this.requestId = uuidv4();
|
|
22
22
|
this.responseParser = new PathwayResponseParser(pathway);
|
|
23
|
-
this.
|
|
23
|
+
this.modelName = [
|
|
24
|
+
pathway.model,
|
|
25
|
+
args?.model,
|
|
26
|
+
pathway.inputParameters?.model,
|
|
27
|
+
config.get('defaultModelName')
|
|
28
|
+
].find(modelName => modelName && config.get('models').hasOwnProperty(modelName));
|
|
29
|
+
this.model = config.get('models')[this.modelName];
|
|
30
|
+
|
|
31
|
+
if (!this.model) {
|
|
32
|
+
throw new Error(`Model ${this.modelName} not found in config`);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const specifiedModelName = pathway.model || args?.model || pathway.inputParameters?.model;
|
|
36
|
+
|
|
37
|
+
if (this.modelName !== (specifiedModelName)) {
|
|
38
|
+
if (specifiedModelName) {
|
|
39
|
+
this.logWarning(`Specified model ${specifiedModelName} not found in config, using ${this.modelName} instead.`);
|
|
40
|
+
} else {
|
|
41
|
+
this.logWarning(`No model specified in the pathway, using ${this.modelName}.`);
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
24
45
|
this.previousResult = '';
|
|
25
46
|
this.prompts = [];
|
|
47
|
+
this.pathwayPrompter = new PathwayPrompter(this.config, this.pathway, this.modelName, this.model);
|
|
26
48
|
|
|
27
49
|
Object.defineProperty(this, 'pathwayPrompt', {
|
|
28
50
|
get() {
|
|
@@ -42,66 +64,61 @@ class PathwayResolver {
|
|
|
42
64
|
}
|
|
43
65
|
|
|
44
66
|
async asyncResolve(args) {
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
if(args.async || typeof data === 'string') { // if async flag set or processed async and got string response
|
|
67
|
+
const responseData = await this.promptAndParse(args);
|
|
68
|
+
|
|
69
|
+
// Either we're dealing with an async request or a stream
|
|
70
|
+
if(args.async || typeof responseData === 'string') {
|
|
50
71
|
const { completedCount, totalCount } = requestState[this.requestId];
|
|
51
|
-
requestState[this.requestId].data =
|
|
72
|
+
requestState[this.requestId].data = responseData;
|
|
52
73
|
pubsub.publish('REQUEST_PROGRESS', {
|
|
53
74
|
requestProgress: {
|
|
54
75
|
requestId: this.requestId,
|
|
55
76
|
progress: completedCount / totalCount,
|
|
56
|
-
data: JSON.stringify(
|
|
77
|
+
data: JSON.stringify(responseData),
|
|
57
78
|
}
|
|
58
79
|
});
|
|
59
|
-
} else { //stream
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
const
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
if (
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
80
|
+
} else { // stream
|
|
81
|
+
try {
|
|
82
|
+
const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
|
|
83
|
+
incomingMessage.on('data', data => {
|
|
84
|
+
const events = data.toString().split('\n');
|
|
85
|
+
|
|
86
|
+
events.forEach(event => {
|
|
87
|
+
if (event.trim() === '') return; // Skip empty lines
|
|
88
|
+
|
|
89
|
+
const message = event.replace(/^data: /, '');
|
|
90
|
+
|
|
91
|
+
//console.log(`====================================`);
|
|
92
|
+
//console.log(`STREAM EVENT: ${event}`);
|
|
93
|
+
//console.log(`MESSAGE: ${message}`);
|
|
94
|
+
|
|
95
|
+
const requestProgress = {
|
|
96
|
+
requestId: this.requestId,
|
|
97
|
+
data: message,
|
|
76
98
|
}
|
|
99
|
+
|
|
100
|
+
if (message.trim() === '[DONE]') {
|
|
101
|
+
requestProgress.progress = 1;
|
|
102
|
+
}
|
|
103
|
+
|
|
77
104
|
try {
|
|
78
|
-
const parsed = JSON.parse(message);
|
|
79
|
-
const result = this.pathwayPrompter.plugin.parseResponse(parsed)
|
|
80
|
-
|
|
81
105
|
pubsub.publish('REQUEST_PROGRESS', {
|
|
82
|
-
requestProgress:
|
|
83
|
-
requestId: this.requestId,
|
|
84
|
-
data: JSON.stringify(result)
|
|
85
|
-
}
|
|
106
|
+
requestProgress: requestProgress
|
|
86
107
|
});
|
|
87
108
|
} catch (error) {
|
|
88
109
|
console.error('Could not JSON parse stream message', message, error);
|
|
89
110
|
}
|
|
90
|
-
}
|
|
111
|
+
});
|
|
91
112
|
});
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
// console.log("stream done");
|
|
95
|
-
// });
|
|
113
|
+
} catch (error) {
|
|
114
|
+
console.error('Could not subscribe to stream', error);
|
|
96
115
|
}
|
|
97
|
-
|
|
98
116
|
}
|
|
99
117
|
}
|
|
100
118
|
|
|
101
119
|
async resolve(args) {
|
|
120
|
+
// Either we're dealing with an async request, stream, or regular request
|
|
102
121
|
if (args.async || args.stream) {
|
|
103
|
-
// Asyncronously process the request
|
|
104
|
-
// this.asyncResolve(args);
|
|
105
122
|
if (!requestState[this.requestId]) {
|
|
106
123
|
requestState[this.requestId] = {}
|
|
107
124
|
}
|
|
@@ -161,7 +178,7 @@ class PathwayResolver {
|
|
|
161
178
|
}
|
|
162
179
|
|
|
163
180
|
// chunk the text and return the chunks with newline separators
|
|
164
|
-
return getSemanticChunks(text, chunkTokenLength);
|
|
181
|
+
return getSemanticChunks(text, chunkTokenLength, this.pathway.inputFormat);
|
|
165
182
|
}
|
|
166
183
|
|
|
167
184
|
truncate(str, n) {
|
|
@@ -292,7 +309,7 @@ class PathwayResolver {
|
|
|
292
309
|
let result = '';
|
|
293
310
|
|
|
294
311
|
// If this text is empty, skip applying the prompt as it will likely be a nonsensical result
|
|
295
|
-
if (!/^\s*$/.test(text)) {
|
|
312
|
+
if (!/^\s*$/.test(text) || parameters?.file) {
|
|
296
313
|
result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt, this);
|
|
297
314
|
} else {
|
|
298
315
|
result = text;
|
|
@@ -2,8 +2,8 @@
|
|
|
2
2
|
import ModelPlugin from './modelPlugin.js';
|
|
3
3
|
|
|
4
4
|
class AzureTranslatePlugin extends ModelPlugin {
|
|
5
|
-
constructor(config, pathway) {
|
|
6
|
-
super(config, pathway);
|
|
5
|
+
constructor(config, pathway, modelName, model) {
|
|
6
|
+
super(config, pathway, modelName, model);
|
|
7
7
|
}
|
|
8
8
|
|
|
9
9
|
// Set up parameters specific to the Azure Translate API
|
|
@@ -4,8 +4,8 @@ import { execFileSync } from 'child_process';
|
|
|
4
4
|
import { encode } from 'gpt-3-encoder';
|
|
5
5
|
|
|
6
6
|
class LocalModelPlugin extends ModelPlugin {
|
|
7
|
-
constructor(config, pathway) {
|
|
8
|
-
super(config, pathway);
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
// if the input starts with a chatML response, just return that
|
|
@@ -6,19 +6,13 @@ import { encode } from 'gpt-3-encoder';
|
|
|
6
6
|
import { getFirstNToken } from '../chunker.js';
|
|
7
7
|
|
|
8
8
|
const DEFAULT_MAX_TOKENS = 4096;
|
|
9
|
+
const DEFAULT_MAX_RETURN_TOKENS = 256;
|
|
9
10
|
const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
|
|
10
11
|
|
|
11
12
|
class ModelPlugin {
|
|
12
|
-
constructor(config, pathway) {
|
|
13
|
-
|
|
14
|
-
this.
|
|
15
|
-
// Get the model from the config
|
|
16
|
-
this.model = config.get('models')[this.modelName];
|
|
17
|
-
// If the model doesn't exist, throw an exception
|
|
18
|
-
if (!this.model) {
|
|
19
|
-
throw new Error(`Model ${this.modelName} not found in config`);
|
|
20
|
-
}
|
|
21
|
-
|
|
13
|
+
constructor(config, pathway, modelName, model) {
|
|
14
|
+
this.modelName = modelName;
|
|
15
|
+
this.model = model;
|
|
22
16
|
this.config = config;
|
|
23
17
|
this.environmentVariables = config.getEnv();
|
|
24
18
|
this.temperature = pathway.temperature;
|
|
@@ -143,6 +137,10 @@ class ModelPlugin {
|
|
|
143
137
|
return (this.promptParameters.maxTokenLength ?? this.model.maxTokenLength ?? DEFAULT_MAX_TOKENS);
|
|
144
138
|
}
|
|
145
139
|
|
|
140
|
+
getModelMaxReturnTokens() {
|
|
141
|
+
return (this.promptParameters.maxReturnTokens ?? this.model.maxReturnTokens ?? DEFAULT_MAX_RETURN_TOKENS);
|
|
142
|
+
}
|
|
143
|
+
|
|
146
144
|
getPromptTokenRatio() {
|
|
147
145
|
// TODO: Is this the right order of precedence? inputParameters should maybe be second?
|
|
148
146
|
return this.promptParameters.inputParameters?.tokenRatio ?? this.promptParameters.tokenRatio ?? DEFAULT_PROMPT_TOKEN_RATIO;
|
|
@@ -3,8 +3,8 @@ import ModelPlugin from './modelPlugin.js';
|
|
|
3
3
|
import { encode } from 'gpt-3-encoder';
|
|
4
4
|
|
|
5
5
|
class OpenAIChatPlugin extends ModelPlugin {
|
|
6
|
-
constructor(config, pathway) {
|
|
7
|
-
super(config, pathway);
|
|
6
|
+
constructor(config, pathway, modelName, model) {
|
|
7
|
+
super(config, pathway, modelName, model);
|
|
8
8
|
}
|
|
9
9
|
|
|
10
10
|
// convert to OpenAI messages array format if necessary
|
|
@@ -90,7 +90,7 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
90
90
|
parseResponse(data) {
|
|
91
91
|
const { choices } = data;
|
|
92
92
|
if (!choices || !choices.length) {
|
|
93
|
-
return
|
|
93
|
+
return data;
|
|
94
94
|
}
|
|
95
95
|
|
|
96
96
|
// if we got a choices array back with more than one choice, return the whole array
|
|
@@ -108,8 +108,9 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
108
108
|
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
109
109
|
console.log(separator);
|
|
110
110
|
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
const { stream, messages } = data;
|
|
112
|
+
if (messages && messages.length > 1) {
|
|
113
|
+
messages.forEach((message, index) => {
|
|
113
114
|
const words = message.content.split(" ");
|
|
114
115
|
const tokenCount = encode(message.content).length;
|
|
115
116
|
const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
@@ -117,11 +118,15 @@ class OpenAIChatPlugin extends ModelPlugin {
|
|
|
117
118
|
console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
|
|
118
119
|
});
|
|
119
120
|
} else {
|
|
120
|
-
console.log(`\x1b[36m${
|
|
121
|
+
console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
|
|
121
122
|
}
|
|
122
123
|
|
|
123
|
-
|
|
124
|
-
|
|
124
|
+
if (stream) {
|
|
125
|
+
console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
|
|
126
|
+
} else {
|
|
127
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
128
|
+
}
|
|
129
|
+
|
|
125
130
|
prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
|
|
126
131
|
}
|
|
127
132
|
}
|
|
@@ -15,8 +15,8 @@ const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, tar
|
|
|
15
15
|
}
|
|
16
16
|
|
|
17
17
|
class OpenAICompletionPlugin extends ModelPlugin {
|
|
18
|
-
constructor(config, pathway) {
|
|
19
|
-
super(config, pathway);
|
|
18
|
+
constructor(config, pathway, modelName, model) {
|
|
19
|
+
super(config, pathway, modelName, model);
|
|
20
20
|
}
|
|
21
21
|
|
|
22
22
|
// Set up parameters specific to the OpenAI Completion API
|
|
@@ -108,10 +108,16 @@ class OpenAICompletionPlugin extends ModelPlugin {
|
|
|
108
108
|
const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
|
|
109
109
|
console.log(separator);
|
|
110
110
|
|
|
111
|
+
const stream = data.stream;
|
|
111
112
|
const modelInput = data.prompt;
|
|
112
113
|
|
|
113
114
|
console.log(`\x1b[36m${modelInput}\x1b[0m`);
|
|
114
|
-
|
|
115
|
+
|
|
116
|
+
if (stream) {
|
|
117
|
+
console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
|
|
118
|
+
} else {
|
|
119
|
+
console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
|
|
120
|
+
}
|
|
115
121
|
|
|
116
122
|
prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
|
|
117
123
|
}
|
|
@@ -19,6 +19,7 @@ const pipeline = promisify(stream.pipeline);
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
const API_URL = config.get('whisperMediaApiUrl');
|
|
22
|
+
const WHISPER_TS_API_URL = config.get('whisperTSApiUrl');
|
|
22
23
|
|
|
23
24
|
function alignSubtitles(subtitles) {
|
|
24
25
|
const result = [];
|
|
@@ -74,14 +75,14 @@ const downloadFile = async (fileUrl) => {
|
|
|
74
75
|
fs.unlink(localFilePath, () => {
|
|
75
76
|
reject(error);
|
|
76
77
|
});
|
|
77
|
-
throw error;
|
|
78
|
+
//throw error;
|
|
78
79
|
}
|
|
79
80
|
});
|
|
80
81
|
};
|
|
81
82
|
|
|
82
83
|
class OpenAIWhisperPlugin extends ModelPlugin {
|
|
83
|
-
constructor(config, pathway) {
|
|
84
|
-
super(config, pathway);
|
|
84
|
+
constructor(config, pathway, modelName, model) {
|
|
85
|
+
super(config, pathway, modelName, model);
|
|
85
86
|
}
|
|
86
87
|
|
|
87
88
|
async getMediaChunks(file, requestId) {
|
|
@@ -115,11 +116,28 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
115
116
|
|
|
116
117
|
// Execute the request to the OpenAI Whisper API
|
|
117
118
|
async execute(text, parameters, prompt, pathwayResolver) {
|
|
118
|
-
const { responseFormat } = parameters;
|
|
119
|
+
const { responseFormat, wordTimestamped } = parameters;
|
|
119
120
|
const url = this.requestUrl(text);
|
|
120
121
|
const params = {};
|
|
121
122
|
const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
|
|
122
123
|
|
|
124
|
+
const processTS = async (uri) => {
|
|
125
|
+
if (wordTimestamped) {
|
|
126
|
+
if (!WHISPER_TS_API_URL) {
|
|
127
|
+
throw new Error(`WHISPER_TS_API_URL not set for word timestamped processing`);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
try {
|
|
131
|
+
// const res = await axios.post(WHISPER_TS_API_URL, { params: { fileurl: uri } });
|
|
132
|
+
const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri},{},{});
|
|
133
|
+
return res;
|
|
134
|
+
} catch (err) {
|
|
135
|
+
console.log(`Error getting word timestamped data from api:`, err);
|
|
136
|
+
throw err;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
123
141
|
const processChunk = async (chunk) => {
|
|
124
142
|
try {
|
|
125
143
|
const { language, responseFormat } = parameters;
|
|
@@ -159,7 +177,6 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
159
177
|
|
|
160
178
|
let chunks = []; // array of local file paths
|
|
161
179
|
try {
|
|
162
|
-
|
|
163
180
|
const uris = await this.getMediaChunks(file, requestId); // array of remote file uris
|
|
164
181
|
if (!uris || !uris.length) {
|
|
165
182
|
throw new Error(`Error in getting chunks from media helper for file ${file}`);
|
|
@@ -169,7 +186,13 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
169
186
|
|
|
170
187
|
// sequential download of chunks
|
|
171
188
|
for (const uri of uris) {
|
|
172
|
-
|
|
189
|
+
if (wordTimestamped) { // get word timestamped data
|
|
190
|
+
sendProgress(); // no download needed auto progress
|
|
191
|
+
const ts = await processTS(uri);
|
|
192
|
+
result.push(ts);
|
|
193
|
+
} else {
|
|
194
|
+
chunks.push(await downloadFile(uri));
|
|
195
|
+
}
|
|
173
196
|
sendProgress();
|
|
174
197
|
}
|
|
175
198
|
|
|
@@ -210,7 +233,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
210
233
|
}
|
|
211
234
|
}
|
|
212
235
|
|
|
213
|
-
if (['srt','vtt'].includes(responseFormat)) { // align subtitles for formats
|
|
236
|
+
if (['srt','vtt'].includes(responseFormat) || wordTimestamped) { // align subtitles for formats
|
|
214
237
|
return alignSubtitles(result);
|
|
215
238
|
}
|
|
216
239
|
return result.join(` `);
|
|
@@ -4,8 +4,8 @@ import { encode } from 'gpt-3-encoder';
|
|
|
4
4
|
import HandleBars from '../../lib/handleBars.js';
|
|
5
5
|
|
|
6
6
|
class PalmChatPlugin extends ModelPlugin {
|
|
7
|
-
constructor(config, pathway) {
|
|
8
|
-
super(config, pathway);
|
|
7
|
+
constructor(config, pathway, modelName, model) {
|
|
8
|
+
super(config, pathway, modelName, model);
|
|
9
9
|
}
|
|
10
10
|
|
|
11
11
|
// Convert to PaLM messages array format if necessary
|
|
@@ -92,10 +92,8 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
92
92
|
const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
|
|
93
93
|
const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
|
|
94
94
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
|
|
98
|
-
|
|
95
|
+
const max_tokens = this.getModelMaxReturnTokens();
|
|
96
|
+
|
|
99
97
|
if (max_tokens < 0) {
|
|
100
98
|
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
101
99
|
}
|