@aj-archipelago/cortex 1.1.4 → 1.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/config.js +2 -2
- package/lib/cortexRequest.js +11 -1
- package/lib/requestExecutor.js +4 -4
- package/package.json +2 -1
- package/pathways/bias.js +1 -1
- package/pathways/cognitive_insert.js +1 -1
- package/server/graphql.js +2 -0
- package/server/modelExecutor.js +8 -0
- package/server/pathwayResolver.js +23 -5
- package/server/plugins/geminiChatPlugin.js +195 -0
- package/server/plugins/geminiVisionPlugin.js +102 -0
- package/server/plugins/modelPlugin.js +4 -3
- package/server/plugins/openAiEmbeddingsPlugin.js +3 -1
- package/server/rest.js +11 -5
package/config.js
CHANGED
|
@@ -122,9 +122,9 @@ var config = convict({
|
|
|
122
122
|
},
|
|
123
123
|
"oai-embeddings": {
|
|
124
124
|
"type": "OPENAI-EMBEDDINGS",
|
|
125
|
-
"url": "https://
|
|
125
|
+
"url": "https://api.openai.com/v1/embeddings",
|
|
126
126
|
"headers": {
|
|
127
|
-
"
|
|
127
|
+
"Authorization": "Bearer {{OPENAI_API_KEY}}",
|
|
128
128
|
"Content-Type": "application/json"
|
|
129
129
|
},
|
|
130
130
|
"params": {
|
package/lib/cortexRequest.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { selectEndpoint } from './requestExecutor.js';
|
|
2
2
|
|
|
3
3
|
class CortexRequest {
|
|
4
|
-
constructor( { url, data, params, headers, cache, model, pathwayResolver, selectedEndpoint } = {}) {
|
|
4
|
+
constructor( { url, data, params, headers, cache, model, pathwayResolver, selectedEndpoint, stream } = {}) {
|
|
5
5
|
this._url = url || '';
|
|
6
6
|
this._data = data || {};
|
|
7
7
|
this._params = params || {};
|
|
@@ -10,6 +10,7 @@ class CortexRequest {
|
|
|
10
10
|
this._model = model || '';
|
|
11
11
|
this._pathwayResolver = pathwayResolver || {};
|
|
12
12
|
this._selectedEndpoint = selectedEndpoint || {};
|
|
13
|
+
this._stream = stream || false;
|
|
13
14
|
|
|
14
15
|
if (this._pathwayResolver) {
|
|
15
16
|
this._model = this._pathwayResolver.model;
|
|
@@ -112,6 +113,15 @@ class CortexRequest {
|
|
|
112
113
|
set pathwayResolver(value) {
|
|
113
114
|
this._pathwayResolver = value;
|
|
114
115
|
}
|
|
116
|
+
|
|
117
|
+
// stream getter and setter
|
|
118
|
+
get stream() {
|
|
119
|
+
return this._stream;
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
set stream(value) {
|
|
123
|
+
this._stream = value;
|
|
124
|
+
}
|
|
115
125
|
}
|
|
116
126
|
|
|
117
127
|
export default CortexRequest;
|
package/lib/requestExecutor.js
CHANGED
|
@@ -192,7 +192,7 @@ const DUPLICATE_REQUEST_AFTER = 10; // 10 seconds
|
|
|
192
192
|
const postRequest = async (cortexRequest) => {
|
|
193
193
|
let promises = [];
|
|
194
194
|
for (let i = 0; i < MAX_RETRY; i++) {
|
|
195
|
-
const { url, data, params, headers, cache, selectedEndpoint, requestId, pathway, model} = cortexRequest;
|
|
195
|
+
const { url, data, params, headers, cache, selectedEndpoint, requestId, pathway, model, stream} = cortexRequest;
|
|
196
196
|
const enableDuplicateRequests = pathway?.enableDuplicateRequests !== undefined ? pathway.enableDuplicateRequests : config.get('enableDuplicateRequests');
|
|
197
197
|
let maxDuplicateRequests = enableDuplicateRequests ? MAX_DUPLICATE_REQUESTS : 1;
|
|
198
198
|
let duplicateRequestAfter = (pathway?.duplicateRequestAfter || DUPLICATE_REQUEST_AFTER) * 1000;
|
|
@@ -202,7 +202,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
const axiosConfigObj = { params, headers, cache };
|
|
205
|
-
const streamRequested = (params?.stream || data?.stream);
|
|
205
|
+
const streamRequested = (stream || params?.stream || data?.stream);
|
|
206
206
|
if (streamRequested && model.supportsStreaming) {
|
|
207
207
|
axiosConfigObj.responseType = 'stream';
|
|
208
208
|
promises.push(selectedEndpoint.limiter.schedule({expiration: pathway.timeout * 1000 + 1000, id: `${requestId}_${uuidv4()}`},() => postWithMonitor(selectedEndpoint, url, data, axiosConfigObj)));
|
|
@@ -249,7 +249,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
249
249
|
|
|
250
250
|
if (!controller.signal?.aborted) {
|
|
251
251
|
|
|
252
|
-
|
|
252
|
+
logger.debug(`<<< [${requestId}] received response for request ${index}`);
|
|
253
253
|
|
|
254
254
|
if (axiosConfigObj.responseType === 'stream') {
|
|
255
255
|
// Buffering and collecting the stream data
|
|
@@ -258,7 +258,7 @@ const postRequest = async (cortexRequest) => {
|
|
|
258
258
|
let responseData = '';
|
|
259
259
|
response.data.on('data', (chunk) => {
|
|
260
260
|
responseData += chunk;
|
|
261
|
-
|
|
261
|
+
logger.debug(`<<< [${requestId}] received chunk for request ${index}`);
|
|
262
262
|
});
|
|
263
263
|
response.data.on('end', () => {
|
|
264
264
|
response.data = JSON.parse(responseData);
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aj-archipelago/cortex",
|
|
3
|
-
"version": "1.1.
|
|
3
|
+
"version": "1.1.5",
|
|
4
4
|
"description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
|
|
5
5
|
"private": false,
|
|
6
6
|
"repository": {
|
|
@@ -53,6 +53,7 @@
|
|
|
53
53
|
"ioredis": "^5.3.1",
|
|
54
54
|
"keyv": "^4.5.2",
|
|
55
55
|
"langchain": "^0.0.47",
|
|
56
|
+
"mime-types": "^2.1.35",
|
|
56
57
|
"subsrt": "^1.1.1",
|
|
57
58
|
"uuid": "^9.0.0",
|
|
58
59
|
"winston": "^3.11.0",
|
package/pathways/bias.js
CHANGED
|
@@ -6,5 +6,5 @@ export default {
|
|
|
6
6
|
// Uncomment the following line to enable caching for this prompt, if desired.
|
|
7
7
|
// enableCache: true,
|
|
8
8
|
|
|
9
|
-
prompt: `{{text}}\n\nIs the above text written objectively? Why or why not, explain with details:\n
|
|
9
|
+
prompt: `{{text}}\n\nIs the above text written objectively? Why or why not, explain with details:\n`,
|
|
10
10
|
};
|
package/server/graphql.js
CHANGED
package/server/modelExecutor.js
CHANGED
|
@@ -17,6 +17,8 @@ import OpenAiEmbeddingsPlugin from './plugins/openAiEmbeddingsPlugin.js';
|
|
|
17
17
|
import OpenAIImagePlugin from './plugins/openAiImagePlugin.js';
|
|
18
18
|
import OpenAIDallE3Plugin from './plugins/openAiDallE3Plugin.js';
|
|
19
19
|
import OpenAIVisionPlugin from './plugins/openAiVisionPlugin.js';
|
|
20
|
+
import GeminiChatPlugin from './plugins/geminiChatPlugin.js';
|
|
21
|
+
import GeminiVisionPlugin from './plugins/geminiVisionPlugin.js';
|
|
20
22
|
|
|
21
23
|
class ModelExecutor {
|
|
22
24
|
constructor(pathway, model) {
|
|
@@ -72,6 +74,12 @@ class ModelExecutor {
|
|
|
72
74
|
case 'OPENAI-VISION':
|
|
73
75
|
plugin = new OpenAIVisionPlugin(pathway, model);
|
|
74
76
|
break;
|
|
77
|
+
case 'GEMINI-CHAT':
|
|
78
|
+
plugin = new GeminiChatPlugin(pathway, model);
|
|
79
|
+
break;
|
|
80
|
+
case 'GEMINI-VISION':
|
|
81
|
+
plugin = new GeminiVisionPlugin(pathway, model);
|
|
82
|
+
break;
|
|
75
83
|
default:
|
|
76
84
|
throw new Error(`Unsupported model type: ${model.type}`);
|
|
77
85
|
}
|
|
@@ -98,8 +98,9 @@ class PathwayResolver {
|
|
|
98
98
|
const incomingMessage = responseData;
|
|
99
99
|
|
|
100
100
|
let messageBuffer = '';
|
|
101
|
+
let streamEnded = false;
|
|
101
102
|
|
|
102
|
-
const
|
|
103
|
+
const processStreamSSE = (data) => {
|
|
103
104
|
try {
|
|
104
105
|
//logger.info(`\n\nReceived stream data for requestId ${this.requestId}: ${data.toString()}`);
|
|
105
106
|
let events = data.toString().split('\n');
|
|
@@ -132,18 +133,35 @@ class PathwayResolver {
|
|
|
132
133
|
return;
|
|
133
134
|
}
|
|
134
135
|
|
|
136
|
+
// error can be in different places in the message
|
|
135
137
|
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
|
|
136
138
|
if (streamError) {
|
|
137
139
|
streamErrorOccurred = true;
|
|
138
140
|
logger.error(`Stream error: ${streamError.message}`);
|
|
139
|
-
incomingMessage.off('data',
|
|
141
|
+
incomingMessage.off('data', processStreamSSE);
|
|
140
142
|
return;
|
|
141
143
|
}
|
|
144
|
+
|
|
145
|
+
// finish reason can be in different places in the message
|
|
146
|
+
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
|
|
147
|
+
if (finishReason?.toLowerCase() === 'stop') {
|
|
148
|
+
requestProgress.progress = 1;
|
|
149
|
+
} else {
|
|
150
|
+
if (finishReason?.toLowerCase() === 'safety') {
|
|
151
|
+
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
|
|
152
|
+
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
|
|
153
|
+
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
|
|
154
|
+
requestProgress.progress = 1;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
142
157
|
}
|
|
143
158
|
|
|
144
159
|
try {
|
|
145
|
-
|
|
146
|
-
|
|
160
|
+
if (!streamEnded) {
|
|
161
|
+
//logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
|
|
162
|
+
publishRequestProgress(requestProgress);
|
|
163
|
+
streamEnded = requestProgress.progress === 1;
|
|
164
|
+
}
|
|
147
165
|
} catch (error) {
|
|
148
166
|
logger.error(`Could not publish the stream message: "${messageBuffer}", ${error}`);
|
|
149
167
|
}
|
|
@@ -156,7 +174,7 @@ class PathwayResolver {
|
|
|
156
174
|
|
|
157
175
|
if (incomingMessage) {
|
|
158
176
|
await new Promise((resolve, reject) => {
|
|
159
|
-
incomingMessage.on('data',
|
|
177
|
+
incomingMessage.on('data', processStreamSSE);
|
|
160
178
|
incomingMessage.on('end', resolve);
|
|
161
179
|
incomingMessage.on('error', reject);
|
|
162
180
|
});
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
// geminiChatPlugin.js
|
|
2
|
+
import ModelPlugin from './modelPlugin.js';
|
|
3
|
+
import { encode } from 'gpt-3-encoder';
|
|
4
|
+
import logger from '../../lib/logger.js';
|
|
5
|
+
|
|
6
|
+
const mergeResults = (data) => {
|
|
7
|
+
let output = '';
|
|
8
|
+
let safetyRatings = [];
|
|
9
|
+
|
|
10
|
+
for (let chunk of data) {
|
|
11
|
+
const { candidates } = chunk;
|
|
12
|
+
if (!candidates || !candidates.length) {
|
|
13
|
+
continue;
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
// If it was blocked, return the blocked message
|
|
17
|
+
if (candidates[0].safetyRatings.some(rating => rating.blocked)) {
|
|
18
|
+
safetyRatings = candidates[0].safetyRatings;
|
|
19
|
+
return {mergedResult: 'The response was blocked because the input or response potentially violates policies. Try rephrasing the prompt or adjusting the parameter settings.', safetyRatings: safetyRatings};
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// Append the content of the first part of the first candidate to the output
|
|
23
|
+
const message = candidates[0].content.parts[0].text;
|
|
24
|
+
output += message;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
return {mergedResult: output || null, safetyRatings: safetyRatings};
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
class GeminiChatPlugin extends ModelPlugin {
|
|
31
|
+
constructor(pathway, model) {
|
|
32
|
+
super(pathway, model);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// This code converts either OpenAI or PaLM messages to the Gemini messages format
|
|
36
|
+
convertMessagesToGemini(messages) {
|
|
37
|
+
let modifiedMessages = [];
|
|
38
|
+
let lastAuthor = '';
|
|
39
|
+
|
|
40
|
+
// Check if the messages are already in the Gemini format
|
|
41
|
+
if (messages[0] && Object.prototype.hasOwnProperty.call(messages[0], 'parts')) {
|
|
42
|
+
modifiedMessages = messages;
|
|
43
|
+
} else {
|
|
44
|
+
messages.forEach(message => {
|
|
45
|
+
const { role, author, content } = message;
|
|
46
|
+
|
|
47
|
+
// Right now Gemini API has no direct translation for system messages,
|
|
48
|
+
// but they work fine as parts of user messages
|
|
49
|
+
if (role === 'system') {
|
|
50
|
+
modifiedMessages.push({
|
|
51
|
+
role: 'user',
|
|
52
|
+
parts: [{ text: content }],
|
|
53
|
+
});
|
|
54
|
+
lastAuthor = 'user';
|
|
55
|
+
return;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// Aggregate consecutive author messages, appending the content
|
|
59
|
+
if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
|
|
60
|
+
modifiedMessages[modifiedMessages.length - 1].parts.push({ text: content });
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
// Push messages that are role: 'user' or 'assistant', changing 'assistant' to 'model'
|
|
64
|
+
else if (role === 'user' || role === 'assistant' || author) {
|
|
65
|
+
modifiedMessages.push({
|
|
66
|
+
role: author || role,
|
|
67
|
+
parts: [{ text: content }],
|
|
68
|
+
});
|
|
69
|
+
lastAuthor = author || role;
|
|
70
|
+
}
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// Gemini requires an even number of messages
|
|
75
|
+
if (modifiedMessages.length % 2 === 0) {
|
|
76
|
+
modifiedMessages = modifiedMessages.slice(1);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
modifiedMessages,
|
|
81
|
+
};
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// Set up parameters specific to the Gemini API
|
|
85
|
+
getRequestParameters(text, parameters, prompt, cortexRequest) {
|
|
86
|
+
const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
|
|
87
|
+
const { geminiSafetySettings, geminiTools, max_tokens } = cortexRequest ? cortexRequest.pathway : {};
|
|
88
|
+
|
|
89
|
+
// Define the model's max token length
|
|
90
|
+
const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
|
|
91
|
+
|
|
92
|
+
const geminiMessages = this.convertMessagesToGemini(modelPromptMessages || [{ "role": "user", "parts": [{ "text": modelPromptText }]}]);
|
|
93
|
+
|
|
94
|
+
let requestMessages = geminiMessages.modifiedMessages;
|
|
95
|
+
|
|
96
|
+
// Check if the token length exceeds the model's max token length
|
|
97
|
+
if (tokenLength > modelTargetTokenLength) {
|
|
98
|
+
// Remove older messages until the token length is within the model's limit
|
|
99
|
+
requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelTargetTokenLength);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
if (max_tokens < 0) {
|
|
103
|
+
throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
const requestParameters = {
|
|
107
|
+
contents: requestMessages,
|
|
108
|
+
generationConfig: {
|
|
109
|
+
temperature: this.temperature || 0.7,
|
|
110
|
+
maxOutputTokens: max_tokens || this.getModelMaxReturnTokens(),
|
|
111
|
+
topP: parameters.topP || 0.95,
|
|
112
|
+
topK: parameters.topK || 40,
|
|
113
|
+
},
|
|
114
|
+
safety_settings: geminiSafetySettings || undefined,
|
|
115
|
+
tools: geminiTools || undefined
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
return requestParameters;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
// Parse the response from the new Chat API
|
|
122
|
+
parseResponse(data) {
|
|
123
|
+
// If data is not an array, return it directly
|
|
124
|
+
if (!Array.isArray(data)) {
|
|
125
|
+
return data;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
return mergeResults(data).mergedResult || null;
|
|
129
|
+
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
// Execute the request to the new Chat API
|
|
133
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
134
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt, cortexRequest);
|
|
135
|
+
const { stream } = parameters;
|
|
136
|
+
|
|
137
|
+
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
|
|
138
|
+
cortexRequest.params = {}; // query params
|
|
139
|
+
cortexRequest.stream = stream;
|
|
140
|
+
cortexRequest.url = cortexRequest.stream ? `${cortexRequest.url}?alt=sse` : cortexRequest.url;
|
|
141
|
+
|
|
142
|
+
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
143
|
+
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
144
|
+
cortexRequest.headers.Authorization = `Bearer ${authToken}`;
|
|
145
|
+
|
|
146
|
+
return this.executeRequest(cortexRequest);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
// Override the logging function to display the messages and responses
|
|
150
|
+
logRequestData(data, responseData, prompt) {
|
|
151
|
+
this.logAIRequestFinished();
|
|
152
|
+
|
|
153
|
+
const messages = data && data.contents;
|
|
154
|
+
|
|
155
|
+
if (messages && messages.length > 1) {
|
|
156
|
+
logger.info(`[chat request contains ${messages.length} messages]`);
|
|
157
|
+
messages.forEach((message, index) => {
|
|
158
|
+
const messageContent = message.parts.reduce((acc, part) => {
|
|
159
|
+
if (part.text) {
|
|
160
|
+
return acc + part.text;
|
|
161
|
+
}
|
|
162
|
+
return acc;
|
|
163
|
+
} , '');
|
|
164
|
+
const words = messageContent.split(" ");
|
|
165
|
+
const tokenCount = encode(messageContent).length;
|
|
166
|
+
const preview = words.length < 41 ? messageContent : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
|
|
167
|
+
|
|
168
|
+
logger.debug(`Message ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"`);
|
|
169
|
+
});
|
|
170
|
+
} else if (messages && messages.length === 1) {
|
|
171
|
+
logger.debug(`${messages[0].parts[0].text}`);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// check if responseData is an array
|
|
175
|
+
if (!Array.isArray(responseData)) {
|
|
176
|
+
logger.info(`[response received as an SSE stream]`);
|
|
177
|
+
} else {
|
|
178
|
+
const { mergedResult, safetyRatings } = mergeResults(responseData);
|
|
179
|
+
if (safetyRatings?.length) {
|
|
180
|
+
logger.warn(`!!! response was blocked because the input or response potentially violates policies`);
|
|
181
|
+
logger.debug(`Safety Ratings: ${JSON.stringify(safetyRatings, null, 2)}`);
|
|
182
|
+
}
|
|
183
|
+
const responseTokens = encode(mergedResult).length;
|
|
184
|
+
logger.info(`[response received containing ${responseTokens} tokens]`);
|
|
185
|
+
logger.debug(`${mergedResult}`);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
if (prompt && prompt.debugInfo) {
|
|
189
|
+
prompt.debugInfo += `\n${JSON.stringify(data)}`;
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
export default GeminiChatPlugin;
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import GeminiChatPlugin from './geminiChatPlugin.js';
|
|
2
|
+
import mime from 'mime-types';
|
|
3
|
+
import logger from '../../lib/logger.js';
|
|
4
|
+
|
|
5
|
+
class GeminiVisionPlugin extends GeminiChatPlugin {
|
|
6
|
+
|
|
7
|
+
// Override the convertMessagesToGemini method to handle multimodal vision messages
|
|
8
|
+
// This function can operate on messages in Gemini native format or in OpenAI's format
|
|
9
|
+
// It will convert the messages to the Gemini format
|
|
10
|
+
convertMessagesToGemini(messages) {
|
|
11
|
+
let modifiedMessages = [];
|
|
12
|
+
let lastAuthor = '';
|
|
13
|
+
|
|
14
|
+
// Check if the messages are already in the Gemini format
|
|
15
|
+
if (messages[0] && Object.prototype.hasOwnProperty.call(messages[0], 'parts')) {
|
|
16
|
+
modifiedMessages = messages;
|
|
17
|
+
} else {
|
|
18
|
+
messages.forEach(message => {
|
|
19
|
+
const { role, author, content } = message;
|
|
20
|
+
|
|
21
|
+
// Right now Gemini API has no direct translation for system messages,
|
|
22
|
+
// so we insert them as parts of the first user: role message
|
|
23
|
+
if (role === 'system') {
|
|
24
|
+
modifiedMessages.push({
|
|
25
|
+
role: 'user',
|
|
26
|
+
parts: [{ text: content }],
|
|
27
|
+
});
|
|
28
|
+
lastAuthor = 'user';
|
|
29
|
+
return;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// Convert content to Gemini format, trying to maintain compatibility
|
|
33
|
+
const convertPartToGemini = (partString) => {
|
|
34
|
+
try {
|
|
35
|
+
const part = JSON.parse(partString);
|
|
36
|
+
if (typeof part === 'string') {
|
|
37
|
+
return { text: part };
|
|
38
|
+
} else if (part.type === 'text') {
|
|
39
|
+
return { text: part.text };
|
|
40
|
+
} else if (part.type === 'image_url') {
|
|
41
|
+
if (part.image_url.url.startsWith('gs://')) {
|
|
42
|
+
return {
|
|
43
|
+
fileData: {
|
|
44
|
+
mimeType: mime.lookup(part.image_url.url),
|
|
45
|
+
fileUri: part.image_url.url
|
|
46
|
+
}
|
|
47
|
+
};
|
|
48
|
+
} else {
|
|
49
|
+
return {
|
|
50
|
+
inlineData: {
|
|
51
|
+
mimeType: 'image/jpeg', // fixed for now as there's no MIME type in the request
|
|
52
|
+
data: part.image_url.url.split('base64,')[1]
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
} catch (e) {
|
|
58
|
+
logger.warn(`Unable to parse part - including as string: ${partString}`);
|
|
59
|
+
}
|
|
60
|
+
return { text: partString };
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
const addPartToMessages = (geminiPart) => {
|
|
64
|
+
// Gemini requires alternating user: and model: messages
|
|
65
|
+
if ((role === lastAuthor || author === lastAuthor) && modifiedMessages.length > 0) {
|
|
66
|
+
modifiedMessages[modifiedMessages.length - 1].parts.push(geminiPart);
|
|
67
|
+
}
|
|
68
|
+
// Gemini only supports user: and model: roles
|
|
69
|
+
else if (role === 'user' || role === 'assistant' || author) {
|
|
70
|
+
modifiedMessages.push({
|
|
71
|
+
role: author || role,
|
|
72
|
+
parts: [geminiPart],
|
|
73
|
+
});
|
|
74
|
+
lastAuthor = author || role;
|
|
75
|
+
}
|
|
76
|
+
};
|
|
77
|
+
|
|
78
|
+
// Content can either be in the "vision" format (array) or in the "chat" format (string)
|
|
79
|
+
if (Array.isArray(content)) {
|
|
80
|
+
content.forEach(part => {
|
|
81
|
+
addPartToMessages(convertPartToGemini(part));
|
|
82
|
+
});
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
addPartToMessages(convertPartToGemini(content));
|
|
86
|
+
}
|
|
87
|
+
});
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Gemini requires an even number of messages
|
|
91
|
+
if (modifiedMessages.length % 2 === 0) {
|
|
92
|
+
modifiedMessages = modifiedMessages.slice(1);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
return {
|
|
96
|
+
modifiedMessages,
|
|
97
|
+
};
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
export default GeminiVisionPlugin;
|
|
@@ -269,9 +269,10 @@ class ModelPlugin {
|
|
|
269
269
|
|
|
270
270
|
const responseData = await executeRequest(cortexRequest);
|
|
271
271
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
272
|
+
let errorData = Array.isArray(responseData) ? responseData[0] : responseData;
|
|
273
|
+
|
|
274
|
+
if (errorData && errorData.error) {
|
|
275
|
+
throw new Error(`Server error: ${JSON.stringify(errorData.error)}`);
|
|
275
276
|
}
|
|
276
277
|
|
|
277
278
|
this.logRequestData(data, responseData, prompt);
|
|
@@ -7,11 +7,13 @@ class OpenAiEmbeddingsPlugin extends ModelPlugin {
|
|
|
7
7
|
}
|
|
8
8
|
|
|
9
9
|
getRequestParameters(text, parameters, prompt) {
|
|
10
|
-
const combinedParameters = { ...this.promptParameters, ...parameters };
|
|
10
|
+
const combinedParameters = { ...this.promptParameters, ...this.model.params, ...parameters };
|
|
11
11
|
const { modelPromptText } = this.getCompiledPrompt(text, combinedParameters, prompt);
|
|
12
|
+
const { model } = combinedParameters;
|
|
12
13
|
const requestParameters = {
|
|
13
14
|
data: {
|
|
14
15
|
input: combinedParameters?.input?.length ? combinedParameters.input : modelPromptText || text,
|
|
16
|
+
model
|
|
15
17
|
}
|
|
16
18
|
};
|
|
17
19
|
return requestParameters;
|
package/server/rest.js
CHANGED
|
@@ -85,7 +85,7 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
85
85
|
}
|
|
86
86
|
|
|
87
87
|
const sendStreamData = (data) => {
|
|
88
|
-
|
|
88
|
+
logger.debug(`REST SEND: data: ${JSON.stringify(data)}`);
|
|
89
89
|
const dataString = (data==='[DONE]') ? data : JSON.stringify(data);
|
|
90
90
|
|
|
91
91
|
if (!res.writableEnded) {
|
|
@@ -93,9 +93,9 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
93
93
|
}
|
|
94
94
|
}
|
|
95
95
|
|
|
96
|
-
const fillJsonResponse = (jsonResponse, inputText,
|
|
96
|
+
const fillJsonResponse = (jsonResponse, inputText, _finishReason) => {
|
|
97
97
|
|
|
98
|
-
jsonResponse.choices[0].finish_reason =
|
|
98
|
+
jsonResponse.choices[0].finish_reason = null;
|
|
99
99
|
if (jsonResponse.object === 'text_completion') {
|
|
100
100
|
jsonResponse.choices[0].text = inputText;
|
|
101
101
|
} else {
|
|
@@ -114,7 +114,10 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
114
114
|
const safeUnsubscribe = async () => {
|
|
115
115
|
if (subscription) {
|
|
116
116
|
try {
|
|
117
|
-
|
|
117
|
+
const subPromiseResult = await subscription;
|
|
118
|
+
if (subPromiseResult) {
|
|
119
|
+
pubsub.unsubscribe(subPromiseResult);
|
|
120
|
+
}
|
|
118
121
|
} catch (error) {
|
|
119
122
|
logger.error(`Error unsubscribing from pubsub: ${error}`);
|
|
120
123
|
}
|
|
@@ -122,7 +125,7 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
122
125
|
}
|
|
123
126
|
|
|
124
127
|
if (data.requestProgress.requestId === requestId) {
|
|
125
|
-
|
|
128
|
+
logger.debug(`REQUEST_PROGRESS received progress: ${data.requestProgress.progress}, data: ${data.requestProgress.data}`);
|
|
126
129
|
|
|
127
130
|
const progress = data.requestProgress.progress;
|
|
128
131
|
const progressData = data.requestProgress.data;
|
|
@@ -142,6 +145,9 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
142
145
|
} else {
|
|
143
146
|
fillJsonResponse(jsonResponse, delta.content, finish_reason);
|
|
144
147
|
}
|
|
148
|
+
} else if (messageJson.candidates) {
|
|
149
|
+
const { content, finishReason } = messageJson.candidates[0];
|
|
150
|
+
fillJsonResponse(jsonResponse, content.parts[0].text, finishReason);
|
|
145
151
|
} else {
|
|
146
152
|
fillJsonResponse(jsonResponse, messageJson, null);
|
|
147
153
|
}
|