@aj-archipelago/cortex 1.1.7 → 1.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +3 -1
- package/pathways/test_langchain.mjs +6 -78
- package/server/modelExecutor.js +4 -0
- package/server/pathwayResolver.js +92 -115
- package/server/plugins/azureBingPlugin.js +4 -0
- package/server/plugins/claude3VertexPlugin.js +126 -0
- package/server/plugins/geminiChatPlugin.js +12 -1
- package/server/plugins/modelPlugin.js +41 -2
- package/server/plugins/openAiChatPlugin.js +1 -0
- package/server/plugins/openAiWhisperPlugin.js +7 -1
- package/server/plugins/palmChatPlugin.js +4 -1
- package/server/rest.js +4 -0
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@aj-archipelago/cortex",
|
|
3
|
-
"version": "1.1.
|
|
3
|
+
"version": "1.1.8",
|
|
4
4
|
"description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
|
|
5
5
|
"private": false,
|
|
6
6
|
"repository": {
|
|
@@ -35,6 +35,7 @@
|
|
|
35
35
|
"@datastructures-js/deque": "^1.0.4",
|
|
36
36
|
"@graphql-tools/schema": "^9.0.12",
|
|
37
37
|
"@keyv/redis": "^2.5.4",
|
|
38
|
+
"@langchain/openai": "^0.0.24",
|
|
38
39
|
"axios": "^1.3.4",
|
|
39
40
|
"axios-cache-interceptor": "^1.0.1",
|
|
40
41
|
"bottleneck": "^2.19.5",
|
|
@@ -42,6 +43,7 @@
|
|
|
42
43
|
"compromise": "^14.8.1",
|
|
43
44
|
"compromise-paragraphs": "^0.1.0",
|
|
44
45
|
"convict": "^6.2.3",
|
|
46
|
+
"eventsource-parser": "^1.1.2",
|
|
45
47
|
"express": "^4.18.2",
|
|
46
48
|
"form-data": "^4.0.0",
|
|
47
49
|
"google-auth-library": "^8.8.0",
|
|
@@ -2,12 +2,7 @@
|
|
|
2
2
|
// LangChain Cortex integration test
|
|
3
3
|
|
|
4
4
|
// Import required modules
|
|
5
|
-
import {
|
|
6
|
-
//import { PromptTemplate } from "langchain/prompts";
|
|
7
|
-
//import { LLMChain, ConversationChain } from "langchain/chains";
|
|
8
|
-
import { initializeAgentExecutor } from "langchain/agents";
|
|
9
|
-
import { SerpAPI, Calculator } from "langchain/tools";
|
|
10
|
-
//import { BufferMemory } from "langchain/memory";
|
|
5
|
+
import { ChatOpenAI } from "@langchain/openai";
|
|
11
6
|
|
|
12
7
|
export default {
|
|
13
8
|
|
|
@@ -15,89 +10,22 @@ export default {
|
|
|
15
10
|
resolver: async (parent, args, contextValue, _info) => {
|
|
16
11
|
|
|
17
12
|
const { config } = contextValue;
|
|
18
|
-
const env = config.getEnv();
|
|
19
13
|
|
|
20
14
|
// example of reading from a predefined config variable
|
|
21
15
|
const openAIApiKey = config.get('openaiApiKey');
|
|
22
|
-
// example of reading straight from environment
|
|
23
|
-
const serpApiKey = env.SERPAPI_API_KEY;
|
|
24
16
|
|
|
25
|
-
const model = new
|
|
26
|
-
const tools = [new SerpAPI( serpApiKey ), new Calculator()];
|
|
27
|
-
|
|
28
|
-
const executor = await initializeAgentExecutor(
|
|
29
|
-
tools,
|
|
30
|
-
model,
|
|
31
|
-
"zero-shot-react-description"
|
|
32
|
-
);
|
|
17
|
+
const model = new ChatOpenAI({ openAIApiKey: openAIApiKey, temperature: 0 });
|
|
33
18
|
|
|
34
19
|
console.log(`====================`);
|
|
35
|
-
console.log("Loaded langchain
|
|
20
|
+
console.log("Loaded langchain.");
|
|
36
21
|
const input = args.text;
|
|
37
22
|
console.log(`Executing with input "${input}"...`);
|
|
38
|
-
const result = await
|
|
39
|
-
console.log(`Got output ${result.
|
|
40
|
-
console.log(`====================`);
|
|
41
|
-
|
|
42
|
-
return result?.output;
|
|
43
|
-
},
|
|
44
|
-
|
|
45
|
-
/*
|
|
46
|
-
// Agent test case
|
|
47
|
-
resolver: async (parent, args, contextValue, info) => {
|
|
48
|
-
|
|
49
|
-
const { config } = contextValue;
|
|
50
|
-
const openAIApiKey = config.get('openaiApiKey');
|
|
51
|
-
const serpApiKey = config.get('serpApiKey');
|
|
52
|
-
|
|
53
|
-
const model = new OpenAI({ openAIApiKey: openAIApiKey, temperature: 0 });
|
|
54
|
-
const tools = [new SerpAPI( serpApiKey ), new Calculator()];
|
|
55
|
-
|
|
56
|
-
const executor = await initializeAgentExecutor(
|
|
57
|
-
tools,
|
|
58
|
-
model,
|
|
59
|
-
"zero-shot-react-description"
|
|
60
|
-
);
|
|
61
|
-
|
|
62
|
-
console.log(`====================`);
|
|
63
|
-
console.log("Loaded langchain agent.");
|
|
64
|
-
const input = args.text;
|
|
65
|
-
console.log(`Executing with input "${input}"...`);
|
|
66
|
-
const result = await executor.call({ input });
|
|
67
|
-
console.log(`Got output ${result.output}`);
|
|
68
|
-
console.log(`====================`);
|
|
69
|
-
|
|
70
|
-
return result?.output;
|
|
71
|
-
},
|
|
72
|
-
*/
|
|
73
|
-
// Simplest test case
|
|
74
|
-
/*
|
|
75
|
-
resolver: async (parent, args, contextValue, info) => {
|
|
76
|
-
|
|
77
|
-
const { config } = contextValue;
|
|
78
|
-
const openAIApiKey = config.get('openaiApiKey');
|
|
79
|
-
|
|
80
|
-
const model = new OpenAI({ openAIApiKey: openAIApiKey, temperature: 0.9 });
|
|
81
|
-
|
|
82
|
-
const template = "What is a good name for a company that makes {product}?";
|
|
83
|
-
|
|
84
|
-
const prompt = new PromptTemplate({
|
|
85
|
-
template: template,
|
|
86
|
-
inputVariables: ["product"],
|
|
87
|
-
});
|
|
88
|
-
|
|
89
|
-
const chain = new LLMChain({ llm: model, prompt: prompt });
|
|
90
|
-
|
|
23
|
+
const result = await model.invoke(input);
|
|
24
|
+
console.log(`Got output "${result.content}"`);
|
|
91
25
|
console.log(`====================`);
|
|
92
|
-
console.log(`Calling langchain with prompt: ${prompt?.template}`);
|
|
93
|
-
console.log(`Input text: ${args.text}`);
|
|
94
|
-
const res = await chain.call({ product: args.text });
|
|
95
|
-
console.log(`Result: ${res?.text}`);
|
|
96
|
-
console.log(`====================`);
|
|
97
26
|
|
|
98
|
-
return
|
|
27
|
+
return result?.content;
|
|
99
28
|
},
|
|
100
|
-
*/
|
|
101
29
|
};
|
|
102
30
|
|
|
103
31
|
|
package/server/modelExecutor.js
CHANGED
|
@@ -20,6 +20,7 @@ import OpenAIVisionPlugin from './plugins/openAiVisionPlugin.js';
|
|
|
20
20
|
import GeminiChatPlugin from './plugins/geminiChatPlugin.js';
|
|
21
21
|
import GeminiVisionPlugin from './plugins/geminiVisionPlugin.js';
|
|
22
22
|
import AzureBingPlugin from './plugins/azureBingPlugin.js';
|
|
23
|
+
import Claude3VertexPlugin from './plugins/claude3VertexPlugin.js';
|
|
23
24
|
|
|
24
25
|
class ModelExecutor {
|
|
25
26
|
constructor(pathway, model) {
|
|
@@ -84,6 +85,9 @@ class ModelExecutor {
|
|
|
84
85
|
case 'AZURE-BING':
|
|
85
86
|
plugin = new AzureBingPlugin(pathway, model);
|
|
86
87
|
break;
|
|
88
|
+
case 'CLAUDE-3-VERTEX':
|
|
89
|
+
plugin = new Claude3VertexPlugin(pathway, model);
|
|
90
|
+
break;
|
|
87
91
|
default:
|
|
88
92
|
throw new Error(`Unsupported model type: ${model.type}`);
|
|
89
93
|
}
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import { ModelExecutor } from './modelExecutor.js';
|
|
2
2
|
import { modelEndpoints } from '../lib/requestExecutor.js';
|
|
3
|
-
// eslint-disable-next-line import/no-extraneous-dependencies
|
|
4
3
|
import { v4 as uuidv4 } from 'uuid';
|
|
5
4
|
import { encode } from '../lib/encodeCache.js';
|
|
6
5
|
import { getFirstNToken, getLastNToken, getSemanticChunks } from './chunker.js';
|
|
@@ -11,6 +10,8 @@ import { requestState } from './requestState.js';
|
|
|
11
10
|
import { callPathway } from '../lib/pathwayTools.js';
|
|
12
11
|
import { publishRequestProgress } from '../lib/redisSubscription.js';
|
|
13
12
|
import logger from '../lib/logger.js';
|
|
13
|
+
// eslint-disable-next-line import/no-extraneous-dependencies
|
|
14
|
+
import { createParser } from 'eventsource-parser';
|
|
14
15
|
|
|
15
16
|
const modelTypesExcludedFromProgressUpdates = ['OPENAI-DALLE2', 'OPENAI-DALLE3'];
|
|
16
17
|
|
|
@@ -69,136 +70,112 @@ class PathwayResolver {
|
|
|
69
70
|
this.pathwayPrompt = pathway.prompt;
|
|
70
71
|
}
|
|
71
72
|
|
|
72
|
-
// This code handles async and streaming responses
|
|
73
|
-
//
|
|
74
|
-
// the time the client will be an external client, but it could also be the
|
|
75
|
-
// Cortex REST api code.
|
|
73
|
+
// This code handles async and streaming responses for either long-running
|
|
74
|
+
// tasks or streaming model responses
|
|
76
75
|
async asyncResolve(args) {
|
|
77
|
-
const MAX_RETRY_COUNT = 3;
|
|
78
|
-
let attempt = 0;
|
|
79
76
|
let streamErrorOccurred = false;
|
|
77
|
+
let responseData = null;
|
|
80
78
|
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
try {
|
|
80
|
+
responseData = await this.executePathway(args);
|
|
81
|
+
}
|
|
82
|
+
catch (error) {
|
|
83
|
+
if (!args.async) {
|
|
84
|
+
publishRequestProgress({
|
|
85
|
+
requestId: this.requestId,
|
|
86
|
+
progress: 1,
|
|
87
|
+
data: '[DONE]',
|
|
88
|
+
});
|
|
89
|
+
}
|
|
90
|
+
return;
|
|
91
|
+
}
|
|
83
92
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
}
|
|
96
|
-
}
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
let events = data.toString().split('\n');
|
|
107
|
-
|
|
108
|
-
//events = "data: {\"id\":\"chatcmpl-20bf1895-2fa7-4ef9-abfe-4d142aba5817\",\"object\":\"chat.completion.chunk\",\"created\":1689303423723,\"model\":\"gpt-4\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":{\"error\":{\"message\":\"The server had an error while processing your request. Sorry about that!\",\"type\":\"server_error\",\"param\":null,\"code\":null}}},\"finish_reason\":null}]}\n\n".split("\n");
|
|
109
|
-
|
|
110
|
-
for (let event of events) {
|
|
111
|
-
if (streamErrorOccurred) break;
|
|
112
|
-
|
|
113
|
-
// skip empty events
|
|
114
|
-
if (!(event.trim() === '')) {
|
|
115
|
-
//logger.info(`Processing stream event for requestId ${this.requestId}: ${event}`);
|
|
116
|
-
messageBuffer += event.replace(/^data: /, '');
|
|
117
|
-
|
|
118
|
-
const requestProgress = {
|
|
119
|
-
requestId: this.requestId,
|
|
120
|
-
data: messageBuffer,
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
// check for end of stream or in-stream errors
|
|
124
|
-
if (messageBuffer.trim() === '[DONE]') {
|
|
125
|
-
requestProgress.progress = 1;
|
|
126
|
-
} else {
|
|
127
|
-
let parsedMessage;
|
|
128
|
-
try {
|
|
129
|
-
parsedMessage = JSON.parse(messageBuffer);
|
|
130
|
-
messageBuffer = '';
|
|
131
|
-
} catch (error) {
|
|
132
|
-
// incomplete stream message, try to buffer more data
|
|
133
|
-
return;
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
// error can be in different places in the message
|
|
137
|
-
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
|
|
138
|
-
if (streamError) {
|
|
139
|
-
streamErrorOccurred = true;
|
|
140
|
-
logger.error(`Stream error: ${streamError.message}`);
|
|
141
|
-
incomingMessage.off('data', processStreamSSE);
|
|
142
|
-
return;
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
// finish reason can be in different places in the message
|
|
146
|
-
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
|
|
147
|
-
if (finishReason?.toLowerCase() === 'stop') {
|
|
148
|
-
requestProgress.progress = 1;
|
|
149
|
-
} else {
|
|
150
|
-
if (finishReason?.toLowerCase() === 'safety') {
|
|
151
|
-
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
|
|
152
|
-
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
|
|
153
|
-
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
|
|
154
|
-
requestProgress.progress = 1;
|
|
155
|
-
}
|
|
156
|
-
}
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
try {
|
|
160
|
-
if (!streamEnded) {
|
|
161
|
-
//logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
|
|
162
|
-
publishRequestProgress(requestProgress);
|
|
163
|
-
streamEnded = requestProgress.progress === 1;
|
|
164
|
-
}
|
|
165
|
-
} catch (error) {
|
|
166
|
-
logger.error(`Could not publish the stream message: "${messageBuffer}", ${error}`);
|
|
167
|
-
}
|
|
168
|
-
}
|
|
169
|
-
}
|
|
170
|
-
} catch (error) {
|
|
171
|
-
logger.error(`Could not process stream data: ${error}`);
|
|
172
|
-
}
|
|
93
|
+
// If the response is a string, it's a regular long running response
|
|
94
|
+
if (args.async || typeof responseData === 'string') {
|
|
95
|
+
const { completedCount, totalCount } = requestState[this.requestId];
|
|
96
|
+
requestState[this.requestId].data = responseData;
|
|
97
|
+
|
|
98
|
+
// some models don't support progress updates
|
|
99
|
+
if (!modelTypesExcludedFromProgressUpdates.includes(this.model.type)) {
|
|
100
|
+
await publishRequestProgress({
|
|
101
|
+
requestId: this.requestId,
|
|
102
|
+
progress: completedCount / totalCount,
|
|
103
|
+
data: JSON.stringify(responseData),
|
|
104
|
+
});
|
|
105
|
+
}
|
|
106
|
+
// If the response is an object, it's a streaming response
|
|
107
|
+
} else {
|
|
108
|
+
try {
|
|
109
|
+
const incomingMessage = responseData;
|
|
110
|
+
let streamEnded = false;
|
|
111
|
+
|
|
112
|
+
const onParse = (event) => {
|
|
113
|
+
let requestProgress = {
|
|
114
|
+
requestId: this.requestId
|
|
173
115
|
};
|
|
174
116
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
})
|
|
117
|
+
logger.debug(`Received event: ${event.type}`);
|
|
118
|
+
|
|
119
|
+
if (event.type === 'event') {
|
|
120
|
+
logger.debug('Received event!')
|
|
121
|
+
logger.debug(`id: ${event.id || '<none>'}`)
|
|
122
|
+
logger.debug(`name: ${event.name || '<none>'}`)
|
|
123
|
+
logger.debug(`data: ${event.data}`)
|
|
124
|
+
} else if (event.type === 'reconnect-interval') {
|
|
125
|
+
logger.debug(`We should set reconnect interval to ${event.value} milliseconds`)
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
try {
|
|
129
|
+
requestProgress = this.modelExecutor.plugin.processStreamEvent(event, requestProgress);
|
|
130
|
+
} catch (error) {
|
|
131
|
+
streamErrorOccurred = true;
|
|
132
|
+
logger.error(`Stream error: ${error.message}`);
|
|
133
|
+
incomingMessage.off('data', processStream);
|
|
134
|
+
return;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
try {
|
|
138
|
+
if (!streamEnded && requestProgress.data) {
|
|
139
|
+
//logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
|
|
140
|
+
publishRequestProgress(requestProgress);
|
|
141
|
+
streamEnded = requestProgress.progress === 1;
|
|
142
|
+
}
|
|
143
|
+
} catch (error) {
|
|
144
|
+
logger.error(`Could not publish the stream message: "${event.data}", ${error}`);
|
|
181
145
|
}
|
|
182
146
|
|
|
183
|
-
} catch (error) {
|
|
184
|
-
logger.error(`Could not subscribe to stream: ${error}`);
|
|
185
147
|
}
|
|
148
|
+
|
|
149
|
+
const sseParser = createParser(onParse);
|
|
150
|
+
|
|
151
|
+
const processStream = (data) => {
|
|
152
|
+
//logger.warn(`RECEIVED DATA: ${JSON.stringify(data.toString())}`);
|
|
153
|
+
sseParser.feed(data.toString());
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
if (incomingMessage) {
|
|
157
|
+
await new Promise((resolve, reject) => {
|
|
158
|
+
incomingMessage.on('data', processStream);
|
|
159
|
+
incomingMessage.on('end', resolve);
|
|
160
|
+
incomingMessage.on('error', reject);
|
|
161
|
+
});
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
} catch (error) {
|
|
165
|
+
logger.error(`Could not subscribe to stream: ${error}`);
|
|
186
166
|
}
|
|
187
167
|
|
|
188
168
|
if (streamErrorOccurred) {
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
169
|
+
logger.error(`Stream read failed. Finishing stream...`);
|
|
170
|
+
publishRequestProgress({
|
|
171
|
+
requestId: this.requestId,
|
|
172
|
+
progress: 1,
|
|
173
|
+
data: '[DONE]',
|
|
174
|
+
});
|
|
192
175
|
} else {
|
|
193
176
|
return;
|
|
194
177
|
}
|
|
195
178
|
}
|
|
196
|
-
// if all retries failed, publish the stream end message
|
|
197
|
-
publishRequestProgress({
|
|
198
|
-
requestId: this.requestId,
|
|
199
|
-
progress: 1,
|
|
200
|
-
data: '[DONE]',
|
|
201
|
-
});
|
|
202
179
|
}
|
|
203
180
|
|
|
204
181
|
async resolve(args) {
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import ModelPlugin from './modelPlugin.js';
|
|
2
2
|
import logger from '../../lib/logger.js';
|
|
3
|
+
import { config } from '../../config.js';
|
|
3
4
|
|
|
4
5
|
class AzureBingPlugin extends ModelPlugin {
|
|
5
6
|
constructor(pathway, model) {
|
|
@@ -18,6 +19,9 @@ class AzureBingPlugin extends ModelPlugin {
|
|
|
18
19
|
}
|
|
19
20
|
|
|
20
21
|
async execute(text, parameters, prompt, cortexRequest) {
|
|
22
|
+
if(!config.getEnv()["AZURE_BING_KEY"]){
|
|
23
|
+
throw new Error("AZURE_BING_KEY is not set in the environment variables!");
|
|
24
|
+
}
|
|
21
25
|
const requestParameters = this.getRequestParameters(text, parameters, prompt);
|
|
22
26
|
|
|
23
27
|
cortexRequest.data = requestParameters.data;
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import OpenAIVisionPlugin from './openAiVisionPlugin.js';
|
|
2
|
+
|
|
3
|
+
class Claude3VertexPlugin extends OpenAIVisionPlugin {
|
|
4
|
+
|
|
5
|
+
parseResponse(data)
|
|
6
|
+
{
|
|
7
|
+
if (!data) {
|
|
8
|
+
return data;
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
const { content } = data;
|
|
12
|
+
|
|
13
|
+
// if the response is an array, return the text property of the first item
|
|
14
|
+
// if the type property is 'text'
|
|
15
|
+
if (content && Array.isArray(content) && content[0].type === 'text') {
|
|
16
|
+
return content[0].text;
|
|
17
|
+
} else {
|
|
18
|
+
return data;
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// This code converts messages to the format required by the Claude Vertex API
|
|
23
|
+
convertMessagesToClaudeVertex(messages) {
|
|
24
|
+
let modifiedMessages = [];
|
|
25
|
+
let system = '';
|
|
26
|
+
let lastAuthor = '';
|
|
27
|
+
|
|
28
|
+
// Claude needs system messages in a separate field
|
|
29
|
+
const systemMessages = messages.filter(message => message.role === 'system');
|
|
30
|
+
if (systemMessages.length > 0) {
|
|
31
|
+
system = systemMessages.map(message => message.content).join('\n');
|
|
32
|
+
modifiedMessages = messages.filter(message => message.role !== 'system');
|
|
33
|
+
} else {
|
|
34
|
+
modifiedMessages = messages;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// remove any empty messages
|
|
38
|
+
modifiedMessages = modifiedMessages.filter(message => message.content);
|
|
39
|
+
|
|
40
|
+
// combine any consecutive messages from the same author
|
|
41
|
+
var combinedMessages = [];
|
|
42
|
+
|
|
43
|
+
modifiedMessages.forEach((message) => {
|
|
44
|
+
if (message.role === lastAuthor) {
|
|
45
|
+
combinedMessages[combinedMessages.length - 1].content += '\n' + message.content;
|
|
46
|
+
} else {
|
|
47
|
+
combinedMessages.push(message);
|
|
48
|
+
lastAuthor = message.role;
|
|
49
|
+
}
|
|
50
|
+
});
|
|
51
|
+
|
|
52
|
+
modifiedMessages = combinedMessages;
|
|
53
|
+
|
|
54
|
+
// Claude vertex requires an even number of messages
|
|
55
|
+
if (modifiedMessages.length % 2 === 0) {
|
|
56
|
+
modifiedMessages = modifiedMessages.slice(1);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return {
|
|
60
|
+
system,
|
|
61
|
+
modifiedMessages,
|
|
62
|
+
};
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
getRequestParameters(text, parameters, prompt, cortexRequest) {
|
|
66
|
+
const requestParameters = super.getRequestParameters(text, parameters, prompt, cortexRequest);
|
|
67
|
+
const { system, modifiedMessages } = this.convertMessagesToClaudeVertex(requestParameters.messages);
|
|
68
|
+
requestParameters.system = system;
|
|
69
|
+
requestParameters.messages = modifiedMessages;
|
|
70
|
+
requestParameters.max_tokens = this.getModelMaxReturnTokens();
|
|
71
|
+
requestParameters.anthropic_version = 'vertex-2023-10-16';
|
|
72
|
+
return requestParameters;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
async execute(text, parameters, prompt, cortexRequest) {
|
|
76
|
+
const requestParameters = this.getRequestParameters(text, parameters, prompt, cortexRequest);
|
|
77
|
+
const { stream } = parameters;
|
|
78
|
+
|
|
79
|
+
cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
|
|
80
|
+
cortexRequest.params = {}; // query params
|
|
81
|
+
cortexRequest.stream = stream;
|
|
82
|
+
cortexRequest.url = cortexRequest.stream ? `${cortexRequest.url}:streamRawPredict` : `${cortexRequest.url}:rawPredict`;
|
|
83
|
+
|
|
84
|
+
const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
|
|
85
|
+
const authToken = await gcpAuthTokenHelper.getAccessToken();
|
|
86
|
+
cortexRequest.headers.Authorization = `Bearer ${authToken}`;
|
|
87
|
+
|
|
88
|
+
return this.executeRequest(cortexRequest);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
processStreamEvent(event, requestProgress) {
|
|
92
|
+
const eventData = JSON.parse(event.data);
|
|
93
|
+
switch (eventData.type) {
|
|
94
|
+
case 'message_start':
|
|
95
|
+
requestProgress.data = JSON.stringify(eventData.message);
|
|
96
|
+
break;
|
|
97
|
+
case 'content_block_start':
|
|
98
|
+
break;
|
|
99
|
+
case 'ping':
|
|
100
|
+
break;
|
|
101
|
+
case 'content_block_delta':
|
|
102
|
+
if (eventData.delta.type === 'text_delta') {
|
|
103
|
+
requestProgress.data = JSON.stringify(eventData.delta.text);
|
|
104
|
+
}
|
|
105
|
+
break;
|
|
106
|
+
case 'content_block_stop':
|
|
107
|
+
break;
|
|
108
|
+
case 'message_delta':
|
|
109
|
+
break;
|
|
110
|
+
case 'message_stop':
|
|
111
|
+
requestProgress.data = '[DONE]';
|
|
112
|
+
requestProgress.progress = 1;
|
|
113
|
+
break;
|
|
114
|
+
case 'error':
|
|
115
|
+
requestProgress.data = `\n\n*** ${eventData.error.message || eventData.error} ***`;
|
|
116
|
+
requestProgress.progress = 1;
|
|
117
|
+
break;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
return requestProgress;
|
|
121
|
+
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
export default Claude3VertexPlugin;
|
|
@@ -5,8 +5,18 @@ import logger from '../../lib/logger.js';
|
|
|
5
5
|
const mergeResults = (data) => {
|
|
6
6
|
let output = '';
|
|
7
7
|
let safetyRatings = [];
|
|
8
|
+
const RESPONSE_BLOCKED = 'The response was blocked because the input or response potentially violates policies. Try rephrasing the prompt or adjusting the parameter settings.';
|
|
8
9
|
|
|
9
10
|
for (let chunk of data) {
|
|
11
|
+
const { promptfeedback } = chunk;
|
|
12
|
+
if (promptfeedback) {
|
|
13
|
+
const { blockReason } = promptfeedback;
|
|
14
|
+
if (blockReason) {
|
|
15
|
+
logger.warn(`Response blocked due to prompt feedback: ${blockReason}`);
|
|
16
|
+
return {mergedResult: RESPONSE_BLOCKED, safetyRatings: safetyRatings};
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
10
20
|
const { candidates } = chunk;
|
|
11
21
|
if (!candidates || !candidates.length) {
|
|
12
22
|
continue;
|
|
@@ -15,7 +25,8 @@ const mergeResults = (data) => {
|
|
|
15
25
|
// If it was blocked, return the blocked message
|
|
16
26
|
if (candidates[0].safetyRatings.some(rating => rating.blocked)) {
|
|
17
27
|
safetyRatings = candidates[0].safetyRatings;
|
|
18
|
-
|
|
28
|
+
logger.warn(`Response blocked due to safety ratings: ${JSON.stringify(safetyRatings, null, 2)}`);
|
|
29
|
+
return {mergedResult: RESPONSE_BLOCKED, safetyRatings: safetyRatings};
|
|
19
30
|
}
|
|
20
31
|
|
|
21
32
|
// Append the content of the first part of the first candidate to the output
|
|
@@ -236,8 +236,11 @@ class ModelPlugin {
|
|
|
236
236
|
|
|
237
237
|
getLength(data) {
|
|
238
238
|
const isProd = config.get('env') === 'production';
|
|
239
|
-
|
|
240
|
-
|
|
239
|
+
let length = 0;
|
|
240
|
+
let units = isProd ? 'characters' : 'tokens';
|
|
241
|
+
if (data) {
|
|
242
|
+
length = isProd ? data.length : encode(data).length;
|
|
243
|
+
}
|
|
241
244
|
return {length, units};
|
|
242
245
|
}
|
|
243
246
|
|
|
@@ -288,6 +291,42 @@ class ModelPlugin {
|
|
|
288
291
|
}
|
|
289
292
|
}
|
|
290
293
|
|
|
294
|
+
processStreamEvent(event, requestProgress) {
|
|
295
|
+
// check for end of stream or in-stream errors
|
|
296
|
+
if (event.data.trim() === '[DONE]') {
|
|
297
|
+
requestProgress.progress = 1;
|
|
298
|
+
} else {
|
|
299
|
+
let parsedMessage;
|
|
300
|
+
try {
|
|
301
|
+
parsedMessage = JSON.parse(event.data);
|
|
302
|
+
requestProgress.data = event.data;
|
|
303
|
+
} catch (error) {
|
|
304
|
+
throw new Error(`Could not parse stream data: ${error}`);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// error can be in different places in the message
|
|
308
|
+
const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
|
|
309
|
+
if (streamError) {
|
|
310
|
+
throw new Error(streamError);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
// finish reason can be in different places in the message
|
|
314
|
+
const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
|
|
315
|
+
if (finishReason?.toLowerCase() === 'stop') {
|
|
316
|
+
requestProgress.progress = 1;
|
|
317
|
+
} else {
|
|
318
|
+
if (finishReason?.toLowerCase() === 'safety') {
|
|
319
|
+
const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
|
|
320
|
+
logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
|
|
321
|
+
requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
|
|
322
|
+
requestProgress.progress = 1;
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
return requestProgress;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
|
|
291
330
|
}
|
|
292
331
|
|
|
293
332
|
export default ModelPlugin;
|
|
@@ -100,7 +100,13 @@ function alignSubtitles(subtitles, format) {
|
|
|
100
100
|
const result = [];
|
|
101
101
|
|
|
102
102
|
function preprocessStr(str) {
|
|
103
|
-
|
|
103
|
+
try{
|
|
104
|
+
if(!str) return '';
|
|
105
|
+
return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
|
|
106
|
+
}catch(e){
|
|
107
|
+
logger.error(`An error occurred in content text preprocessing: ${e}`);
|
|
108
|
+
return '';
|
|
109
|
+
}
|
|
104
110
|
}
|
|
105
111
|
|
|
106
112
|
function shiftSubtitles(subtitle, shiftOffset) {
|
|
@@ -14,6 +14,9 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
14
14
|
let modifiedMessages = [];
|
|
15
15
|
let lastAuthor = '';
|
|
16
16
|
|
|
17
|
+
// remove any empty messages
|
|
18
|
+
messages = messages.filter(message => message.content);
|
|
19
|
+
|
|
17
20
|
messages.forEach(message => {
|
|
18
21
|
const { role, author, content } = message;
|
|
19
22
|
|
|
@@ -153,7 +156,7 @@ class PalmChatPlugin extends ModelPlugin {
|
|
|
153
156
|
parseResponse(data) {
|
|
154
157
|
const { predictions } = data;
|
|
155
158
|
if (!predictions || !predictions.length) {
|
|
156
|
-
return
|
|
159
|
+
return data;
|
|
157
160
|
}
|
|
158
161
|
|
|
159
162
|
// Get the candidates array from the first prediction
|
package/server/rest.js
CHANGED
|
@@ -148,6 +148,10 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
|
|
|
148
148
|
} else if (messageJson.candidates) {
|
|
149
149
|
const { content, finishReason } = messageJson.candidates[0];
|
|
150
150
|
fillJsonResponse(jsonResponse, content.parts[0].text, finishReason);
|
|
151
|
+
} else if (messageJson.content) {
|
|
152
|
+
const text = messageJson.content?.[0]?.text || '';
|
|
153
|
+
const finishReason = messageJson.stop_reason;
|
|
154
|
+
fillJsonResponse(jsonResponse, text, finishReason);
|
|
151
155
|
} else {
|
|
152
156
|
fillJsonResponse(jsonResponse, messageJson, null);
|
|
153
157
|
}
|