@aj-archipelago/cortex 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/.eslintignore CHANGED
@@ -18,6 +18,12 @@
18
18
  # Ignore documentation
19
19
  /docs
20
20
 
21
+ # Ignore helper apps
22
+ /helper_apps
23
+
24
+ # Ignore tests
25
+ /tests
26
+
21
27
  # Ignore any generated or bundled files
22
28
  *.min.js
23
29
  *.bundle.js
package/.eslintrc CHANGED
@@ -16,8 +16,9 @@
16
16
  ],
17
17
  "rules": {
18
18
  "import/no-unresolved": "error",
19
- "import/no-extraneous-dependencies": ["error", {"devDependencies": true, "dependencies": true}],
20
- "no-unused-vars": ["error", { "argsIgnorePattern": "^_" }]
19
+ "import/no-extraneous-dependencies": ["error", {"devDependencies": true}],
20
+ "no-unused-vars": ["error", { "argsIgnorePattern": "^_" }],
21
+ "no-useless-escape": "off"
21
22
  },
22
23
  "settings": {
23
24
  "import/resolver": {
@@ -57,7 +57,7 @@ if (connectionString) {
57
57
  subscriptionClient.on('message', (channel, message) => {
58
58
  logger.debug(`Received message from ${channel}: ${message}`);
59
59
 
60
- let decryptedMessage = message;
60
+ let decryptedMessage;
61
61
 
62
62
  if (channel === requestProgressChannel && redisEncryptionKey) {
63
63
  try {
@@ -67,7 +67,10 @@ if (connectionString) {
67
67
  }
68
68
  }
69
69
 
70
- let parsedMessage = decryptedMessage;
70
+ decryptedMessage = decryptedMessage || message;
71
+
72
+ let parsedMessage;
73
+
71
74
  try {
72
75
  parsedMessage = JSON.parse(decryptedMessage);
73
76
  } catch (error) {
@@ -76,10 +79,10 @@ if (connectionString) {
76
79
 
77
80
  switch(channel) {
78
81
  case requestProgressChannel:
79
- pubsubHandleMessage(parsedMessage);
82
+ parsedMessage && pubsubHandleMessage(parsedMessage);
80
83
  break;
81
84
  case requestProgressSubscriptionsChannel:
82
- handleSubscription(parsedMessage);
85
+ parsedMessage && handleSubscription(parsedMessage);
83
86
  break;
84
87
  default:
85
88
  logger.error(`Unsupported channel: ${channel}`);
@@ -92,8 +95,8 @@ if (connectionString) {
92
95
  logger.info(`Using pubsub publish for channel ${requestProgressChannel}`);
93
96
  }
94
97
 
95
- async function publishRequestProgress(data) {
96
- if (publisherClient) {
98
+ async function publishRequestProgress(data, useRedis = true) {
99
+ if (publisherClient && useRedis) {
97
100
  try {
98
101
  let message = JSON.stringify(data);
99
102
  if (redisEncryptionKey) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aj-archipelago/cortex",
3
- "version": "1.1.0",
3
+ "version": "1.1.2",
4
4
  "description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
5
5
  "private": false,
6
6
  "repository": {
@@ -1,8 +1,6 @@
1
1
  // sys_openai_completion.js
2
2
  // default handler for openAI completion endpoints when REST endpoints are enabled
3
3
 
4
- import { Prompt } from '../server/prompt.js';
5
-
6
4
  export default {
7
5
  prompt: `{{text}}`,
8
6
  model: 'oai-gpturbo',
@@ -70,7 +70,7 @@ class PathwayResolver {
70
70
  // the graphql subscription to send progress updates to the client. Most of
71
71
  // the time the client will be an external client, but it could also be the
72
72
  // Cortex REST api code.
73
- async asyncResolve(args) {
73
+ async asyncResolve(args, useRedis = true) {
74
74
  const MAX_RETRY_COUNT = 3;
75
75
  let attempt = 0;
76
76
  let streamErrorOccurred = false;
@@ -88,7 +88,7 @@ class PathwayResolver {
88
88
  requestId: this.requestId,
89
89
  progress: completedCount / totalCount,
90
90
  data: JSON.stringify(responseData),
91
- });
91
+ }, useRedis);
92
92
  }
93
93
  } else {
94
94
  try {
@@ -140,7 +140,7 @@ class PathwayResolver {
140
140
 
141
141
  try {
142
142
  //logger.info(`Publishing stream message to requestId ${this.requestId}: ${message}`);
143
- publishRequestProgress(requestProgress);
143
+ publishRequestProgress(requestProgress, useRedis);
144
144
  } catch (error) {
145
145
  logger.error(`Could not publish the stream message: "${messageBuffer}", ${error}`);
146
146
  }
@@ -30,7 +30,7 @@ class AzureCognitivePlugin extends ModelPlugin {
30
30
  }
31
31
 
32
32
  // Set up parameters specific to the Azure Cognitive API
33
- async getRequestParameters(text, parameters, prompt, mode, indexName, savedContextId, {headers, requestId, pathway, url}) {
33
+ async getRequestParameters(text, parameters, prompt, mode, indexName, savedContextId, {headers, requestId, pathway, _url}) {
34
34
  const combinedParameters = { ...this.promptParameters, ...parameters };
35
35
  const { modelPromptText } = this.getCompiledPrompt(text, combinedParameters, prompt);
36
36
  const { inputVector, calculateInputVector, privateData, filter, docId } = combinedParameters;
@@ -1,5 +1,6 @@
1
1
  // AzureTranslatePlugin.js
2
2
  import ModelPlugin from './modelPlugin.js';
3
+ import logger from '../../lib/logger.js';
3
4
 
4
5
  class AzureTranslatePlugin extends ModelPlugin {
5
6
  constructor(config, pathway, modelName, model) {
@@ -8,7 +8,7 @@ class CohereGeneratePlugin extends ModelPlugin {
8
8
 
9
9
  // Set up parameters specific to the Cohere API
10
10
  getRequestParameters(text, parameters, prompt) {
11
- const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
11
+ let { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
12
12
 
13
13
  // Define the model's max token length
14
14
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
@@ -2,6 +2,7 @@
2
2
  import ModelPlugin from './modelPlugin.js';
3
3
  import { execFileSync } from 'child_process';
4
4
  import { encode } from 'gpt-3-encoder';
5
+ import logger from '../../lib/logger.js';
5
6
 
6
7
  class LocalModelPlugin extends ModelPlugin {
7
8
  constructor(config, pathway, modelName, model) {
@@ -1,6 +1,5 @@
1
1
  // OpenAICompletionPlugin.js
2
2
 
3
- import { request } from 'https';
4
3
  import ModelPlugin from './modelPlugin.js';
5
4
  import { encode } from 'gpt-3-encoder';
6
5
  import logger from '../../lib/logger.js';
@@ -75,7 +75,6 @@ class PalmChatPlugin extends ModelPlugin {
75
75
  // Set up parameters specific to the PaLM Chat API
76
76
  getRequestParameters(text, parameters, prompt) {
77
77
  const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
78
- const { stream } = parameters;
79
78
 
80
79
  // Define the model's max token length
81
80
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
@@ -11,7 +11,6 @@ class PalmCodeCompletionPlugin extends PalmCompletionPlugin {
11
11
  // Set up parameters specific to the PaLM API Code Completion API
12
12
  getRequestParameters(text, parameters, prompt, pathwayResolver) {
13
13
  const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
14
- const { stream } = parameters;
15
14
  // Define the model's max token length
16
15
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
17
16
 
@@ -22,7 +22,7 @@ class PalmCompletionPlugin extends ModelPlugin {
22
22
  // Set up parameters specific to the PaLM API Text Completion API
23
23
  getRequestParameters(text, parameters, prompt, pathwayResolver) {
24
24
  const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
25
- const { stream } = parameters;
25
+
26
26
  // Define the model's max token length
27
27
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
28
28
 
package/server/rest.js CHANGED
@@ -1,7 +1,6 @@
1
1
  // rest.js
2
2
  // Implement the REST endpoints for the pathways
3
3
 
4
- import { json } from 'express';
5
4
  import pubsub from './pubsub.js';
6
5
  import { requestState } from './requestState.js';
7
6
  import { v4 as uuidv4 } from 'uuid';
@@ -168,7 +167,11 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
168
167
  // Fire the resolver for the async requestProgress
169
168
  logger.info(`Rest Endpoint starting async requestProgress, requestId: ${requestId}`);
170
169
  const { resolver, args } = requestState[requestId];
171
- resolver(args);
170
+ // The false here means never use a Redis subscription channel
171
+ // to handle these streaming messages. This is because we are
172
+ // guaranteed in this case that the stream is going to the same
173
+ // client.
174
+ resolver(args, false);
172
175
 
173
176
  return subscription;
174
177
 
@@ -236,17 +239,18 @@ function buildRestEndpoints(pathways, app, server, config) {
236
239
  ],
237
240
  };
238
241
 
242
+ // eslint-disable-next-line no-extra-boolean-cast
239
243
  if (Boolean(req.body.stream)) {
240
244
  jsonResponse.id = `cmpl-${resultText}`;
241
245
  jsonResponse.choices[0].finish_reason = null;
242
246
  //jsonResponse.object = "text_completion.chunk";
243
247
 
244
- const subscription = processIncomingStream(resultText, res, jsonResponse);
248
+ processIncomingStream(resultText, res, jsonResponse);
245
249
  } else {
246
250
  const requestId = uuidv4();
247
251
  jsonResponse.id = `cmpl-${requestId}`;
248
252
  res.json(jsonResponse);
249
- };
253
+ }
250
254
  });
251
255
 
252
256
  app.post('/v1/chat/completions', async (req, res) => {
@@ -281,6 +285,7 @@ function buildRestEndpoints(pathways, app, server, config) {
281
285
  ],
282
286
  };
283
287
 
288
+ // eslint-disable-next-line no-extra-boolean-cast
284
289
  if (Boolean(req.body.stream)) {
285
290
  jsonResponse.id = `chatcmpl-${resultText}`;
286
291
  jsonResponse.choices[0] = {
@@ -292,7 +297,7 @@ function buildRestEndpoints(pathways, app, server, config) {
292
297
  }
293
298
  jsonResponse.object = "chat.completion.chunk";
294
299
 
295
- const subscription = processIncomingStream(resultText, res, jsonResponse);
300
+ processIncomingStream(resultText, res, jsonResponse);
296
301
  } else {
297
302
  const requestId = uuidv4();
298
303
  jsonResponse.id = `chatcmpl-${requestId}`;
@@ -330,6 +335,6 @@ function buildRestEndpoints(pathways, app, server, config) {
330
335
  });
331
336
 
332
337
  }
333
- };
338
+ }
334
339
 
335
340
  export { buildRestEndpoints };
@@ -1,5 +1,4 @@
1
1
  import pubsub from './pubsub.js';
2
- import logger from '../lib/logger.js';
3
2
  import { withFilter } from 'graphql-subscriptions';
4
3
  import { publishRequestProgressSubscription } from '../lib/redisSubscription.js';
5
4