@aj-archipelago/cortex 1.1.4-0 → 1.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/config.js +17 -5
- package/helper-apps/cortex-file-handler/fileChunker.js +3 -1
- package/lib/cortexRequest.js +11 -1
- package/lib/requestExecutor.js +4 -4
- package/package.json +2 -1
- package/pathways/bias.js +1 -1
- package/pathways/cognitive_insert.js +1 -1
- package/server/graphql.js +6 -4
- package/server/modelExecutor.js +8 -0
- package/server/pathwayResolver.js +23 -5
- package/server/plugins/geminiChatPlugin.js +195 -0
- package/server/plugins/geminiVisionPlugin.js +102 -0
- package/server/plugins/modelPlugin.js +4 -3
- package/server/plugins/openAiEmbeddingsPlugin.js +3 -1
- package/server/plugins/openAiWhisperPlugin.js +1 -1
- package/server/rest.js +11 -5
package/README.md
CHANGED
|
@@ -422,7 +422,7 @@ Configuration of Cortex is done via a [convict](https://github.com/mozilla/node-
|
|
|
422
422
|
- `enableCache`: A boolean flag indicating whether to enable Axios-level request caching. Default is true. The value can be set using the `CORTEX_ENABLE_CACHE` environment variable.
|
|
423
423
|
- `enableGraphqlCache`: A boolean flag indicating whether to enable GraphQL query caching. Default is false. The value can be set using the `CORTEX_ENABLE_GRAPHQL_CACHE` environment variable.
|
|
424
424
|
- `enableRestEndpoints`: A boolean flag indicating whether create REST endpoints for pathways as well as GraphQL queries. Default is false. The value can be set using the `CORTEX_ENABLE_REST` environment variable.
|
|
425
|
-
- `
|
|
425
|
+
- `cortexApiKeys`: A string containing one or more comma separated API keys that the client must pass to Cortex for authorization. Default is null in which case Cortex is unprotected. The value can be set using the `CORTEX_API_KEY` environment variable
|
|
426
426
|
- `models`: An object containing the different models used by the project. The value can be set using the `CORTEX_MODELS` environment variable. Cortex is model and vendor agnostic - you can use this config to set up models of any type from any vendor.
|
|
427
427
|
- `openaiApiKey`: The API key used for accessing the OpenAI API. This is sensitive information and has no default value. The value can be set using the `OPENAI_API_KEY` environment variable.
|
|
428
428
|
- `openaiApiUrl`: The URL used for accessing the OpenAI API. Default is https://api.openai.com/v1/completions. The value can be set using the `OPENAI_API_URL` environment variable.
|
package/config.js
CHANGED
|
@@ -8,6 +8,18 @@ import logger from './lib/logger.js';
|
|
|
8
8
|
|
|
9
9
|
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
|
10
10
|
|
|
11
|
+
convict.addFormat({
|
|
12
|
+
name: 'string-array',
|
|
13
|
+
validate: function(val) {
|
|
14
|
+
if (!Array.isArray(val)) {
|
|
15
|
+
throw new Error('must be of type Array');
|
|
16
|
+
}
|
|
17
|
+
},
|
|
18
|
+
coerce: function(val) {
|
|
19
|
+
return val.split(',');
|
|
20
|
+
},
|
|
21
|
+
});
|
|
22
|
+
|
|
11
23
|
// Schema for config
|
|
12
24
|
var config = convict({
|
|
13
25
|
env: {
|
|
@@ -30,8 +42,8 @@ var config = convict({
|
|
|
30
42
|
default: path.join(__dirname, 'pathways'),
|
|
31
43
|
env: 'CORTEX_CORE_PATHWAYS_PATH'
|
|
32
44
|
},
|
|
33
|
-
|
|
34
|
-
format:
|
|
45
|
+
cortexApiKeys: {
|
|
46
|
+
format: 'string-array',
|
|
35
47
|
default: null,
|
|
36
48
|
env: 'CORTEX_API_KEY',
|
|
37
49
|
sensitive: true
|
|
@@ -110,9 +122,9 @@ var config = convict({
|
|
|
110
122
|
},
|
|
111
123
|
"oai-embeddings": {
|
|
112
124
|
"type": "OPENAI-EMBEDDINGS",
|
|
113
|
-
"url": "https://
|
|
125
|
+
"url": "https://api.openai.com/v1/embeddings",
|
|
114
126
|
"headers": {
|
|
115
|
-
"
|
|
127
|
+
"Authorization": "Bearer {{OPENAI_API_KEY}}",
|
|
116
128
|
"Content-Type": "application/json"
|
|
117
129
|
},
|
|
118
130
|
"params": {
|
|
@@ -264,7 +276,7 @@ const buildPathways = async (config) => {
|
|
|
264
276
|
|
|
265
277
|
// Build and load models to config
|
|
266
278
|
const buildModels = (config) => {
|
|
267
|
-
|
|
279
|
+
const { models } = config.getProperties();
|
|
268
280
|
|
|
269
281
|
// iterate over each model
|
|
270
282
|
for (let [key, model] of Object.entries(models)) {
|
|
@@ -113,7 +113,9 @@ async function splitMediaFile(inputPath, chunkDurationInSeconds = 600) {
|
|
|
113
113
|
|
|
114
114
|
return { chunkPromises, uniqueOutputPath };
|
|
115
115
|
} catch (err) {
|
|
116
|
-
|
|
116
|
+
const msg = `Error processing media file, check if the file is a valid media file or is accessible`;
|
|
117
|
+
console.error(msg, err);
|
|
118
|
+
throw new Error(msg);
|
|
117
119
|
}
|
|
118
120
|
}
|
|
119
121
|
|
package/lib/cortexRequest.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { selectEndpoint } from './requestExecutor.js';
|
|
2
2
|
|
|
3
3
|
class CortexRequest {
|
|
4
|
-
constructor( { url, data, params, headers, cache, model, pathwayResolver, selectedEndpoint } = {}) {
|
|
4
|
+
constructor( { url, data, params, headers, cache, model, pathwayResolver, selectedEndpoint, stream } = {}) {
|
|
5
5
|
this._url = url || '';
|
|
6
6
|
this._data = data || {};
|
|
7
7
|
this._params = params || {};
|
|
@@ -10,6 +10,7 @@ class CortexRequest {
|
|
|
10
10
|
this._model = model || '';
|
|
11
11
|
this._pathwayResolver = pathwayResolver || {};
|
|
12
12
|
this._selectedEndpoint = selectedEndpoint || {};
|
|
13
|
+
this._stream = stream || false;
|
|
13
14
|
|
|
14
15
|
if (this._pathwayResolver) {
|
|
15
16
|
this._model = this._pathwayResolver.model;
|
|
@@ -112,6 +113,15 @@ class CortexRequest {
|
|
|
112
113
|
set pathwayResolver(value) {
|
|
113
114
|
this._pathwayResolver = value;
|
|
114
115
|
}
|
|
116
|
+
|
|
117
|
+
// stream getter and setter
|
|
118
|
+
get stream() {
|
|
119
|
+
return this._stream;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
set stream(value) {
|
|
123
|
+
this._stream = value;
|
|
124
|
+
}
|
|
115
125
|
}
|
|
116
126
|
|
|
117
127
|
export default CortexRequest;
|
package/lib/requestExecutor.js
CHANGED
|
@@ -192,7 +192,7 @@ const DUPLICATE_REQUEST_AFTER = 10; // 10 seconds
|
|
|
192
192
|
const postRequest = async (cortexRequest) => {
|
|
193
193
|
let promises = [];
|
|
194
194
|
for (let i = 0; i < MAX_RETRY; i++) {
|
|
195
|
-
const { url, data, params, headers, cache, selectedEndpoint, requestId, pathway, model} = cortexRequest;
|
|
195
|
+
const { url, data, params, headers, cache, selectedEndpoint, requestId, pathway, model, stream} = cortexRequest;
|
|
196
196
|
const enableDuplicateRequests = pathway?.enableDuplicateRequests !== undefined ? pathway.enableDuplicateRequests : config.get('enableDuplicateRequests');
|
|
197
197
|
let maxDuplicateRequests = enableDuplicateRequests ? MAX_DUPLICATE_REQUESTS : 1;
|
|
198
198
|
let duplicateRequestAfter = (pathway?.duplicateRequestAfter || DUPLICATE_REQUEST_AFTER) * 1000;
|
|
@@ -202,7 +202,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
const axiosConfigObj = { params, headers, cache };
|
|
205
|
-
const streamRequested = (params?.stream || data?.stream);
|
|
205
|
+
const streamRequested = (stream || params?.stream || data?.stream);
|
|
206
206
|
if (streamRequested && model.supportsStreaming) {
|
|
207
207
|
axiosConfigObj.responseType = 'stream';
|
|
208
208
|
promises.push(selectedEndpoint.limiter.schedule({expiration: pathway.timeout * 1000 + 1000, id: `${requestId}_${uuidv4()}`},() => postWithMonitor(selectedEndpoint, url, data, axiosConfigObj)));
|
|
@@ -249,7 +249,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
249
249
|
|
|
250
250
|
if (!controller.signal?.aborted) {
|
|
251
251
|
|
|
252
|
-
|
|
252
|
+
logger.debug(`<<< [${requestId}] received response for request ${index}`);
|
|
253
253
|
|
|
254
254
|
if (axiosConfigObj.responseType === 'stream') {
|
|
255
255
|
// Buffering and collecting the stream data
|
|
@@ -258,7 +258,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
258
258
|
let responseData = '';
|
|
259
259
|
response.data.on('data', (chunk) => {
|
|
260
260
|
responseData += chunk;
|
|
261
|
-
|
|
261
|
+
logger.debug(`<<< [${requestId}] received chunk for request ${index}`);
|
|
262
262
|
});
|
|
263
263
|
response.data.on('end', () => {
|
|
264
264
|
response.data = JSON.parse(responseData);
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aj-archipelago/cortex",
|
|
3
|
-
"version": "1.1.
|
|
3
|
+
"version": "1.1.5",
|
|
4
4
|
"description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
|
|
5
5
|
"private": false,
|
|
6
6
|
"repository": {
|
|
@@ -53,6 +53,7 @@
|
|
|
53
53
|
"ioredis": "^5.3.1",
|
|
54
54
|
"keyv": "^4.5.2",
|
|
55
55
|
"langchain": "^0.0.47",
|
|
56
|
+
"mime-types": "^2.1.35",
|
|
56
57
|
"subsrt": "^1.1.1",
|
|
57
58
|
"uuid": "^9.0.0",
|
|
58
59
|
"winston": "^3.11.0",
|
package/pathways/bias.js
CHANGED
|
@@ -6,5 +6,5 @@ export default {
|
|
|
6
6
|
// Uncomment the following line to enable caching for this prompt, if desired.
|
|
7
7
|
// enableCache: true,
|
|
8
8
|
|
|
9
|
-
prompt: `{{text}}\n\nIs the above text written objectively? Why or why not, explain with details:\n
|
|
9
|
+
prompt: `{{text}}\n\nIs the above text written objectively? Why or why not, explain with details:\n`,
|
|
10
10
|
};
|
package/server/graphql.js
CHANGED
|
@@ -97,7 +97,7 @@ const getResolvers = (config, pathways) => {
|
|
|
97
97
|
// add shared state to contextValue
|
|
98
98
|
contextValue.pathway = pathway;
|
|
99
99
|
contextValue.config = config;
|
|
100
|
-
|
|
100
|
+
return pathway.rootResolver(parent, args, contextValue, info);
|
|
101
101
|
}
|
|
102
102
|
}
|
|
103
103
|
|
|
@@ -131,6 +131,8 @@ const build = async (config) => {
|
|
|
131
131
|
|
|
132
132
|
const app = express();
|
|
133
133
|
|
|
134
|
+
app.use(express.json({ limit: '50mb' }));
|
|
135
|
+
|
|
134
136
|
const httpServer = http.createServer(app);
|
|
135
137
|
|
|
136
138
|
// Creating the WebSocket server
|
|
@@ -176,8 +178,8 @@ const build = async (config) => {
|
|
|
176
178
|
});
|
|
177
179
|
|
|
178
180
|
// If CORTEX_API_KEY is set, we roll our own auth middleware - usually not used if you're being fronted by a proxy
|
|
179
|
-
const
|
|
180
|
-
if (
|
|
181
|
+
const cortexApiKeys = config.get('cortexApiKeys');
|
|
182
|
+
if (cortexApiKeys && Array.isArray(cortexApiKeys)) {
|
|
181
183
|
app.use((req, res, next) => {
|
|
182
184
|
let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key'];
|
|
183
185
|
if (!providedApiKey) {
|
|
@@ -185,7 +187,7 @@ const build = async (config) => {
|
|
|
185
187
|
providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey;
|
|
186
188
|
}
|
|
187
189
|
|
|
188
|
-
if (
|
|
190
|
+
if (!cortexApiKeys.includes(providedApiKey)) {
|
|
189
191
|
if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') {
|
|
190
192
|
res.status(401)
|
|
191
193
|
.set('WWW-Authenticate', 'Cortex-Api-Key')
|
package/server/modelExecutor.js
CHANGED
|
@@ -17,6 +17,8 @@ import OpenAiEmbeddingsPlugin from './plugins/openAiEmbeddingsPlugin.js';
|
|
|
17
17
|
import OpenAIImagePlugin from './plugins/openAiImagePlugin.js';
|
|
18
18
|
import OpenAIDallE3Plugin from './plugins/openAiDallE3Plugin.js';
|
|
19
19
|
import OpenAIVisionPlugin from './plugins/openAiVisionPlugin.js';
|
|
20
|
+
import GeminiChatPlugin from './plugins/geminiChatPlugin.js';
|
|
21
|
+
import GeminiVisionPlugin from './plugins/geminiVisionPlugin.js';
|
|
20
22
|
|
|
21
23
|
class ModelExecutor {
|
|
22
24
|
constructor(pathway, model) {
|
|
@@ -72,6 +74,12 @@ class ModelExecutor {
|
|
|
72
74
|
case 'OPENAI-VISION':
|
|
73
75
|
plugin = new OpenAIVisionPlugin(pathway, model);
|
|
74
76
|
break;
|
|
77
|
+
case 'GEMINI-CHAT':
|
|
78
|
+
plugin = new GeminiChatPlugin(pathway, model);
|
|
79
|
+
break;
|
|
80
|
+
case 'GEMINI-VISION':
|
|
81
|
+
plugin = new GeminiVisionPlugin(pathway, model);
|
|
82
|
+
break;
|
|
75
83
|
default:
|
|
76
84
|
throw new Error(`Unsupported model type: ${model.type}`);
|
|
77
85
|
}
|
|
@@ -98,8 +98,9 @@ class PathwayResolver {
|
|
|
98
98
|
const incomingMessage = responseData;
|
|
99
99
|
|
|
100
100
|
let messageBuffer = '';
|
|
101
|
+
let streamEnded = false;
|
|
101
102
|
|
|
102
|
-
const
|
|
103
|
+
const processStreamSSE = (data) => {
|
|
103
104
|
try {
|
|
104
105
|
//logger.info(`\n\nReceived stream data for requestId ${this.requestId}: ${data.toString()}`);
|
|
105
106
|
let events = data.toString().split('\n');
|
|
@@ -132,18 +133,35 @@ class PathwayResolver {
|
|
|
132
133
|
return;
|
|
133
134
|
}
|
|
134
135
|
|
|
136
|
+
// error can be in different places in the message
|
|
135
137
|
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
|
|
136
138
|
if (streamError) {
|
|
137
139
|
streamErrorOccurred = true;
|
|
138
140
|
logger.error(`Stream error: ${streamError.message}`);
|
|
139
|
-
incomingMessage.off('data',
|
|
141
|
+
incomingMessage.off('data', processStreamSSE);
|
|
140
142
|
return;
|
|
141
143
|
}
|
|
144
|
+
|
|
145
|
+
// finish reason can be in different places in the message
|
|
146
|
+
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
|
|
147
|
+
if (finishReason?.toLowerCase() === 'stop') {
|
|
148
|
+
requestProgress.progress = 1;
|
|
149
|
+
} else {
|
|
150
|
+
if (finishReason?.toLowerCase() === 'safety') {
|
|
151
|
+
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
|
|
152
|
+
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
|
|
153
|
+
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
|
|
154
|
+
requestProgress.progress = 1;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
142
157
|
}
|
|
143
158
|
|
|
144
159
|
try {
|
|
145
|
-
|
|
146
|
-
|
|
160
|
+
if (!streamEnded) {
|
|
161
|
+
//logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
|
|
162
|
+
publishRequestProgress(requestProgress);
|
|
163
|
+
streamEnded = requestProgress.progress === 1;
|
|
164
|
+
}
|
|
147
165
|
} catch (error) {
|
|
148
166
|
logger.error(`Could not publish the stream message: "${messageBuffer}", ${error}`);
|
|
149
167
|
}
|
|
@@ -156,7 +174,7 @@ class PathwayResolver {
|
|
|
156
174
|
|
|
157
175
|
if (incomingMessage) {
|
|
158
176
|
await new Promise((resolve, reject) => {
|
|
159
|
-
incomingMessage.on('data',
|
|
177
|
+
incomingMessage.on('data', processStreamSSE);
|
|
160
178
|
incomingMessage.on('end', resolve);
|
|
161
179
|
incomingMessage.on('error', reject);
|
|
162
180
|
});
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
// geminiChatPlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
import { encode } from 'gpt-3-encoder';
|
|
4
|
+
import logger from '../../lib/logger.js';
|
|
5
|
+
|
|
6
|
+
const mergeResults = (data) => {
|
|
7
|
+
let output = '';
|
|
8
|
+
let safetyRatings = [];
|
|
9
|
+
|
|
10
|
+
for (let chunk of data) {
|
|
11
|
+
const { candidates } = chunk;
|
|
12
|
+
if (!candidates || !candidates.length) {
|
|
13
|
+
continue;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
// If it was blocked, return the blocked message
|
|
17
|
+
if (candidates[0].safetyRatings.some(rating => rating.blocked)) {
|
|
18
|
+
safetyRatings = candidates[0].safetyRatings;
|
|
19
|
+
return {mergedResult: 'The response was blocked because the input or response potentially violates policies. Try rephrasing the prompt or adjusting the parameter settings.', safetyRatings: safetyRatings};
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// Append the content of the first part of the first candidate to the output
|
|
23
|
+
const message = candidates[0].content.parts[0].text;
|
|
24
|
+
output += message;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
return {mergedResult: output || null, safetyRatings: safetyRatings};
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
class GeminiChatPlugin extends ModelPlugin {
|
|
31
|
+
constructor(pathway, model) {
|
|
32
|
+
super(pathway, model);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// This code converts either OpenAI or PaLM messages to the Gemini messages format
|
|
36
|
+
convertMessagesToGemini(messages) {
|
|
37
|
+
let modifiedMessages = [];
|
|
38
|
+
let lastAuthor = '';
|
|
39
|
+
|
|
40
|
+
// Check if the messages are already in the Gemini format
|
|
41
|
+
if (messages[0] && Object.prototype.hasOwnProperty.call(messages[0], 'parts')) {
|
|
42
|
+
modifiedMessages = messages;
|
|
43
|
+
} else {
|
|
44
|
+
messages.forEach(message => {
|
|
45
|
+
const { role, author, content } = message;
|
|
46
|
+
|
|
47
|
+
// Right now Gemini API has no direct translation for system messages,
|
|
48
|
+
// but they work fine as parts of user messages
|
|
49
|
+
if (role === 'system') {
|
|
50
|
+
modifiedMessages.push({
|
|
51
|
+
role: 'user',
|
|
52
|
+
parts: [{ text: content }],
|
|
53
|
+
});
|
|
54
|
+
lastAuthor = 'user';
|
|
55
|
+
return;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Aggregate consecutive author messages, appending the content
|
|
59
|
+
if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
|
|
60
|
+
modifiedMessages[modifiedMessages.length - 1].parts.push({ text: content });
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// Push messages that are role: 'user' or 'assistant', changing 'assistant' to 'model'
|
|
64
|
+
else if (role === 'user' || role === 'assistant' || author) {
|
|
65
|
+
modifiedMessages.push({
|
|
66
|
+
role: author || role,
|
|
67
|
+
parts: [{ text: content }],
|
|
68
|
+
});
|
|
69
|
+
lastAuthor = author || role;
|
|
70
|
+
}
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// Gemini requires an even number of messages
|
|
75
|
+
if (modifiedMessages.length % 2 === 0) {
|
|
76
|
+
modifiedMessages = modifiedMessages.slice(1);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
modifiedMessages,
|
|
81
|
+
};
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Set up parameters specific to the Gemini API
|
|
85
|
+
getRequestParameters(text, parameters, prompt, cortexRequest) {
|
|
86
|
+
const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
87
|
+
const { geminiSafetySettings, geminiTools, max_tokens } = cortexRequest ? cortexRequest.pathway : {};
|
|
88
|
+
|
|
89
|
+
// Define the model's max token length
|
|
90
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
91
|
+
|
|
92
|
+
const geminiMessages = this.convertMessagesToGemini(modelPromptMessages || [{ "role": "user", "parts": [{ "text": modelPromptText }]}]);
|
|
93
|
+
|
|
94
|
+
let requestMessages = geminiMessages.modifiedMessages;
|
|
95
|
+
|
|
96
|
+
// Check if the token length exceeds the model's max token length
|
|
97
|
+
if (tokenLength > modelTargetTokenLength) {
|
|
98
|
+
// Remove older messages until the token length is within the model's limit
|
|
99
|
+
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if (max_tokens < 0) {
|
|
103
|
+
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
const requestParameters = {
|
|
107
|
+
contents: requestMessages,
|
|
108
|
+
generationConfig: {
|
|
109
|
+
temperature: this.temperature || 0.7,
|
|
110
|
+
maxOutputTokens: max_tokens || this.getModelMaxReturnTokens(),
|
|
111
|
+
topP: parameters.topP || 0.95,
|
|
112
|
+
topK: parameters.topK || 40,
|
|
113
|
+
},
|
|
114
|
+
safety_settings: geminiSafetySettings || undefined,
|
|
115
|
+
tools: geminiTools || undefined
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
return requestParameters;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// Parse the response from the new Chat API
|
|
122
|
+
parseResponse(data) {
|
|
123
|
+
// If data is not an array, return it directly
|
|
124
|
+
if (!Array.isArray(data)) {
|
|
125
|
+
return data;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return mergeResults(data).mergedResult || null;
|
|
129
|
+
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// Execute the request to the new Chat API
|
|
133
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
134
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt, cortexRequest);
|
|
135
|
+
const { stream } = parameters;
|
|
136
|
+
|
|
137
|
+
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
|
|
138
|
+
cortexRequest.params = {}; // query params
|
|
139
|
+
cortexRequest.stream = stream;
|
|
140
|
+
cortexRequest.url = cortexRequest.stream ? `${cortexRequest.url}?alt=sse` : cortexRequest.url;
|
|
141
|
+
|
|
142
|
+
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
143
|
+
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
144
|
+
cortexRequest.headers.Authorization = `Bearer ${authToken}`;
|
|
145
|
+
|
|
146
|
+
return this.executeRequest(cortexRequest);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
// Override the logging function to display the messages and responses
|
|
150
|
+
logRequestData(data, responseData, prompt) {
|
|
151
|
+
this.logAIRequestFinished();
|
|
152
|
+
|
|
153
|
+
const messages = data && data.contents;
|
|
154
|
+
|
|
155
|
+
if (messages && messages.length > 1) {
|
|
156
|
+
logger.info(`[chat request contains ${messages.length} messages]`);
|
|
157
|
+
messages.forEach((message, index) => {
|
|
158
|
+
const messageContent = message.parts.reduce((acc, part) => {
|
|
159
|
+
if (part.text) {
|
|
160
|
+
return acc + part.text;
|
|
161
|
+
}
|
|
162
|
+
return acc;
|
|
163
|
+
} , '');
|
|
164
|
+
const words = messageContent.split(" ");
|
|
165
|
+
const tokenCount = encode(messageContent).length;
|
|
166
|
+
const preview = words.length < 41 ? messageContent : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
167
|
+
|
|
168
|
+
logger.debug(`Message ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"`);
|
|
169
|
+
});
|
|
170
|
+
} else if (messages && messages.length === 1) {
|
|
171
|
+
logger.debug(`${messages[0].parts[0].text}`);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// check if responseData is an array
|
|
175
|
+
if (!Array.isArray(responseData)) {
|
|
176
|
+
logger.info(`[response received as an SSE stream]`);
|
|
177
|
+
} else {
|
|
178
|
+
const { mergedResult, safetyRatings } = mergeResults(responseData);
|
|
179
|
+
if (safetyRatings?.length) {
|
|
180
|
+
logger.warn(`!!! response was blocked because the input or response potentially violates policies`);
|
|
181
|
+
logger.debug(`Safety Ratings: ${JSON.stringify(safetyRatings, null, 2)}`);
|
|
182
|
+
}
|
|
183
|
+
const responseTokens = encode(mergedResult).length;
|
|
184
|
+
logger.info(`[response received containing ${responseTokens} tokens]`);
|
|
185
|
+
logger.debug(`${mergedResult}`);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if (prompt && prompt.debugInfo) {
|
|
189
|
+
prompt.debugInfo += `\n${JSON.stringify(data)}`;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
export default GeminiChatPlugin;
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import GeminiChatPlugin from './geminiChatPlugin.js';
|
|
2
|
+
import mime from 'mime-types';
|
|
3
|
+
import logger from '../../lib/logger.js';
|
|
4
|
+
|
|
5
|
+
class GeminiVisionPlugin extends GeminiChatPlugin {
|
|
6
|
+
|
|
7
|
+
// Override the convertMessagesToGemini method to handle multimodal vision messages
|
|
8
|
+
// This function can operate on messages in Gemini native format or in OpenAI's format
|
|
9
|
+
// It will convert the messages to the Gemini format
|
|
10
|
+
convertMessagesToGemini(messages) {
|
|
11
|
+
let modifiedMessages = [];
|
|
12
|
+
let lastAuthor = '';
|
|
13
|
+
|
|
14
|
+
// Check if the messages are already in the Gemini format
|
|
15
|
+
if (messages[0] && Object.prototype.hasOwnProperty.call(messages[0], 'parts')) {
|
|
16
|
+
modifiedMessages = messages;
|
|
17
|
+
} else {
|
|
18
|
+
messages.forEach(message => {
|
|
19
|
+
const { role, author, content } = message;
|
|
20
|
+
|
|
21
|
+
// Right now Gemini API has no direct translation for system messages,
|
|
22
|
+
// so we insert them as parts of the first user: role message
|
|
23
|
+
if (role === 'system') {
|
|
24
|
+
modifiedMessages.push({
|
|
25
|
+
role: 'user',
|
|
26
|
+
parts: [{ text: content }],
|
|
27
|
+
});
|
|
28
|
+
lastAuthor = 'user';
|
|
29
|
+
return;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// Convert content to Gemini format, trying to maintain compatibility
|
|
33
|
+
const convertPartToGemini = (partString) => {
|
|
34
|
+
try {
|
|
35
|
+
const part = JSON.parse(partString);
|
|
36
|
+
if (typeof part === 'string') {
|
|
37
|
+
return { text: part };
|
|
38
|
+
} else if (part.type === 'text') {
|
|
39
|
+
return { text: part.text };
|
|
40
|
+
} else if (part.type === 'image_url') {
|
|
41
|
+
if (part.image_url.url.startsWith('gs://')) {
|
|
42
|
+
return {
|
|
43
|
+
fileData: {
|
|
44
|
+
mimeType: mime.lookup(part.image_url.url),
|
|
45
|
+
fileUri: part.image_url.url
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
} else {
|
|
49
|
+
return {
|
|
50
|
+
inlineData: {
|
|
51
|
+
mimeType: 'image/jpeg', // fixed for now as there's no MIME type in the request
|
|
52
|
+
data: part.image_url.url.split('base64,')[1]
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
} catch (e) {
|
|
58
|
+
logger.warn(`Unable to parse part - including as string: ${partString}`);
|
|
59
|
+
}
|
|
60
|
+
return { text: partString };
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
const addPartToMessages = (geminiPart) => {
|
|
64
|
+
// Gemini requires alternating user: and model: messages
|
|
65
|
+
if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
|
|
66
|
+
modifiedMessages[modifiedMessages.length - 1].parts.push(geminiPart);
|
|
67
|
+
}
|
|
68
|
+
// Gemini only supports user: and model: roles
|
|
69
|
+
else if (role === 'user' || role === 'assistant' || author) {
|
|
70
|
+
modifiedMessages.push({
|
|
71
|
+
role: author || role,
|
|
72
|
+
parts: [geminiPart],
|
|
73
|
+
});
|
|
74
|
+
lastAuthor = author || role;
|
|
75
|
+
}
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
// Content can either be in the "vision" format (array) or in the "chat" format (string)
|
|
79
|
+
if (Array.isArray(content)) {
|
|
80
|
+
content.forEach(part => {
|
|
81
|
+
addPartToMessages(convertPartToGemini(part));
|
|
82
|
+
});
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
addPartToMessages(convertPartToGemini(content));
|
|
86
|
+
}
|
|
87
|
+
});
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Gemini requires an even number of messages
|
|
91
|
+
if (modifiedMessages.length % 2 === 0) {
|
|
92
|
+
modifiedMessages = modifiedMessages.slice(1);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
modifiedMessages,
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
export default GeminiVisionPlugin;
|
|
@@ -269,9 +269,10 @@ class ModelPlugin {
|
|
|
269
269
|
|
|
270
270
|
const responseData = await executeRequest(cortexRequest);
|
|
271
271
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
272
|
+
let errorData = Array.isArray(responseData) ? responseData[0] : responseData;
|
|
273
|
+
|
|
274
|
+
if (errorData && errorData.error) {
|
|
275
|
+
throw new Error(`Server error: ${JSON.stringify(errorData.error)}`);
|
|
275
276
|
}
|
|
276
277
|
|
|
277
278
|
this.logRequestData(data, responseData, prompt);
|
|
@@ -7,11 +7,13 @@ class OpenAiEmbeddingsPlugin extends ModelPlugin {
|
|
|
7
7
|
}
|
|
8
8
|
|
|
9
9
|
getRequestParameters(text, parameters, prompt) {
|
|
10
|
-
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
10
|
+
const combinedParameters = { ...this.promptParameters, ...this.model.params, ...parameters };
|
|
11
11
|
const { modelPromptText } = this.getCompiledPrompt(text, combinedParameters, prompt);
|
|
12
|
+
const { model } = combinedParameters;
|
|
12
13
|
const requestParameters = {
|
|
13
14
|
data: {
|
|
14
15
|
input: combinedParameters?.input?.length ? combinedParameters.input : modelPromptText || text,
|
|
16
|
+
model
|
|
15
17
|
}
|
|
16
18
|
};
|
|
17
19
|
return requestParameters;
|
|
@@ -289,7 +289,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
|
|
|
289
289
|
}
|
|
290
290
|
|
|
291
291
|
} catch (error) {
|
|
292
|
-
const errMsg = `Transcribe error: ${error?.message || error}`;
|
|
292
|
+
const errMsg = `Transcribe error: ${error?.response?.data || error?.message || error}`;
|
|
293
293
|
logger.error(errMsg);
|
|
294
294
|
return errMsg;
|
|
295
295
|
}
|
package/server/rest.js
CHANGED
|
@@ -85,7 +85,7 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
85
85
|
}
|
|
86
86
|
|
|
87
87
|
const sendStreamData = (data) => {
|
|
88
|
-
|
|
88
|
+
logger.debug(`REST SEND: data: ${JSON.stringify(data)}`);
|
|
89
89
|
const dataString = (data==='[DONE]') ? data : JSON.stringify(data);
|
|
90
90
|
|
|
91
91
|
if (!res.writableEnded) {
|
|
@@ -93,9 +93,9 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
93
93
|
}
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
-
const fillJsonResponse = (jsonResponse, inputText,
|
|
96
|
+
const fillJsonResponse = (jsonResponse, inputText, _finishReason) => {
|
|
97
97
|
|
|
98
|
-
jsonResponse.choices[0].finish_reason =
|
|
98
|
+
jsonResponse.choices[0].finish_reason = null;
|
|
99
99
|
if (jsonResponse.object === 'text_completion') {
|
|
100
100
|
jsonResponse.choices[0].text = inputText;
|
|
101
101
|
} else {
|
|
@@ -114,7 +114,10 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
114
114
|
const safeUnsubscribe = async () => {
|
|
115
115
|
if (subscription) {
|
|
116
116
|
try {
|
|
117
|
-
|
|
117
|
+
const subPromiseResult = await subscription;
|
|
118
|
+
if (subPromiseResult) {
|
|
119
|
+
pubsub.unsubscribe(subPromiseResult);
|
|
120
|
+
}
|
|
118
121
|
} catch (error) {
|
|
119
122
|
logger.error(`Error unsubscribing from pubsub: ${error}`);
|
|
120
123
|
}
|
|
@@ -122,7 +125,7 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
122
125
|
}
|
|
123
126
|
|
|
124
127
|
if (data.requestProgress.requestId === requestId) {
|
|
125
|
-
|
|
128
|
+
logger.debug(`REQUEST_PROGRESS received progress: ${data.requestProgress.progress}, data: ${data.requestProgress.data}`);
|
|
126
129
|
|
|
127
130
|
const progress = data.requestProgress.progress;
|
|
128
131
|
const progressData = data.requestProgress.data;
|
|
@@ -142,6 +145,9 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
142
145
|
} else {
|
|
143
146
|
fillJsonResponse(jsonResponse, delta.content, finish_reason);
|
|
144
147
|
}
|
|
148
|
+
} else if (messageJson.candidates) {
|
|
149
|
+
const { content, finishReason } = messageJson.candidates[0];
|
|
150
|
+
fillJsonResponse(jsonResponse, content.parts[0].text, finishReason);
|
|
145
151
|
} else {
|
|
146
152
|
fillJsonResponse(jsonResponse, messageJson, null);
|
|
147
153
|
}
|