@aj-archipelago/cortex 1.1.7 → 1.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aj-archipelago/cortex",
3
- "version": "1.1.7",
3
+ "version": "1.1.8",
4
4
  "description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
5
5
  "private": false,
6
6
  "repository": {
@@ -35,6 +35,7 @@
35
35
  "@datastructures-js/deque": "^1.0.4",
36
36
  "@graphql-tools/schema": "^9.0.12",
37
37
  "@keyv/redis": "^2.5.4",
38
+ "@langchain/openai": "^0.0.24",
38
39
  "axios": "^1.3.4",
39
40
  "axios-cache-interceptor": "^1.0.1",
40
41
  "bottleneck": "^2.19.5",
@@ -42,6 +43,7 @@
42
43
  "compromise": "^14.8.1",
43
44
  "compromise-paragraphs": "^0.1.0",
44
45
  "convict": "^6.2.3",
46
+ "eventsource-parser": "^1.1.2",
45
47
  "express": "^4.18.2",
46
48
  "form-data": "^4.0.0",
47
49
  "google-auth-library": "^8.8.0",
@@ -2,12 +2,7 @@
2
2
  // LangChain Cortex integration test
3
3
 
4
4
  // Import required modules
5
- import { OpenAI } from "langchain/llms";
6
- //import { PromptTemplate } from "langchain/prompts";
7
- //import { LLMChain, ConversationChain } from "langchain/chains";
8
- import { initializeAgentExecutor } from "langchain/agents";
9
- import { SerpAPI, Calculator } from "langchain/tools";
10
- //import { BufferMemory } from "langchain/memory";
5
+ import { ChatOpenAI } from "@langchain/openai";
11
6
 
12
7
  export default {
13
8
 
@@ -15,89 +10,22 @@ export default {
15
10
  resolver: async (parent, args, contextValue, _info) => {
16
11
 
17
12
  const { config } = contextValue;
18
- const env = config.getEnv();
19
13
 
20
14
  // example of reading from a predefined config variable
21
15
  const openAIApiKey = config.get('openaiApiKey');
22
- // example of reading straight from environment
23
- const serpApiKey = env.SERPAPI_API_KEY;
24
16
 
25
- const model = new OpenAI({ openAIApiKey: openAIApiKey, temperature: 0 });
26
- const tools = [new SerpAPI( serpApiKey ), new Calculator()];
27
-
28
- const executor = await initializeAgentExecutor(
29
- tools,
30
- model,
31
- "zero-shot-react-description"
32
- );
17
+ const model = new ChatOpenAI({ openAIApiKey: openAIApiKey, temperature: 0 });
33
18
 
34
19
  console.log(`====================`);
35
- console.log("Loaded langchain agent.");
20
+ console.log("Loaded langchain.");
36
21
  const input = args.text;
37
22
  console.log(`Executing with input "${input}"...`);
38
- const result = await executor.call({ input });
39
- console.log(`Got output ${result.output}`);
40
- console.log(`====================`);
41
-
42
- return result?.output;
43
- },
44
-
45
- /*
46
- // Agent test case
47
- resolver: async (parent, args, contextValue, info) => {
48
-
49
- const { config } = contextValue;
50
- const openAIApiKey = config.get('openaiApiKey');
51
- const serpApiKey = config.get('serpApiKey');
52
-
53
- const model = new OpenAI({ openAIApiKey: openAIApiKey, temperature: 0 });
54
- const tools = [new SerpAPI( serpApiKey ), new Calculator()];
55
-
56
- const executor = await initializeAgentExecutor(
57
- tools,
58
- model,
59
- "zero-shot-react-description"
60
- );
61
-
62
- console.log(`====================`);
63
- console.log("Loaded langchain agent.");
64
- const input = args.text;
65
- console.log(`Executing with input "${input}"...`);
66
- const result = await executor.call({ input });
67
- console.log(`Got output ${result.output}`);
68
- console.log(`====================`);
69
-
70
- return result?.output;
71
- },
72
- */
73
- // Simplest test case
74
- /*
75
- resolver: async (parent, args, contextValue, info) => {
76
-
77
- const { config } = contextValue;
78
- const openAIApiKey = config.get('openaiApiKey');
79
-
80
- const model = new OpenAI({ openAIApiKey: openAIApiKey, temperature: 0.9 });
81
-
82
- const template = "What is a good name for a company that makes {product}?";
83
-
84
- const prompt = new PromptTemplate({
85
- template: template,
86
- inputVariables: ["product"],
87
- });
88
-
89
- const chain = new LLMChain({ llm: model, prompt: prompt });
90
-
23
+ const result = await model.invoke(input);
24
+ console.log(`Got output "${result.content}"`);
91
25
  console.log(`====================`);
92
- console.log(`Calling langchain with prompt: ${prompt?.template}`);
93
- console.log(`Input text: ${args.text}`);
94
- const res = await chain.call({ product: args.text });
95
- console.log(`Result: ${res?.text}`);
96
- console.log(`====================`);
97
26
 
98
- return res?.text?.trim();
27
+ return result?.content;
99
28
  },
100
- */
101
29
  };
102
30
 
103
31
 
@@ -20,6 +20,7 @@ import OpenAIVisionPlugin from './plugins/openAiVisionPlugin.js';
20
20
  import GeminiChatPlugin from './plugins/geminiChatPlugin.js';
21
21
  import GeminiVisionPlugin from './plugins/geminiVisionPlugin.js';
22
22
  import AzureBingPlugin from './plugins/azureBingPlugin.js';
23
+ import Claude3VertexPlugin from './plugins/claude3VertexPlugin.js';
23
24
 
24
25
  class ModelExecutor {
25
26
  constructor(pathway, model) {
@@ -84,6 +85,9 @@ class ModelExecutor {
84
85
  case 'AZURE-BING':
85
86
  plugin = new AzureBingPlugin(pathway, model);
86
87
  break;
88
+ case 'CLAUDE-3-VERTEX':
89
+ plugin = new Claude3VertexPlugin(pathway, model);
90
+ break;
87
91
  default:
88
92
  throw new Error(`Unsupported model type: ${model.type}`);
89
93
  }
@@ -1,6 +1,5 @@
1
1
  import { ModelExecutor } from './modelExecutor.js';
2
2
  import { modelEndpoints } from '../lib/requestExecutor.js';
3
- // eslint-disable-next-line import/no-extraneous-dependencies
4
3
  import { v4 as uuidv4 } from 'uuid';
5
4
  import { encode } from '../lib/encodeCache.js';
6
5
  import { getFirstNToken, getLastNToken, getSemanticChunks } from './chunker.js';
@@ -11,6 +10,8 @@ import { requestState } from './requestState.js';
11
10
  import { callPathway } from '../lib/pathwayTools.js';
12
11
  import { publishRequestProgress } from '../lib/redisSubscription.js';
13
12
  import logger from '../lib/logger.js';
13
+ // eslint-disable-next-line import/no-extraneous-dependencies
14
+ import { createParser } from 'eventsource-parser';
14
15
 
15
16
  const modelTypesExcludedFromProgressUpdates = ['OPENAI-DALLE2', 'OPENAI-DALLE3'];
16
17
 
@@ -69,136 +70,112 @@ class PathwayResolver {
69
70
  this.pathwayPrompt = pathway.prompt;
70
71
  }
71
72
 
72
- // This code handles async and streaming responses. In either case, we use
73
- // the graphql subscription to send progress updates to the client. Most of
74
- // the time the client will be an external client, but it could also be the
75
- // Cortex REST api code.
73
+ // This code handles async and streaming responses for either long-running
74
+ // tasks or streaming model responses
76
75
  async asyncResolve(args) {
77
- const MAX_RETRY_COUNT = 3;
78
- let attempt = 0;
79
76
  let streamErrorOccurred = false;
77
+ let responseData = null;
80
78
 
81
- while (attempt < MAX_RETRY_COUNT) {
82
- const responseData = await this.executePathway(args);
79
+ try {
80
+ responseData = await this.executePathway(args);
81
+ }
82
+ catch (error) {
83
+ if (!args.async) {
84
+ publishRequestProgress({
85
+ requestId: this.requestId,
86
+ progress: 1,
87
+ data: '[DONE]',
88
+ });
89
+ }
90
+ return;
91
+ }
83
92
 
84
- if (args.async || typeof responseData === 'string') {
85
- const { completedCount, totalCount } = requestState[this.requestId];
86
- requestState[this.requestId].data = responseData;
87
-
88
- // if model type is OPENAI-IMAGE
89
- if (!modelTypesExcludedFromProgressUpdates.includes(this.model.type)) {
90
- await publishRequestProgress({
91
- requestId: this.requestId,
92
- progress: completedCount / totalCount,
93
- data: JSON.stringify(responseData),
94
- });
95
- }
96
- } else {
97
- try {
98
- const incomingMessage = responseData;
99
-
100
- let messageBuffer = '';
101
- let streamEnded = false;
102
-
103
- const processStreamSSE = (data) => {
104
- try {
105
- //logger.info(`\n\nReceived stream data for requestId ${this.requestId}: ${data.toString()}`);
106
- let events = data.toString().split('\n');
107
-
108
- //events = "data: {\"id\":\"chatcmpl-20bf1895-2fa7-4ef9-abfe-4d142aba5817\",\"object\":\"chat.completion.chunk\",\"created\":1689303423723,\"model\":\"gpt-4\",\"choices\":[{\"delta\":{\"role\":\"assistant\",\"content\":{\"error\":{\"message\":\"The server had an error while processing your request. Sorry about that!\",\"type\":\"server_error\",\"param\":null,\"code\":null}}},\"finish_reason\":null}]}\n\n".split("\n");
109
-
110
- for (let event of events) {
111
- if (streamErrorOccurred) break;
112
-
113
- // skip empty events
114
- if (!(event.trim() === '')) {
115
- //logger.info(`Processing stream event for requestId ${this.requestId}: ${event}`);
116
- messageBuffer += event.replace(/^data: /, '');
117
-
118
- const requestProgress = {
119
- requestId: this.requestId,
120
- data: messageBuffer,
121
- }
122
-
123
- // check for end of stream or in-stream errors
124
- if (messageBuffer.trim() === '[DONE]') {
125
- requestProgress.progress = 1;
126
- } else {
127
- let parsedMessage;
128
- try {
129
- parsedMessage = JSON.parse(messageBuffer);
130
- messageBuffer = '';
131
- } catch (error) {
132
- // incomplete stream message, try to buffer more data
133
- return;
134
- }
135
-
136
- // error can be in different places in the message
137
- const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
138
- if (streamError) {
139
- streamErrorOccurred = true;
140
- logger.error(`Stream error: ${streamError.message}`);
141
- incomingMessage.off('data', processStreamSSE);
142
- return;
143
- }
144
-
145
- // finish reason can be in different places in the message
146
- const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
147
- if (finishReason?.toLowerCase() === 'stop') {
148
- requestProgress.progress = 1;
149
- } else {
150
- if (finishReason?.toLowerCase() === 'safety') {
151
- const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
152
- logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
153
- requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
154
- requestProgress.progress = 1;
155
- }
156
- }
157
- }
158
-
159
- try {
160
- if (!streamEnded) {
161
- //logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
162
- publishRequestProgress(requestProgress);
163
- streamEnded = requestProgress.progress === 1;
164
- }
165
- } catch (error) {
166
- logger.error(`Could not publish the stream message: "${messageBuffer}", ${error}`);
167
- }
168
- }
169
- }
170
- } catch (error) {
171
- logger.error(`Could not process stream data: ${error}`);
172
- }
93
+ // If the response is a string, it's a regular long running response
94
+ if (args.async || typeof responseData === 'string') {
95
+ const { completedCount, totalCount } = requestState[this.requestId];
96
+ requestState[this.requestId].data = responseData;
97
+
98
+ // some models don't support progress updates
99
+ if (!modelTypesExcludedFromProgressUpdates.includes(this.model.type)) {
100
+ await publishRequestProgress({
101
+ requestId: this.requestId,
102
+ progress: completedCount / totalCount,
103
+ data: JSON.stringify(responseData),
104
+ });
105
+ }
106
+ // If the response is an object, it's a streaming response
107
+ } else {
108
+ try {
109
+ const incomingMessage = responseData;
110
+ let streamEnded = false;
111
+
112
+ const onParse = (event) => {
113
+ let requestProgress = {
114
+ requestId: this.requestId
173
115
  };
174
116
 
175
- if (incomingMessage) {
176
- await new Promise((resolve, reject) => {
177
- incomingMessage.on('data', processStreamSSE);
178
- incomingMessage.on('end', resolve);
179
- incomingMessage.on('error', reject);
180
- });
117
+ logger.debug(`Received event: ${event.type}`);
118
+
119
+ if (event.type === 'event') {
120
+ logger.debug('Received event!')
121
+ logger.debug(`id: ${event.id || '<none>'}`)
122
+ logger.debug(`name: ${event.name || '<none>'}`)
123
+ logger.debug(`data: ${event.data}`)
124
+ } else if (event.type === 'reconnect-interval') {
125
+ logger.debug(`We should set reconnect interval to ${event.value} milliseconds`)
126
+ }
127
+
128
+ try {
129
+ requestProgress = this.modelExecutor.plugin.processStreamEvent(event, requestProgress);
130
+ } catch (error) {
131
+ streamErrorOccurred = true;
132
+ logger.error(`Stream error: ${error.message}`);
133
+ incomingMessage.off('data', processStream);
134
+ return;
135
+ }
136
+
137
+ try {
138
+ if (!streamEnded && requestProgress.data) {
139
+ //logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
140
+ publishRequestProgress(requestProgress);
141
+ streamEnded = requestProgress.progress === 1;
142
+ }
143
+ } catch (error) {
144
+ logger.error(`Could not publish the stream message: "${event.data}", ${error}`);
181
145
  }
182
146
 
183
- } catch (error) {
184
- logger.error(`Could not subscribe to stream: ${error}`);
185
147
  }
148
+
149
+ const sseParser = createParser(onParse);
150
+
151
+ const processStream = (data) => {
152
+ //logger.warn(`RECEIVED DATA: ${JSON.stringify(data.toString())}`);
153
+ sseParser.feed(data.toString());
154
+ }
155
+
156
+ if (incomingMessage) {
157
+ await new Promise((resolve, reject) => {
158
+ incomingMessage.on('data', processStream);
159
+ incomingMessage.on('end', resolve);
160
+ incomingMessage.on('error', reject);
161
+ });
162
+ }
163
+
164
+ } catch (error) {
165
+ logger.error(`Could not subscribe to stream: ${error}`);
186
166
  }
187
167
 
188
168
  if (streamErrorOccurred) {
189
- attempt++;
190
- logger.error(`Stream attempt ${attempt} failed. Retrying...`);
191
- streamErrorOccurred = false; // Reset the flag for the next attempt
169
+ logger.error(`Stream read failed. Finishing stream...`);
170
+ publishRequestProgress({
171
+ requestId: this.requestId,
172
+ progress: 1,
173
+ data: '[DONE]',
174
+ });
192
175
  } else {
193
176
  return;
194
177
  }
195
178
  }
196
- // if all retries failed, publish the stream end message
197
- publishRequestProgress({
198
- requestId: this.requestId,
199
- progress: 1,
200
- data: '[DONE]',
201
- });
202
179
  }
203
180
 
204
181
  async resolve(args) {
@@ -1,5 +1,6 @@
1
1
  import ModelPlugin from './modelPlugin.js';
2
2
  import logger from '../../lib/logger.js';
3
+ import { config } from '../../config.js';
3
4
 
4
5
  class AzureBingPlugin extends ModelPlugin {
5
6
  constructor(pathway, model) {
@@ -18,6 +19,9 @@ class AzureBingPlugin extends ModelPlugin {
18
19
  }
19
20
 
20
21
  async execute(text, parameters, prompt, cortexRequest) {
22
+ if(!config.getEnv()["AZURE_BING_KEY"]){
23
+ throw new Error("AZURE_BING_KEY is not set in the environment variables!");
24
+ }
21
25
  const requestParameters = this.getRequestParameters(text, parameters, prompt);
22
26
 
23
27
  cortexRequest.data = requestParameters.data;
@@ -0,0 +1,126 @@
1
+ import OpenAIVisionPlugin from './openAiVisionPlugin.js';
2
+
3
+ class Claude3VertexPlugin extends OpenAIVisionPlugin {
4
+
5
+ parseResponse(data)
6
+ {
7
+ if (!data) {
8
+ return data;
9
+ }
10
+
11
+ const { content } = data;
12
+
13
+ // if the response is an array, return the text property of the first item
14
+ // if the type property is 'text'
15
+ if (content && Array.isArray(content) && content[0].type === 'text') {
16
+ return content[0].text;
17
+ } else {
18
+ return data;
19
+ }
20
+ }
21
+
22
+ // This code converts messages to the format required by the Claude Vertex API
23
+ convertMessagesToClaudeVertex(messages) {
24
+ let modifiedMessages = [];
25
+ let system = '';
26
+ let lastAuthor = '';
27
+
28
+ // Claude needs system messages in a separate field
29
+ const systemMessages = messages.filter(message => message.role === 'system');
30
+ if (systemMessages.length > 0) {
31
+ system = systemMessages.map(message => message.content).join('\n');
32
+ modifiedMessages = messages.filter(message => message.role !== 'system');
33
+ } else {
34
+ modifiedMessages = messages;
35
+ }
36
+
37
+ // remove any empty messages
38
+ modifiedMessages = modifiedMessages.filter(message => message.content);
39
+
40
+ // combine any consecutive messages from the same author
41
+ var combinedMessages = [];
42
+
43
+ modifiedMessages.forEach((message) => {
44
+ if (message.role === lastAuthor) {
45
+ combinedMessages[combinedMessages.length - 1].content += '\n' + message.content;
46
+ } else {
47
+ combinedMessages.push(message);
48
+ lastAuthor = message.role;
49
+ }
50
+ });
51
+
52
+ modifiedMessages = combinedMessages;
53
+
54
+ // Claude vertex requires an even number of messages
55
+ if (modifiedMessages.length % 2 === 0) {
56
+ modifiedMessages = modifiedMessages.slice(1);
57
+ }
58
+
59
+ return {
60
+ system,
61
+ modifiedMessages,
62
+ };
63
+ }
64
+
65
+ getRequestParameters(text, parameters, prompt, cortexRequest) {
66
+ const requestParameters = super.getRequestParameters(text, parameters, prompt, cortexRequest);
67
+ const { system, modifiedMessages } = this.convertMessagesToClaudeVertex(requestParameters.messages);
68
+ requestParameters.system = system;
69
+ requestParameters.messages = modifiedMessages;
70
+ requestParameters.max_tokens = this.getModelMaxReturnTokens();
71
+ requestParameters.anthropic_version = 'vertex-2023-10-16';
72
+ return requestParameters;
73
+ }
74
+
75
+ async execute(text, parameters, prompt, cortexRequest) {
76
+ const requestParameters = this.getRequestParameters(text, parameters, prompt, cortexRequest);
77
+ const { stream } = parameters;
78
+
79
+ cortexRequest.data = { ...(cortexRequest.data || {}), ...requestParameters };
80
+ cortexRequest.params = {}; // query params
81
+ cortexRequest.stream = stream;
82
+ cortexRequest.url = cortexRequest.stream ? `${cortexRequest.url}:streamRawPredict` : `${cortexRequest.url}:rawPredict`;
83
+
84
+ const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
85
+ const authToken = await gcpAuthTokenHelper.getAccessToken();
86
+ cortexRequest.headers.Authorization = `Bearer ${authToken}`;
87
+
88
+ return this.executeRequest(cortexRequest);
89
+ }
90
+
91
+ processStreamEvent(event, requestProgress) {
92
+ const eventData = JSON.parse(event.data);
93
+ switch (eventData.type) {
94
+ case 'message_start':
95
+ requestProgress.data = JSON.stringify(eventData.message);
96
+ break;
97
+ case 'content_block_start':
98
+ break;
99
+ case 'ping':
100
+ break;
101
+ case 'content_block_delta':
102
+ if (eventData.delta.type === 'text_delta') {
103
+ requestProgress.data = JSON.stringify(eventData.delta.text);
104
+ }
105
+ break;
106
+ case 'content_block_stop':
107
+ break;
108
+ case 'message_delta':
109
+ break;
110
+ case 'message_stop':
111
+ requestProgress.data = '[DONE]';
112
+ requestProgress.progress = 1;
113
+ break;
114
+ case 'error':
115
+ requestProgress.data = `\n\n*** ${eventData.error.message || eventData.error} ***`;
116
+ requestProgress.progress = 1;
117
+ break;
118
+ }
119
+
120
+ return requestProgress;
121
+
122
+ }
123
+
124
+ }
125
+
126
+ export default Claude3VertexPlugin;
@@ -5,8 +5,18 @@ import logger from '../../lib/logger.js';
5
5
  const mergeResults = (data) => {
6
6
  let output = '';
7
7
  let safetyRatings = [];
8
+ const RESPONSE_BLOCKED = 'The response was blocked because the input or response potentially violates policies. Try rephrasing the prompt or adjusting the parameter settings.';
8
9
 
9
10
  for (let chunk of data) {
11
+ const { promptfeedback } = chunk;
12
+ if (promptfeedback) {
13
+ const { blockReason } = promptfeedback;
14
+ if (blockReason) {
15
+ logger.warn(`Response blocked due to prompt feedback: ${blockReason}`);
16
+ return {mergedResult: RESPONSE_BLOCKED, safetyRatings: safetyRatings};
17
+ }
18
+ }
19
+
10
20
  const { candidates } = chunk;
11
21
  if (!candidates || !candidates.length) {
12
22
  continue;
@@ -15,7 +25,8 @@ const mergeResults = (data) => {
15
25
  // If it was blocked, return the blocked message
16
26
  if (candidates[0].safetyRatings.some(rating => rating.blocked)) {
17
27
  safetyRatings = candidates[0].safetyRatings;
18
- return {mergedResult: 'The response was blocked because the input or response potentially violates policies. Try rephrasing the prompt or adjusting the parameter settings.', safetyRatings: safetyRatings};
28
+ logger.warn(`Response blocked due to safety ratings: ${JSON.stringify(safetyRatings, null, 2)}`);
29
+ return {mergedResult: RESPONSE_BLOCKED, safetyRatings: safetyRatings};
19
30
  }
20
31
 
21
32
  // Append the content of the first part of the first candidate to the output
@@ -236,8 +236,11 @@ class ModelPlugin {
236
236
 
237
237
  getLength(data) {
238
238
  const isProd = config.get('env') === 'production';
239
- const length = isProd ? data.length : encode(data).length;
240
- const units = isProd ? 'characters' : 'tokens';
239
+ let length = 0;
240
+ let units = isProd ? 'characters' : 'tokens';
241
+ if (data) {
242
+ length = isProd ? data.length : encode(data).length;
243
+ }
241
244
  return {length, units};
242
245
  }
243
246
 
@@ -288,6 +291,42 @@ class ModelPlugin {
288
291
  }
289
292
  }
290
293
 
294
+ processStreamEvent(event, requestProgress) {
295
+ // check for end of stream or in-stream errors
296
+ if (event.data.trim() === '[DONE]') {
297
+ requestProgress.progress = 1;
298
+ } else {
299
+ let parsedMessage;
300
+ try {
301
+ parsedMessage = JSON.parse(event.data);
302
+ requestProgress.data = event.data;
303
+ } catch (error) {
304
+ throw new Error(`Could not parse stream data: ${error}`);
305
+ }
306
+
307
+ // error can be in different places in the message
308
+ const streamError = parsedMessage?.error || parsedMessage?.choices?.[0]?.delta?.content?.error || parsedMessage?.choices?.[0]?.text?.error;
309
+ if (streamError) {
310
+ throw new Error(streamError);
311
+ }
312
+
313
+ // finish reason can be in different places in the message
314
+ const finishReason = parsedMessage?.choices?.[0]?.finish_reason || parsedMessage?.candidates?.[0]?.finishReason;
315
+ if (finishReason?.toLowerCase() === 'stop') {
316
+ requestProgress.progress = 1;
317
+ } else {
318
+ if (finishReason?.toLowerCase() === 'safety') {
319
+ const safetyRatings = JSON.stringify(parsedMessage?.candidates?.[0]?.safetyRatings) || '';
320
+ logger.warn(`Request ${this.requestId} was blocked by the safety filter. ${safetyRatings}`);
321
+ requestProgress.data = `\n\nResponse blocked by safety filter: ${safetyRatings}`;
322
+ requestProgress.progress = 1;
323
+ }
324
+ }
325
+ }
326
+ return requestProgress;
327
+ }
328
+
329
+
291
330
  }
292
331
 
293
332
  export default ModelPlugin;
@@ -87,6 +87,7 @@ class OpenAIChatPlugin extends ModelPlugin {
87
87
 
88
88
  // Parse the response from the OpenAI Chat API
89
89
  parseResponse(data) {
90
+ if(!data) return "";
90
91
  const { choices } = data;
91
92
  if (!choices || !choices.length) {
92
93
  return data;
@@ -100,7 +100,13 @@ function alignSubtitles(subtitles, format) {
100
100
  const result = [];
101
101
 
102
102
  function preprocessStr(str) {
103
- return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
103
+ try{
104
+ if(!str) return '';
105
+ return str.trim().replace(/(\n\n)(?!\n)/g, '\n\n\n');
106
+ }catch(e){
107
+ logger.error(`An error occurred in content text preprocessing: ${e}`);
108
+ return '';
109
+ }
104
110
  }
105
111
 
106
112
  function shiftSubtitles(subtitle, shiftOffset) {
@@ -14,6 +14,9 @@ class PalmChatPlugin extends ModelPlugin {
14
14
  let modifiedMessages = [];
15
15
  let lastAuthor = '';
16
16
 
17
+ // remove any empty messages
18
+ messages = messages.filter(message => message.content);
19
+
17
20
  messages.forEach(message => {
18
21
  const { role, author, content } = message;
19
22
 
@@ -153,7 +156,7 @@ class PalmChatPlugin extends ModelPlugin {
153
156
  parseResponse(data) {
154
157
  const { predictions } = data;
155
158
  if (!predictions || !predictions.length) {
156
- return null;
159
+ return data;
157
160
  }
158
161
 
159
162
  // Get the candidates array from the first prediction
package/server/rest.js CHANGED
@@ -148,6 +148,10 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
148
148
  } else if (messageJson.candidates) {
149
149
  const { content, finishReason } = messageJson.candidates[0];
150
150
  fillJsonResponse(jsonResponse, content.parts[0].text, finishReason);
151
+ } else if (messageJson.content) {
152
+ const text = messageJson.content?.[0]?.text || '';
153
+ const finishReason = messageJson.stop_reason;
154
+ fillJsonResponse(jsonResponse, text, finishReason);
151
155
  } else {
152
156
  fillJsonResponse(jsonResponse, messageJson, null);
153
157
  }