@aj-archipelago/cortex 1.1.1 → 1.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/.eslintignore CHANGED
@@ -18,6 +18,12 @@
18
18
  # Ignore documentation
19
19
  /docs
20
20
 
21
+ # Ignore helper apps
22
+ /helper_apps
23
+
24
+ # Ignore tests
25
+ /tests
26
+
21
27
  # Ignore any generated or bundled files
22
28
  *.min.js
23
29
  *.bundle.js
package/.eslintrc CHANGED
@@ -16,8 +16,9 @@
16
16
  ],
17
17
  "rules": {
18
18
  "import/no-unresolved": "error",
19
- "import/no-extraneous-dependencies": ["error", {"devDependencies": true, "dependencies": true}],
20
- "no-unused-vars": ["error", { "argsIgnorePattern": "^_" }]
19
+ "import/no-extraneous-dependencies": ["error", {"devDependencies": true}],
20
+ "no-unused-vars": ["error", { "argsIgnorePattern": "^_" }],
21
+ "no-useless-escape": "off"
21
22
  },
22
23
  "settings": {
23
24
  "import/resolver": {
@@ -46,7 +46,7 @@ if (connectionString) {
46
46
  channels.forEach(channel => {
47
47
  subscriptionClient.subscribe(channel, (error) => {
48
48
  if (error) {
49
- logger.error(`Error subscribing to redis channel ${channel}: ${error}`);
49
+ logger.error(`Error subscribing to Redis channel ${channel}: ${error}`);
50
50
  } else {
51
51
  logger.info(`Subscribed to channel ${channel}`);
52
52
  }
@@ -55,26 +55,22 @@ if (connectionString) {
55
55
  });
56
56
 
57
57
  subscriptionClient.on('message', (channel, message) => {
58
- logger.debug(`Received message from ${channel}: ${message}`);
58
+ logger.debug(`Received message from Redis channel ${channel}: ${message}`);
59
59
 
60
- let decryptedMessage;
61
-
62
- if (channel === requestProgressChannel && redisEncryptionKey) {
63
- try {
64
- decryptedMessage = decrypt(message, redisEncryptionKey);
65
- } catch (error) {
66
- logger.error(`Error decrypting message: ${error}`);
67
- }
68
- }
69
-
70
- decryptedMessage = decryptedMessage || message;
71
-
72
60
  let parsedMessage;
73
61
 
74
62
  try {
75
- parsedMessage = JSON.parse(decryptedMessage);
63
+ parsedMessage = JSON.parse(message);
76
64
  } catch (error) {
77
- logger.error(`Error parsing message: ${error}`);
65
+ if (channel === requestProgressChannel && redisEncryptionKey) {
66
+ try {
67
+ parsedMessage = JSON.parse(decrypt(message, redisEncryptionKey));
68
+ } catch (error) {
69
+ logger.error(`Error parsing or decrypting message: ${error}`);
70
+ }
71
+ } else {
72
+ logger.error(`Error parsing message: ${error}`);
73
+ }
78
74
  }
79
75
 
80
76
  switch(channel) {
@@ -96,7 +92,7 @@ if (connectionString) {
96
92
  }
97
93
 
98
94
  async function publishRequestProgress(data) {
99
- if (publisherClient) {
95
+ if (publisherClient && requestState?.[data?.requestId]?.useRedis) {
100
96
  try {
101
97
  let message = JSON.stringify(data);
102
98
  if (redisEncryptionKey) {
@@ -106,10 +102,10 @@ async function publishRequestProgress(data) {
106
102
  logger.error(`Error encrypting message: ${error}`);
107
103
  }
108
104
  }
109
- logger.debug(`Publishing message ${message} to channel ${requestProgressChannel}`);
105
+ logger.debug(`Publishing request progress ${message} to Redis channel ${requestProgressChannel}`);
110
106
  await publisherClient.publish(requestProgressChannel, message);
111
107
  } catch (error) {
112
- logger.error(`Error publishing message: ${error}`);
108
+ logger.error(`Error publishing request progress to Redis: ${error}`);
113
109
  }
114
110
  } else {
115
111
  pubsubHandleMessage(data);
@@ -119,11 +115,30 @@ async function publishRequestProgress(data) {
119
115
  async function publishRequestProgressSubscription(data) {
120
116
  if (publisherClient) {
121
117
  try {
122
- const message = JSON.stringify(data);
123
- logger.debug(`Publishing message ${message} to channel ${requestProgressSubscriptionsChannel}`);
124
- await publisherClient.publish(requestProgressSubscriptionsChannel, message);
118
+ const requestIds = data;
119
+ const idsToForward = [];
120
+ // If any of these requests belong to this instance, we can just start and handle them locally
121
+ for (const requestId of requestIds) {
122
+ if (requestState[requestId]) {
123
+ if (!requestState[requestId].started) {
124
+ requestState[requestId].started = true;
125
+ requestState[requestId].useRedis = false;
126
+ logger.info(`Starting local execution for registered async request: ${requestId}`);
127
+ const { resolver, args } = requestState[requestId];
128
+ resolver(args, false);
129
+ }
130
+ } else {
131
+ idsToForward.push(requestId);
132
+ }
133
+ }
134
+
135
+ if (idsToForward.length > 0) {
136
+ const message = JSON.stringify(idsToForward);
137
+ logger.debug(`Sending subscription request(s) to channel ${requestProgressSubscriptionsChannel} for remote execution: ${message}`);
138
+ await publisherClient.publish(requestProgressSubscriptionsChannel, message);
139
+ }
125
140
  } catch (error) {
126
- logger.error(`Error publishing message: ${error}`);
141
+ logger.error(`Error handling subscription: ${error}`);
127
142
  }
128
143
  } else {
129
144
  handleSubscription(data);
@@ -132,11 +147,11 @@ async function publishRequestProgressSubscription(data) {
132
147
 
133
148
  function pubsubHandleMessage(data){
134
149
  const message = JSON.stringify(data);
135
- logger.debug(`Publishing message to pubsub: ${message}`);
150
+ logger.debug(`Publishing request progress to local subscribers: ${message}`);
136
151
  try {
137
152
  pubsub.publish('REQUEST_PROGRESS', { requestProgress: data });
138
153
  } catch (error) {
139
- logger.error(`Error publishing data to pubsub: ${error}`);
154
+ logger.error(`Error publishing request progress to local subscribers: ${error}`);
140
155
  }
141
156
  }
142
157
 
@@ -145,7 +160,8 @@ function handleSubscription(data){
145
160
  for (const requestId of requestIds) {
146
161
  if (requestState[requestId] && !requestState[requestId].started) {
147
162
  requestState[requestId].started = true;
148
- logger.info(`Subscription starting async requestProgress, requestId: ${requestId}`);
163
+ requestState[requestId].useRedis = true;
164
+ logger.info(`Starting execution for registered async request: ${requestId}`);
149
165
  const { resolver, args } = requestState[requestId];
150
166
  resolver(args);
151
167
  }
package/lib/request.js CHANGED
@@ -46,7 +46,6 @@ const buildLimiters = (config) => {
46
46
  if (connection) {
47
47
  limiterOptions.id = `${cortexId}-${name}-limiter`; // Unique id for each limiter
48
48
  limiterOptions.connection = connection; // Shared Redis connection
49
- limiterOptions.clearDatastore = true; // Clear Redis datastore on startup
50
49
  }
51
50
 
52
51
  limiters[name] = new Bottleneck(limiterOptions);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aj-archipelago/cortex",
3
- "version": "1.1.1",
3
+ "version": "1.1.3",
4
4
  "description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
5
5
  "private": false,
6
6
  "repository": {
@@ -1,8 +1,6 @@
1
1
  // sys_openai_completion.js
2
2
  // default handler for openAI completion endpoints when REST endpoints are enabled
3
3
 
4
- import { Prompt } from '../server/prompt.js';
5
-
6
4
  export default {
7
5
  prompt: `{{text}}`,
8
6
  model: 'oai-gpturbo',
@@ -30,7 +30,7 @@ class AzureCognitivePlugin extends ModelPlugin {
30
30
  }
31
31
 
32
32
  // Set up parameters specific to the Azure Cognitive API
33
- async getRequestParameters(text, parameters, prompt, mode, indexName, savedContextId, {headers, requestId, pathway, url}) {
33
+ async getRequestParameters(text, parameters, prompt, mode, indexName, savedContextId, {headers, requestId, pathway, _url}) {
34
34
  const combinedParameters = { ...this.promptParameters, ...parameters };
35
35
  const { modelPromptText } = this.getCompiledPrompt(text, combinedParameters, prompt);
36
36
  const { inputVector, calculateInputVector, privateData, filter, docId } = combinedParameters;
@@ -1,5 +1,6 @@
1
1
  // AzureTranslatePlugin.js
2
2
  import ModelPlugin from './modelPlugin.js';
3
+ import logger from '../../lib/logger.js';
3
4
 
4
5
  class AzureTranslatePlugin extends ModelPlugin {
5
6
  constructor(config, pathway, modelName, model) {
@@ -8,7 +8,7 @@ class CohereGeneratePlugin extends ModelPlugin {
8
8
 
9
9
  // Set up parameters specific to the Cohere API
10
10
  getRequestParameters(text, parameters, prompt) {
11
- const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
11
+ let { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
12
12
 
13
13
  // Define the model's max token length
14
14
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
@@ -2,6 +2,7 @@
2
2
  import ModelPlugin from './modelPlugin.js';
3
3
  import { execFileSync } from 'child_process';
4
4
  import { encode } from 'gpt-3-encoder';
5
+ import logger from '../../lib/logger.js';
5
6
 
6
7
  class LocalModelPlugin extends ModelPlugin {
7
8
  constructor(config, pathway, modelName, model) {
@@ -1,6 +1,5 @@
1
1
  // OpenAICompletionPlugin.js
2
2
 
3
- import { request } from 'https';
4
3
  import ModelPlugin from './modelPlugin.js';
5
4
  import { encode } from 'gpt-3-encoder';
6
5
  import logger from '../../lib/logger.js';
@@ -75,7 +75,6 @@ class PalmChatPlugin extends ModelPlugin {
75
75
  // Set up parameters specific to the PaLM Chat API
76
76
  getRequestParameters(text, parameters, prompt) {
77
77
  const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
78
- const { stream } = parameters;
79
78
 
80
79
  // Define the model's max token length
81
80
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
@@ -11,7 +11,6 @@ class PalmCodeCompletionPlugin extends PalmCompletionPlugin {
11
11
  // Set up parameters specific to the PaLM API Code Completion API
12
12
  getRequestParameters(text, parameters, prompt, pathwayResolver) {
13
13
  const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
14
- const { stream } = parameters;
15
14
  // Define the model's max token length
16
15
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
17
16
 
@@ -22,7 +22,7 @@ class PalmCompletionPlugin extends ModelPlugin {
22
22
  // Set up parameters specific to the PaLM API Text Completion API
23
23
  getRequestParameters(text, parameters, prompt, pathwayResolver) {
24
24
  const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
25
- const { stream } = parameters;
25
+
26
26
  // Define the model's max token length
27
27
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
28
28
 
package/server/rest.js CHANGED
@@ -1,7 +1,6 @@
1
1
  // rest.js
2
2
  // Implement the REST endpoints for the pathways
3
3
 
4
- import { json } from 'express';
5
4
  import pubsub from './pubsub.js';
6
5
  import { requestState } from './requestState.js';
7
6
  import { v4 as uuidv4 } from 'uuid';
@@ -168,6 +167,9 @@ const processIncomingStream = (requestId, res, jsonResponse) => {
168
167
  // Fire the resolver for the async requestProgress
169
168
  logger.info(`Rest Endpoint starting async requestProgress, requestId: ${requestId}`);
170
169
  const { resolver, args } = requestState[requestId];
170
+ requestState[requestId].useRedis = false;
171
+ requestState[requestId].started = true;
172
+
171
173
  resolver(args);
172
174
 
173
175
  return subscription;
@@ -236,17 +238,18 @@ function buildRestEndpoints(pathways, app, server, config) {
236
238
  ],
237
239
  };
238
240
 
241
+ // eslint-disable-next-line no-extra-boolean-cast
239
242
  if (Boolean(req.body.stream)) {
240
243
  jsonResponse.id = `cmpl-${resultText}`;
241
244
  jsonResponse.choices[0].finish_reason = null;
242
245
  //jsonResponse.object = "text_completion.chunk";
243
246
 
244
- const subscription = processIncomingStream(resultText, res, jsonResponse);
247
+ processIncomingStream(resultText, res, jsonResponse);
245
248
  } else {
246
249
  const requestId = uuidv4();
247
250
  jsonResponse.id = `cmpl-${requestId}`;
248
251
  res.json(jsonResponse);
249
- };
252
+ }
250
253
  });
251
254
 
252
255
  app.post('/v1/chat/completions', async (req, res) => {
@@ -281,6 +284,7 @@ function buildRestEndpoints(pathways, app, server, config) {
281
284
  ],
282
285
  };
283
286
 
287
+ // eslint-disable-next-line no-extra-boolean-cast
284
288
  if (Boolean(req.body.stream)) {
285
289
  jsonResponse.id = `chatcmpl-${resultText}`;
286
290
  jsonResponse.choices[0] = {
@@ -292,7 +296,7 @@ function buildRestEndpoints(pathways, app, server, config) {
292
296
  }
293
297
  jsonResponse.object = "chat.completion.chunk";
294
298
 
295
- const subscription = processIncomingStream(resultText, res, jsonResponse);
299
+ processIncomingStream(resultText, res, jsonResponse);
296
300
  } else {
297
301
  const requestId = uuidv4();
298
302
  jsonResponse.id = `chatcmpl-${requestId}`;
@@ -330,6 +334,6 @@ function buildRestEndpoints(pathways, app, server, config) {
330
334
  });
331
335
 
332
336
  }
333
- };
337
+ }
334
338
 
335
339
  export { buildRestEndpoints };
@@ -1,12 +1,13 @@
1
1
  import pubsub from './pubsub.js';
2
- import logger from '../lib/logger.js';
3
2
  import { withFilter } from 'graphql-subscriptions';
4
3
  import { publishRequestProgressSubscription } from '../lib/redisSubscription.js';
4
+ import logger from '../lib/logger.js';
5
5
 
6
6
  const subscriptions = {
7
7
  requestProgress: {
8
8
  subscribe: withFilter(
9
9
  (_, args, __, _info) => {
10
+ logger.debug(`Client requested subscription for request ids: ${args.requestIds}`);
10
11
  publishRequestProgressSubscription(args.requestIds);
11
12
  return pubsub.asyncIterator(['REQUEST_PROGRESS'])
12
13
  },