@aj-archipelago/cortex 1.0.5 → 1.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. package/README.md +2 -2
  2. package/config/default.example.json +4 -2
  3. package/config.js +14 -8
  4. package/helper_apps/WhisperX/.dockerignore +27 -0
  5. package/helper_apps/WhisperX/Dockerfile +31 -0
  6. package/helper_apps/WhisperX/app-ts.py +76 -0
  7. package/helper_apps/WhisperX/app.py +115 -0
  8. package/helper_apps/WhisperX/docker-compose.debug.yml +12 -0
  9. package/helper_apps/WhisperX/docker-compose.yml +10 -0
  10. package/helper_apps/WhisperX/requirements.txt +6 -0
  11. package/index.js +1 -1
  12. package/lib/redisSubscription.js +1 -1
  13. package/package.json +8 -7
  14. package/pathways/basePathway.js +3 -2
  15. package/pathways/index.js +4 -0
  16. package/pathways/summary.js +2 -2
  17. package/pathways/sys_openai_chat.js +19 -0
  18. package/pathways/sys_openai_completion.js +11 -0
  19. package/pathways/test_palm_chat.js +1 -1
  20. package/pathways/transcribe.js +2 -1
  21. package/{graphql → server}/chunker.js +48 -3
  22. package/{graphql → server}/graphql.js +70 -62
  23. package/{graphql → server}/pathwayPrompter.js +14 -17
  24. package/{graphql → server}/pathwayResolver.js +59 -42
  25. package/{graphql → server}/plugins/azureTranslatePlugin.js +2 -2
  26. package/{graphql → server}/plugins/localModelPlugin.js +2 -2
  27. package/{graphql → server}/plugins/modelPlugin.js +8 -10
  28. package/{graphql → server}/plugins/openAiChatPlugin.js +13 -8
  29. package/{graphql → server}/plugins/openAiCompletionPlugin.js +9 -3
  30. package/{graphql → server}/plugins/openAiWhisperPlugin.js +30 -7
  31. package/{graphql → server}/plugins/palmChatPlugin.js +4 -6
  32. package/server/plugins/palmCodeCompletionPlugin.js +46 -0
  33. package/{graphql → server}/plugins/palmCompletionPlugin.js +13 -15
  34. package/server/rest.js +321 -0
  35. package/{graphql → server}/typeDef.js +30 -13
  36. package/tests/chunkfunction.test.js +112 -26
  37. package/tests/config.test.js +1 -1
  38. package/tests/main.test.js +282 -43
  39. package/tests/mocks.js +43 -2
  40. package/tests/modelPlugin.test.js +4 -4
  41. package/tests/openAiChatPlugin.test.js +21 -14
  42. package/tests/openai_api.test.js +147 -0
  43. package/tests/palmChatPlugin.test.js +10 -11
  44. package/tests/palmCompletionPlugin.test.js +3 -4
  45. package/tests/pathwayResolver.test.js +1 -1
  46. package/tests/truncateMessages.test.js +4 -5
  47. package/pathways/completions.js +0 -17
  48. package/pathways/test_oai_chat.js +0 -18
  49. package/pathways/test_oai_cmpl.js +0 -13
  50. package/tests/chunking.test.js +0 -157
  51. package/tests/translate.test.js +0 -126
  52. /package/{graphql → server}/parser.js +0 -0
  53. /package/{graphql → server}/pathwayResponseParser.js +0 -0
  54. /package/{graphql → server}/prompt.js +0 -0
  55. /package/{graphql → server}/pubsub.js +0 -0
  56. /package/{graphql → server}/requestState.js +0 -0
  57. /package/{graphql → server}/resolver.js +0 -0
  58. /package/{graphql → server}/subscriptions.js +0 -0
@@ -1,24 +1,29 @@
1
- import { createServer } from 'http';
2
- import {
3
- ApolloServerPluginDrainHttpServer,
4
- ApolloServerPluginLandingPageLocalDefault,
5
- } from 'apollo-server-core';
1
+ // graphql.js
2
+ // Setup the Apollo server and Express middleware
3
+
4
+ import { ApolloServerPluginDrainHttpServer } from '@apollo/server/plugin/drainHttpServer';
5
+ import { ApolloServerPluginLandingPageLocalDefault } from '@apollo/server/plugin/landingPage/default';
6
+ import { ApolloServer } from '@apollo/server';
7
+ import { expressMiddleware } from '@apollo/server/express4';
6
8
  import { makeExecutableSchema } from '@graphql-tools/schema';
7
9
  import { WebSocketServer } from 'ws';
8
10
  import { useServer } from 'graphql-ws/lib/use/ws';
9
11
  import express from 'express';
10
- import { ApolloServer } from 'apollo-server-express';
12
+ import http from 'http';
11
13
  import Keyv from 'keyv';
14
+ import cors from 'cors';
12
15
  import { KeyvAdapter } from '@apollo/utils.keyvadapter';
13
- import responseCachePlugin from 'apollo-server-plugin-response-cache';
16
+ import responseCachePlugin from '@apollo/server-plugin-response-cache';
14
17
  import subscriptions from './subscriptions.js';
15
18
  import { buildLimiters } from '../lib/request.js';
16
19
  import { cancelRequestResolver } from './resolver.js';
17
20
  import { buildPathways, buildModels } from '../config.js';
18
21
  import { requestState } from './requestState.js';
22
+ import { buildRestEndpoints } from './rest.js';
19
23
 
24
+ // Utility functions
25
+ // Server plugins
20
26
  const getPlugins = (config) => {
21
- // server plugins
22
27
  const plugins = [
23
28
  ApolloServerPluginLandingPageLocalDefault({ embed: true }), // For local development.
24
29
  ];
@@ -39,41 +44,8 @@ const getPlugins = (config) => {
39
44
  return { plugins, cache };
40
45
  }
41
46
 
42
- const buildRestEndpoints = (pathways, app, server, config) => {
43
- for (const [name, pathway] of Object.entries(pathways)) {
44
- // Only expose endpoints for enabled pathways that explicitly want to expose a REST endpoint
45
- if (pathway.disabled || !config.get('enableRestEndpoints')) continue;
46
-
47
- const fieldVariableDefs = pathway.typeDef(pathway).restDefinition || [];
48
-
49
- app.post(`/rest/${name}`, async (req, res) => {
50
- const variables = fieldVariableDefs.reduce((acc, variableDef) => {
51
- if (Object.prototype.hasOwnProperty.call(req.body, variableDef.name)) {
52
- acc[variableDef.name] = req.body[variableDef.name];
53
- }
54
- return acc;
55
- }, {});
56
-
57
- const variableParams = fieldVariableDefs.map(({ name, type }) => `$${name}: ${type}`).join(', ');
58
- const queryArgs = fieldVariableDefs.map(({ name }) => `${name}: $${name}`).join(', ');
59
-
60
- const query = `
61
- query ${name}(${variableParams}) {
62
- ${name}(${queryArgs}) {
63
- contextId
64
- previousResult
65
- result
66
- }
67
- }
68
- `;
69
-
70
- const result = await server.executeOperation({ query, variables });
71
- res.json(result.data[name]);
72
- });
73
- }
74
- };
75
47
 
76
- //typeDefs
48
+ // Type Definitions for GraphQL
77
49
  const getTypedefs = (pathways) => {
78
50
 
79
51
  const defaultTypeDefs = `#graphql
@@ -111,6 +83,7 @@ const getTypedefs = (pathways) => {
111
83
  return typeDefs.join('\n');
112
84
  }
113
85
 
86
+ // Resolvers for GraphQL
114
87
  const getResolvers = (config, pathways) => {
115
88
  const resolverFunctions = {};
116
89
  for (const [name, pathway] of Object.entries(pathways)) {
@@ -118,6 +91,7 @@ const getResolvers = (config, pathways) => {
118
91
  resolverFunctions[name] = (parent, args, contextValue, info) => {
119
92
  // add shared state to contextValue
120
93
  contextValue.pathway = pathway;
94
+ contextValue.config = config;
121
95
  return pathway.rootResolver(parent, args, contextValue, info);
122
96
  }
123
97
  }
@@ -131,7 +105,7 @@ const getResolvers = (config, pathways) => {
131
105
  return resolvers;
132
106
  }
133
107
 
134
- //graphql api build factory method
108
+ // Build the server including the GraphQL schema and REST endpoints
135
109
  const build = async (config) => {
136
110
  // First perform config build
137
111
  await buildPathways(config);
@@ -150,9 +124,9 @@ const build = async (config) => {
150
124
 
151
125
  const { plugins, cache } = getPlugins(config);
152
126
 
153
- const app = express()
127
+ const app = express();
154
128
 
155
- const httpServer = createServer(app);
129
+ const httpServer = http.createServer(app);
156
130
 
157
131
  // Creating the WebSocket server
158
132
  const wsServer = new WebSocketServer({
@@ -182,35 +156,69 @@ const build = async (config) => {
182
156
  },
183
157
  };
184
158
  },
185
- }]),
186
- context: ({ req, res }) => ({ req, res, config, requestState }),
159
+ }
160
+ ]),
187
161
  });
188
162
 
189
163
  // If CORTEX_API_KEY is set, we roll our own auth middleware - usually not used if you're being fronted by a proxy
190
164
  const cortexApiKey = config.get('cortexApiKey');
165
+ if (cortexApiKey) {
166
+ app.use((req, res, next) => {
167
+ let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key'];
168
+ if (!providedApiKey) {
169
+ providedApiKey = req.headers['authorization'];
170
+ providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey;
171
+ }
172
+
173
+ if (cortexApiKey && cortexApiKey !== providedApiKey) {
174
+ if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') {
175
+ res.status(401)
176
+ .set('WWW-Authenticate', 'Cortex-Api-Key')
177
+ .set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key')
178
+ .json({
179
+ errors: [
180
+ {
181
+ message: 'Unauthorized',
182
+ extensions: {
183
+ code: 'UNAUTHENTICATED',
184
+ },
185
+ },
186
+ ],
187
+ });
188
+ } else {
189
+ res.status(401)
190
+ .set('WWW-Authenticate', 'Cortex-Api-Key')
191
+ .set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key')
192
+ .send('Unauthorized');
193
+ }
194
+ } else {
195
+ next();
196
+ }
197
+ });
198
+ };
191
199
 
192
- app.use((req, res, next) => {
193
- if (cortexApiKey && req.headers.cortexApiKey !== cortexApiKey && req.query.cortexApiKey !== cortexApiKey) {
194
- res.status(401).send('Unauthorized');
195
- } else {
196
- next();
197
- }
198
- });
199
-
200
- // Use the JSON body parser middleware for REST endpoints
200
+ // Parse the body for REST endpoints
201
201
  app.use(express.json());
202
-
203
- // add the REST endpoints
204
- buildRestEndpoints(pathways, app, server, config);
205
202
 
206
- // if local start server
203
+ // Server Startup Function
207
204
  const startServer = async () => {
208
205
  await server.start();
209
- server.applyMiddleware({ app });
206
+ app.use(
207
+ '/graphql',
208
+
209
+ cors(),
210
+
211
+ expressMiddleware(server, {
212
+ context: async ({ req, res }) => ({ req, res, config, requestState }),
213
+ }),
214
+ );
215
+
216
+ // add the REST endpoints
217
+ buildRestEndpoints(pathways, app, server, config);
210
218
 
211
219
  // Now that our HTTP server is fully set up, we can listen to it.
212
220
  httpServer.listen(config.get('PORT'), () => {
213
- console.log(`🚀 Server is now running at http://localhost:${config.get('PORT')}${server.graphqlPath}`);
221
+ console.log(`🚀 Server is now running at http://localhost:${config.get('PORT')}/graphql`);
214
222
  });
215
223
  };
216
224
 
@@ -6,40 +6,37 @@ import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
6
6
  import LocalModelPlugin from './plugins/localModelPlugin.js';
7
7
  import PalmChatPlugin from './plugins/palmChatPlugin.js';
8
8
  import PalmCompletionPlugin from './plugins/palmCompletionPlugin.js';
9
+ import PalmCodeCompletionPlugin from './plugins/palmCodeCompletionPlugin.js';
9
10
 
10
11
  class PathwayPrompter {
11
- constructor({ config, pathway }) {
12
-
13
- const modelName = pathway.model || config.get('defaultModelName');
14
- const model = config.get('models')[modelName];
15
-
16
- if (!model) {
17
- throw new Error(`Model ${modelName} not found in config`);
18
- }
19
-
12
+ constructor(config, pathway, modelName, model) {
13
+
20
14
  let plugin;
21
15
 
22
16
  switch (model.type) {
23
17
  case 'OPENAI-CHAT':
24
- plugin = new OpenAIChatPlugin(config, pathway);
18
+ plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
25
19
  break;
26
20
  case 'AZURE-TRANSLATE':
27
- plugin = new AzureTranslatePlugin(config, pathway);
21
+ plugin = new AzureTranslatePlugin(config, pathway, modelName, model);
28
22
  break;
29
23
  case 'OPENAI-COMPLETION':
30
- plugin = new OpenAICompletionPlugin(config, pathway);
24
+ plugin = new OpenAICompletionPlugin(config, pathway, modelName, model);
31
25
  break;
32
- case 'OPENAI_WHISPER':
33
- plugin = new OpenAIWhisperPlugin(config, pathway);
26
+ case 'OPENAI-WHISPER':
27
+ plugin = new OpenAIWhisperPlugin(config, pathway, modelName, model);
34
28
  break;
35
29
  case 'LOCAL-CPP-MODEL':
36
- plugin = new LocalModelPlugin(config, pathway);
30
+ plugin = new LocalModelPlugin(config, pathway, modelName, model);
37
31
  break;
38
32
  case 'PALM-CHAT':
39
- plugin = new PalmChatPlugin(config, pathway);
33
+ plugin = new PalmChatPlugin(config, pathway, modelName, model);
40
34
  break;
41
35
  case 'PALM-COMPLETION':
42
- plugin = new PalmCompletionPlugin(config, pathway);
36
+ plugin = new PalmCompletionPlugin(config, pathway, modelName, model);
37
+ break;
38
+ case 'PALM-CODE-COMPLETION':
39
+ plugin = new PalmCodeCompletionPlugin(config, pathway, modelName, model);
43
40
  break;
44
41
  default:
45
42
  throw new Error(`Unsupported model type: ${model.type}`);
@@ -20,9 +20,31 @@ class PathwayResolver {
20
20
  this.warnings = [];
21
21
  this.requestId = uuidv4();
22
22
  this.responseParser = new PathwayResponseParser(pathway);
23
- this.pathwayPrompter = new PathwayPrompter({ config, pathway });
23
+ this.modelName = [
24
+ pathway.model,
25
+ args?.model,
26
+ pathway.inputParameters?.model,
27
+ config.get('defaultModelName')
28
+ ].find(modelName => modelName && config.get('models').hasOwnProperty(modelName));
29
+ this.model = config.get('models')[this.modelName];
30
+
31
+ if (!this.model) {
32
+ throw new Error(`Model ${this.modelName} not found in config`);
33
+ }
34
+
35
+ const specifiedModelName = pathway.model || args?.model || pathway.inputParameters?.model;
36
+
37
+ if (this.modelName !== (specifiedModelName)) {
38
+ if (specifiedModelName) {
39
+ this.logWarning(`Specified model ${specifiedModelName} not found in config, using ${this.modelName} instead.`);
40
+ } else {
41
+ this.logWarning(`No model specified in the pathway, using ${this.modelName}.`);
42
+ }
43
+ }
44
+
24
45
  this.previousResult = '';
25
46
  this.prompts = [];
47
+ this.pathwayPrompter = new PathwayPrompter(this.config, this.pathway, this.modelName, this.model);
26
48
 
27
49
  Object.defineProperty(this, 'pathwayPrompt', {
28
50
  get() {
@@ -42,66 +64,61 @@ class PathwayResolver {
42
64
  }
43
65
 
44
66
  async asyncResolve(args) {
45
- // Wait with a sleep promise for the race condition to resolve
46
- // const results = await Promise.all([this.promptAndParse(args), await new Promise(resolve => setTimeout(resolve, 250))]);
47
- const data = await this.promptAndParse(args);
48
- // Process the results for async
49
- if(args.async || typeof data === 'string') { // if async flag set or processed async and got string response
67
+ const responseData = await this.promptAndParse(args);
68
+
69
+ // Either we're dealing with an async request or a stream
70
+ if(args.async || typeof responseData === 'string') {
50
71
  const { completedCount, totalCount } = requestState[this.requestId];
51
- requestState[this.requestId].data = data;
72
+ requestState[this.requestId].data = responseData;
52
73
  pubsub.publish('REQUEST_PROGRESS', {
53
74
  requestProgress: {
54
75
  requestId: this.requestId,
55
76
  progress: completedCount / totalCount,
56
- data: JSON.stringify(data),
77
+ data: JSON.stringify(responseData),
57
78
  }
58
79
  });
59
- } else { //stream
60
- for (const handle of data) {
61
- handle.on('data', data => {
62
- console.log(data.toString());
63
- const lines = data.toString().split('\n').filter(line => line.trim() !== '');
64
- for (const line of lines) {
65
- const message = line.replace(/^data: /, '');
66
- if (message === '[DONE]') {
67
- // Send stream finished message
68
- pubsub.publish('REQUEST_PROGRESS', {
69
- requestProgress: {
70
- requestId: this.requestId,
71
- data: null,
72
- progress: 1,
73
- }
74
- });
75
- return; // Stream finished
80
+ } else { // stream
81
+ try {
82
+ const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
83
+ incomingMessage.on('data', data => {
84
+ const events = data.toString().split('\n');
85
+
86
+ events.forEach(event => {
87
+ if (event.trim() === '') return; // Skip empty lines
88
+
89
+ const message = event.replace(/^data: /, '');
90
+
91
+ //console.log(`====================================`);
92
+ //console.log(`STREAM EVENT: ${event}`);
93
+ //console.log(`MESSAGE: ${message}`);
94
+
95
+ const requestProgress = {
96
+ requestId: this.requestId,
97
+ data: message,
76
98
  }
99
+
100
+ if (message.trim() === '[DONE]') {
101
+ requestProgress.progress = 1;
102
+ }
103
+
77
104
  try {
78
- const parsed = JSON.parse(message);
79
- const result = this.pathwayPrompter.plugin.parseResponse(parsed)
80
-
81
105
  pubsub.publish('REQUEST_PROGRESS', {
82
- requestProgress: {
83
- requestId: this.requestId,
84
- data: JSON.stringify(result)
85
- }
106
+ requestProgress: requestProgress
86
107
  });
87
108
  } catch (error) {
88
109
  console.error('Could not JSON parse stream message', message, error);
89
110
  }
90
- }
111
+ });
91
112
  });
92
-
93
- // data.on('end', () => {
94
- // console.log("stream done");
95
- // });
113
+ } catch (error) {
114
+ console.error('Could not subscribe to stream', error);
96
115
  }
97
-
98
116
  }
99
117
  }
100
118
 
101
119
  async resolve(args) {
120
+ // Either we're dealing with an async request, stream, or regular request
102
121
  if (args.async || args.stream) {
103
- // Asyncronously process the request
104
- // this.asyncResolve(args);
105
122
  if (!requestState[this.requestId]) {
106
123
  requestState[this.requestId] = {}
107
124
  }
@@ -161,7 +178,7 @@ class PathwayResolver {
161
178
  }
162
179
 
163
180
  // chunk the text and return the chunks with newline separators
164
- return getSemanticChunks(text, chunkTokenLength);
181
+ return getSemanticChunks(text, chunkTokenLength, this.pathway.inputFormat);
165
182
  }
166
183
 
167
184
  truncate(str, n) {
@@ -292,7 +309,7 @@ class PathwayResolver {
292
309
  let result = '';
293
310
 
294
311
  // If this text is empty, skip applying the prompt as it will likely be a nonsensical result
295
- if (!/^\s*$/.test(text)) {
312
+ if (!/^\s*$/.test(text) || parameters?.file) {
296
313
  result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt, this);
297
314
  } else {
298
315
  result = text;
@@ -2,8 +2,8 @@
2
2
  import ModelPlugin from './modelPlugin.js';
3
3
 
4
4
  class AzureTranslatePlugin extends ModelPlugin {
5
- constructor(config, pathway) {
6
- super(config, pathway);
5
+ constructor(config, pathway, modelName, model) {
6
+ super(config, pathway, modelName, model);
7
7
  }
8
8
 
9
9
  // Set up parameters specific to the Azure Translate API
@@ -4,8 +4,8 @@ import { execFileSync } from 'child_process';
4
4
  import { encode } from 'gpt-3-encoder';
5
5
 
6
6
  class LocalModelPlugin extends ModelPlugin {
7
- constructor(config, pathway) {
8
- super(config, pathway);
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
9
9
  }
10
10
 
11
11
  // if the input starts with a chatML response, just return that
@@ -6,19 +6,13 @@ import { encode } from 'gpt-3-encoder';
6
6
  import { getFirstNToken } from '../chunker.js';
7
7
 
8
8
  const DEFAULT_MAX_TOKENS = 4096;
9
+ const DEFAULT_MAX_RETURN_TOKENS = 256;
9
10
  const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
10
11
 
11
12
  class ModelPlugin {
12
- constructor(config, pathway) {
13
- // If the pathway specifies a model, use that, otherwise use the default
14
- this.modelName = pathway.model || config.get('defaultModelName');
15
- // Get the model from the config
16
- this.model = config.get('models')[this.modelName];
17
- // If the model doesn't exist, throw an exception
18
- if (!this.model) {
19
- throw new Error(`Model ${this.modelName} not found in config`);
20
- }
21
-
13
+ constructor(config, pathway, modelName, model) {
14
+ this.modelName = modelName;
15
+ this.model = model;
22
16
  this.config = config;
23
17
  this.environmentVariables = config.getEnv();
24
18
  this.temperature = pathway.temperature;
@@ -143,6 +137,10 @@ class ModelPlugin {
143
137
  return (this.promptParameters.maxTokenLength ?? this.model.maxTokenLength ?? DEFAULT_MAX_TOKENS);
144
138
  }
145
139
 
140
+ getModelMaxReturnTokens() {
141
+ return (this.promptParameters.maxReturnTokens ?? this.model.maxReturnTokens ?? DEFAULT_MAX_RETURN_TOKENS);
142
+ }
143
+
146
144
  getPromptTokenRatio() {
147
145
  // TODO: Is this the right order of precedence? inputParameters should maybe be second?
148
146
  return this.promptParameters.inputParameters?.tokenRatio ?? this.promptParameters.tokenRatio ?? DEFAULT_PROMPT_TOKEN_RATIO;
@@ -3,8 +3,8 @@ import ModelPlugin from './modelPlugin.js';
3
3
  import { encode } from 'gpt-3-encoder';
4
4
 
5
5
  class OpenAIChatPlugin extends ModelPlugin {
6
- constructor(config, pathway) {
7
- super(config, pathway);
6
+ constructor(config, pathway, modelName, model) {
7
+ super(config, pathway, modelName, model);
8
8
  }
9
9
 
10
10
  // convert to OpenAI messages array format if necessary
@@ -90,7 +90,7 @@ class OpenAIChatPlugin extends ModelPlugin {
90
90
  parseResponse(data) {
91
91
  const { choices } = data;
92
92
  if (!choices || !choices.length) {
93
- return null;
93
+ return data;
94
94
  }
95
95
 
96
96
  // if we got a choices array back with more than one choice, return the whole array
@@ -108,8 +108,9 @@ class OpenAIChatPlugin extends ModelPlugin {
108
108
  const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
109
  console.log(separator);
110
110
 
111
- if (data && data.messages && data.messages.length > 1) {
112
- data.messages.forEach((message, index) => {
111
+ const { stream, messages } = data;
112
+ if (messages && messages.length > 1) {
113
+ messages.forEach((message, index) => {
113
114
  const words = message.content.split(" ");
114
115
  const tokenCount = encode(message.content).length;
115
116
  const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
@@ -117,11 +118,15 @@ class OpenAIChatPlugin extends ModelPlugin {
117
118
  console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
118
119
  });
119
120
  } else {
120
- console.log(`\x1b[36m${data.messages[0].content}\x1b[0m`);
121
+ console.log(`\x1b[36m${messages[0].content}\x1b[0m`);
121
122
  }
122
123
 
123
- console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
124
-
124
+ if (stream) {
125
+ console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
126
+ } else {
127
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
128
+ }
129
+
125
130
  prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
126
131
  }
127
132
  }
@@ -15,8 +15,8 @@ const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, tar
15
15
  }
16
16
 
17
17
  class OpenAICompletionPlugin extends ModelPlugin {
18
- constructor(config, pathway) {
19
- super(config, pathway);
18
+ constructor(config, pathway, modelName, model) {
19
+ super(config, pathway, modelName, model);
20
20
  }
21
21
 
22
22
  // Set up parameters specific to the OpenAI Completion API
@@ -108,10 +108,16 @@ class OpenAICompletionPlugin extends ModelPlugin {
108
108
  const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
109
  console.log(separator);
110
110
 
111
+ const stream = data.stream;
111
112
  const modelInput = data.prompt;
112
113
 
113
114
  console.log(`\x1b[36m${modelInput}\x1b[0m`);
114
- console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
115
+
116
+ if (stream) {
117
+ console.log(`\x1b[34m> Response is streaming...\x1b[0m`);
118
+ } else {
119
+ console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
120
+ }
115
121
 
116
122
  prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
117
123
  }
@@ -19,6 +19,7 @@ const pipeline = promisify(stream.pipeline);
19
19
 
20
20
 
21
21
  const API_URL = config.get('whisperMediaApiUrl');
22
+ const WHISPER_TS_API_URL = config.get('whisperTSApiUrl');
22
23
 
23
24
  function alignSubtitles(subtitles) {
24
25
  const result = [];
@@ -74,14 +75,14 @@ const downloadFile = async (fileUrl) => {
74
75
  fs.unlink(localFilePath, () => {
75
76
  reject(error);
76
77
  });
77
- throw error;
78
+ //throw error;
78
79
  }
79
80
  });
80
81
  };
81
82
 
82
83
  class OpenAIWhisperPlugin extends ModelPlugin {
83
- constructor(config, pathway) {
84
- super(config, pathway);
84
+ constructor(config, pathway, modelName, model) {
85
+ super(config, pathway, modelName, model);
85
86
  }
86
87
 
87
88
  async getMediaChunks(file, requestId) {
@@ -115,11 +116,28 @@ class OpenAIWhisperPlugin extends ModelPlugin {
115
116
 
116
117
  // Execute the request to the OpenAI Whisper API
117
118
  async execute(text, parameters, prompt, pathwayResolver) {
118
- const { responseFormat } = parameters;
119
+ const { responseFormat, wordTimestamped } = parameters;
119
120
  const url = this.requestUrl(text);
120
121
  const params = {};
121
122
  const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
122
123
 
124
+ const processTS = async (uri) => {
125
+ if (wordTimestamped) {
126
+ if (!WHISPER_TS_API_URL) {
127
+ throw new Error(`WHISPER_TS_API_URL not set for word timestamped processing`);
128
+ }
129
+
130
+ try {
131
+ // const res = await axios.post(WHISPER_TS_API_URL, { params: { fileurl: uri } });
132
+ const res = await this.executeRequest(WHISPER_TS_API_URL, {fileurl:uri},{},{});
133
+ return res;
134
+ } catch (err) {
135
+ console.log(`Error getting word timestamped data from api:`, err);
136
+ throw err;
137
+ }
138
+ }
139
+ }
140
+
123
141
  const processChunk = async (chunk) => {
124
142
  try {
125
143
  const { language, responseFormat } = parameters;
@@ -159,7 +177,6 @@ class OpenAIWhisperPlugin extends ModelPlugin {
159
177
 
160
178
  let chunks = []; // array of local file paths
161
179
  try {
162
-
163
180
  const uris = await this.getMediaChunks(file, requestId); // array of remote file uris
164
181
  if (!uris || !uris.length) {
165
182
  throw new Error(`Error in getting chunks from media helper for file ${file}`);
@@ -169,7 +186,13 @@ class OpenAIWhisperPlugin extends ModelPlugin {
169
186
 
170
187
  // sequential download of chunks
171
188
  for (const uri of uris) {
172
- chunks.push(await downloadFile(uri));
189
+ if (wordTimestamped) { // get word timestamped data
190
+ sendProgress(); // no download needed auto progress
191
+ const ts = await processTS(uri);
192
+ result.push(ts);
193
+ } else {
194
+ chunks.push(await downloadFile(uri));
195
+ }
173
196
  sendProgress();
174
197
  }
175
198
 
@@ -210,7 +233,7 @@ class OpenAIWhisperPlugin extends ModelPlugin {
210
233
  }
211
234
  }
212
235
 
213
- if (['srt','vtt'].includes(responseFormat)) { // align subtitles for formats
236
+ if (['srt','vtt'].includes(responseFormat) || wordTimestamped) { // align subtitles for formats
214
237
  return alignSubtitles(result);
215
238
  }
216
239
  return result.join(` `);
@@ -4,8 +4,8 @@ import { encode } from 'gpt-3-encoder';
4
4
  import HandleBars from '../../lib/handleBars.js';
5
5
 
6
6
  class PalmChatPlugin extends ModelPlugin {
7
- constructor(config, pathway) {
8
- super(config, pathway);
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
9
9
  }
10
10
 
11
11
  // Convert to PaLM messages array format if necessary
@@ -92,10 +92,8 @@ class PalmChatPlugin extends ModelPlugin {
92
92
  const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
93
93
  const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
94
94
 
95
- // For PaLM right now, the max return tokens is 1024, regardless of the max context length
96
- // I can't think of a time you'd want to constrain it to fewer at the moment.
97
- const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
98
-
95
+ const max_tokens = this.getModelMaxReturnTokens();
96
+
99
97
  if (max_tokens < 0) {
100
98
  throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
101
99
  }