@aj-archipelago/cortex 1.0.6 → 1.0.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -58,7 +58,8 @@
58
58
  "Content-Type": "application/json"
59
59
  },
60
60
  "requestsPerSecond": 10,
61
- "maxTokenLength": 2048
61
+ "maxTokenLength": 2048,
62
+ "maxReturnTokens": 1024
62
63
  },
63
64
  "palm-chat": {
64
65
  "type": "PALM-CHAT",
@@ -67,7 +68,8 @@
67
68
  "Content-Type": "application/json"
68
69
  },
69
70
  "requestsPerSecond": 10,
70
- "maxTokenLength": 2048
71
+ "maxTokenLength": 2048,
72
+ "maxReturnTokens": 1024
71
73
  },
72
74
  "local-llama13B": {
73
75
  "type": "LOCAL-CPP-MODEL",
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aj-archipelago/cortex",
3
- "version": "1.0.6",
3
+ "version": "1.0.8",
4
4
  "description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
5
5
  "repository": {
6
6
  "type": "git",
@@ -36,6 +36,7 @@
36
36
  "axios": "^1.3.4",
37
37
  "axios-cache-interceptor": "^1.0.1",
38
38
  "bottleneck": "^2.19.5",
39
+ "cheerio": "^1.0.0-rc.12",
39
40
  "compromise": "^14.8.1",
40
41
  "compromise-paragraphs": "^0.1.0",
41
42
  "convict": "^6.2.3",
@@ -14,6 +14,7 @@ export default {
14
14
  typeDef,
15
15
  rootResolver,
16
16
  resolver,
17
+ inputFormat: 'text',
17
18
  useInputChunking: true,
18
19
  useParallelChunkProcessing: false,
19
20
  useInputSummarization: false,
package/server/chunker.js CHANGED
@@ -1,4 +1,5 @@
1
1
  import { encode, decode } from 'gpt-3-encoder';
2
+ import cheerio from 'cheerio';
2
3
 
3
4
  const getLastNToken = (text, maxTokenLen) => {
4
5
  const encoded = encode(text);
@@ -18,8 +19,18 @@ const getFirstNToken = (text, maxTokenLen) => {
18
19
  return text;
19
20
  }
20
21
 
21
- const getSemanticChunks = (text, chunkSize) => {
22
+ const determineTextFormat = (text) => {
23
+ const htmlTagPattern = /<[^>]*>/g;
24
+
25
+ if (htmlTagPattern.test(text)) {
26
+ return 'html';
27
+ }
28
+ else {
29
+ return 'text';
30
+ }
31
+ }
22
32
 
33
+ const getSemanticChunks = (text, chunkSize, inputFormat = 'text') => {
23
34
  const breakByRegex = (str, regex, preserveWhitespace = false) => {
24
35
  const result = [];
25
36
  let match;
@@ -46,6 +57,19 @@ const getSemanticChunks = (text, chunkSize) => {
46
57
  const breakBySentences = (str) => breakByRegex(str, /(?<=[.。؟!?!\n])\s+/, true);
47
58
  const breakByWords = (str) => breakByRegex(str, /(\s,;:.+)/);
48
59
 
60
+ const breakByHtmlElements = (str) => {
61
+ const $ = cheerio.load(str, null, true);
62
+
63
+ // the .filter() call is important to get the text nodes
64
+ // https://stackoverflow.com/questions/54878673/cheerio-get-normal-text-nodes
65
+ let rootNodes = $('body').contents();
66
+
67
+ // create an array with the outerHTML of each node
68
+ const nodes = rootNodes.map((i, el) => $(el).prop('outerHTML') || $(el).text()).get();
69
+
70
+ return nodes;
71
+ };
72
+
49
73
  const createChunks = (tokens) => {
50
74
  let chunks = [];
51
75
  let currentChunk = '';
@@ -115,7 +139,28 @@ const getSemanticChunks = (text, chunkSize) => {
115
139
  return createChunks([...str]); // Split by characters
116
140
  };
117
141
 
118
- return breakText(text);
142
+ if (inputFormat === 'html') {
143
+ const tokens = breakByHtmlElements(text);
144
+ let chunks = createChunks(tokens);
145
+ chunks = combineChunks(chunks);
146
+
147
+ chunks = chunks.flatMap(chunk => {
148
+ if (determineTextFormat(chunk) === 'text') {
149
+ return getSemanticChunks(chunk, chunkSize);
150
+ } else {
151
+ return chunk;
152
+ }
153
+ });
154
+
155
+ if (chunks.some(chunk => encode(chunk).length > chunkSize)) {
156
+ throw new Error('The HTML contains elements that are larger than the chunk size. Please try again with HTML that has smaller elements.');
157
+ }
158
+
159
+ return chunks;
160
+ }
161
+ else {
162
+ return breakText(text);
163
+ }
119
164
  }
120
165
 
121
166
 
@@ -133,5 +178,5 @@ const semanticTruncate = (text, maxLength) => {
133
178
  };
134
179
 
135
180
  export {
136
- getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken
181
+ getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken, determineTextFormat
137
182
  };
package/server/graphql.js CHANGED
@@ -164,7 +164,13 @@ const build = async (config) => {
164
164
  const cortexApiKey = config.get('cortexApiKey');
165
165
  if (cortexApiKey) {
166
166
  app.use((req, res, next) => {
167
- if (cortexApiKey && req.headers['Cortex-Api-Key'] !== cortexApiKey && req.query['Cortex-Api-Key'] !== cortexApiKey) {
167
+ let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key'];
168
+ if (!providedApiKey) {
169
+ providedApiKey = req.headers['authorization'];
170
+ providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey;
171
+ }
172
+
173
+ if (cortexApiKey && cortexApiKey !== providedApiKey) {
168
174
  if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') {
169
175
  res.status(401)
170
176
  .set('WWW-Authenticate', 'Cortex-Api-Key')
@@ -6,40 +6,37 @@ import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
6
6
  import LocalModelPlugin from './plugins/localModelPlugin.js';
7
7
  import PalmChatPlugin from './plugins/palmChatPlugin.js';
8
8
  import PalmCompletionPlugin from './plugins/palmCompletionPlugin.js';
9
+ import PalmCodeCompletionPlugin from './plugins/palmCodeCompletionPlugin.js';
9
10
 
10
11
  class PathwayPrompter {
11
- constructor({ config, pathway }) {
12
-
13
- const modelName = pathway.model || config.get('defaultModelName');
14
- const model = config.get('models')[modelName];
15
-
16
- if (!model) {
17
- throw new Error(`Model ${modelName} not found in config`);
18
- }
19
-
12
+ constructor(config, pathway, modelName, model) {
13
+
20
14
  let plugin;
21
15
 
22
16
  switch (model.type) {
23
17
  case 'OPENAI-CHAT':
24
- plugin = new OpenAIChatPlugin(config, pathway);
18
+ plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
25
19
  break;
26
20
  case 'AZURE-TRANSLATE':
27
- plugin = new AzureTranslatePlugin(config, pathway);
21
+ plugin = new AzureTranslatePlugin(config, pathway, modelName, model);
28
22
  break;
29
23
  case 'OPENAI-COMPLETION':
30
- plugin = new OpenAICompletionPlugin(config, pathway);
24
+ plugin = new OpenAICompletionPlugin(config, pathway, modelName, model);
31
25
  break;
32
26
  case 'OPENAI-WHISPER':
33
- plugin = new OpenAIWhisperPlugin(config, pathway);
27
+ plugin = new OpenAIWhisperPlugin(config, pathway, modelName, model);
34
28
  break;
35
29
  case 'LOCAL-CPP-MODEL':
36
- plugin = new LocalModelPlugin(config, pathway);
30
+ plugin = new LocalModelPlugin(config, pathway, modelName, model);
37
31
  break;
38
32
  case 'PALM-CHAT':
39
- plugin = new PalmChatPlugin(config, pathway);
33
+ plugin = new PalmChatPlugin(config, pathway, modelName, model);
40
34
  break;
41
35
  case 'PALM-COMPLETION':
42
- plugin = new PalmCompletionPlugin(config, pathway);
36
+ plugin = new PalmCompletionPlugin(config, pathway, modelName, model);
37
+ break;
38
+ case 'PALM-CODE-COMPLETION':
39
+ plugin = new PalmCodeCompletionPlugin(config, pathway, modelName, model);
43
40
  break;
44
41
  default:
45
42
  throw new Error(`Unsupported model type: ${model.type}`);
@@ -20,9 +20,31 @@ class PathwayResolver {
20
20
  this.warnings = [];
21
21
  this.requestId = uuidv4();
22
22
  this.responseParser = new PathwayResponseParser(pathway);
23
- this.pathwayPrompter = new PathwayPrompter({ config, pathway });
23
+ this.modelName = [
24
+ pathway.model,
25
+ args?.model,
26
+ pathway.inputParameters?.model,
27
+ config.get('defaultModelName')
28
+ ].find(modelName => modelName && config.get('models').hasOwnProperty(modelName));
29
+ this.model = config.get('models')[this.modelName];
30
+
31
+ if (!this.model) {
32
+ throw new Error(`Model ${this.modelName} not found in config`);
33
+ }
34
+
35
+ const specifiedModelName = pathway.model || args?.model || pathway.inputParameters?.model;
36
+
37
+ if (this.modelName !== (specifiedModelName)) {
38
+ if (specifiedModelName) {
39
+ this.logWarning(`Specified model ${specifiedModelName} not found in config, using ${this.modelName} instead.`);
40
+ } else {
41
+ this.logWarning(`No model specified in the pathway, using ${this.modelName}.`);
42
+ }
43
+ }
44
+
24
45
  this.previousResult = '';
25
46
  this.prompts = [];
47
+ this.pathwayPrompter = new PathwayPrompter(this.config, this.pathway, this.modelName, this.model);
26
48
 
27
49
  Object.defineProperty(this, 'pathwayPrompt', {
28
50
  get() {
@@ -56,37 +78,41 @@ class PathwayResolver {
56
78
  }
57
79
  });
58
80
  } else { // stream
59
- const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
60
- incomingMessage.on('data', data => {
61
- const events = data.toString().split('\n');
62
-
63
- events.forEach(event => {
64
- if (event.trim() === '') return; // Skip empty lines
65
-
66
- const message = event.replace(/^data: /, '');
67
-
68
- //console.log(`====================================`);
69
- //console.log(`STREAM EVENT: ${event}`);
70
- //console.log(`MESSAGE: ${message}`);
71
-
72
- const requestProgress = {
73
- requestId: this.requestId,
74
- data: message,
75
- }
76
-
77
- if (message.trim() === '[DONE]') {
78
- requestProgress.progress = 1;
79
- }
80
-
81
- try {
82
- pubsub.publish('REQUEST_PROGRESS', {
83
- requestProgress: requestProgress
84
- });
85
- } catch (error) {
86
- console.error('Could not JSON parse stream message', message, error);
87
- }
81
+ try {
82
+ const incomingMessage = Array.isArray(responseData) && responseData.length > 0 ? responseData[0] : responseData;
83
+ incomingMessage.on('data', data => {
84
+ const events = data.toString().split('\n');
85
+
86
+ events.forEach(event => {
87
+ if (event.trim() === '') return; // Skip empty lines
88
+
89
+ const message = event.replace(/^data: /, '');
90
+
91
+ //console.log(`====================================`);
92
+ //console.log(`STREAM EVENT: ${event}`);
93
+ //console.log(`MESSAGE: ${message}`);
94
+
95
+ const requestProgress = {
96
+ requestId: this.requestId,
97
+ data: message,
98
+ }
99
+
100
+ if (message.trim() === '[DONE]') {
101
+ requestProgress.progress = 1;
102
+ }
103
+
104
+ try {
105
+ pubsub.publish('REQUEST_PROGRESS', {
106
+ requestProgress: requestProgress
107
+ });
108
+ } catch (error) {
109
+ console.error('Could not JSON parse stream message', message, error);
110
+ }
111
+ });
88
112
  });
89
- });
113
+ } catch (error) {
114
+ console.error('Could not subscribe to stream', error);
115
+ }
90
116
  }
91
117
  }
92
118
 
@@ -152,7 +178,7 @@ class PathwayResolver {
152
178
  }
153
179
 
154
180
  // chunk the text and return the chunks with newline separators
155
- return getSemanticChunks(text, chunkTokenLength);
181
+ return getSemanticChunks(text, chunkTokenLength, this.pathway.inputFormat);
156
182
  }
157
183
 
158
184
  truncate(str, n) {
@@ -2,8 +2,8 @@
2
2
  import ModelPlugin from './modelPlugin.js';
3
3
 
4
4
  class AzureTranslatePlugin extends ModelPlugin {
5
- constructor(config, pathway) {
6
- super(config, pathway);
5
+ constructor(config, pathway, modelName, model) {
6
+ super(config, pathway, modelName, model);
7
7
  }
8
8
 
9
9
  // Set up parameters specific to the Azure Translate API
@@ -24,8 +24,9 @@ class AzureTranslatePlugin extends ModelPlugin {
24
24
  }
25
25
 
26
26
  // Execute the request to the Azure Translate API
27
- async execute(text, parameters, prompt) {
27
+ async execute(text, parameters, prompt, pathwayResolver) {
28
28
  const requestParameters = this.getRequestParameters(text, parameters, prompt);
29
+ const requestId = pathwayResolver?.requestId;
29
30
 
30
31
  const url = this.requestUrl(text);
31
32
 
@@ -33,7 +34,7 @@ class AzureTranslatePlugin extends ModelPlugin {
33
34
  const params = requestParameters.params;
34
35
  const headers = this.model.headers || {};
35
36
 
36
- return this.executeRequest(url, data, params, headers, prompt);
37
+ return this.executeRequest(url, data, params, headers, prompt, requestId);
37
38
  }
38
39
 
39
40
  // Parse the response from the Azure Translate API
@@ -47,8 +48,7 @@ class AzureTranslatePlugin extends ModelPlugin {
47
48
 
48
49
  // Override the logging function to display the request and response
49
50
  logRequestData(data, responseData, prompt) {
50
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
51
- console.log(separator);
51
+ this.logAIRequestFinished();
52
52
 
53
53
  const modelInput = data[0].Text;
54
54
 
@@ -4,8 +4,8 @@ import { execFileSync } from 'child_process';
4
4
  import { encode } from 'gpt-3-encoder';
5
5
 
6
6
  class LocalModelPlugin extends ModelPlugin {
7
- constructor(config, pathway) {
8
- super(config, pathway);
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
9
9
  }
10
10
 
11
11
  // if the input starts with a chatML response, just return that
@@ -6,19 +6,13 @@ import { encode } from 'gpt-3-encoder';
6
6
  import { getFirstNToken } from '../chunker.js';
7
7
 
8
8
  const DEFAULT_MAX_TOKENS = 4096;
9
+ const DEFAULT_MAX_RETURN_TOKENS = 256;
9
10
  const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
10
11
 
11
12
  class ModelPlugin {
12
- constructor(config, pathway) {
13
- // If the pathway specifies a model, use that, otherwise use the default
14
- this.modelName = pathway.model || config.get('defaultModelName');
15
- // Get the model from the config
16
- this.model = config.get('models')[this.modelName];
17
- // If the model doesn't exist, throw an exception
18
- if (!this.model) {
19
- throw new Error(`Model ${this.modelName} not found in config`);
20
- }
21
-
13
+ constructor(config, pathway, modelName, model) {
14
+ this.modelName = modelName;
15
+ this.model = model;
22
16
  this.config = config;
23
17
  this.environmentVariables = config.getEnv();
24
18
  this.temperature = pathway.temperature;
@@ -36,7 +30,8 @@ class ModelPlugin {
36
30
  }
37
31
  }
38
32
 
39
- this.requestCount = 1;
33
+ this.requestCount = 0;
34
+ this.lastRequestStartTime = new Date();
40
35
  this.shouldCache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
41
36
  }
42
37
 
@@ -143,6 +138,10 @@ class ModelPlugin {
143
138
  return (this.promptParameters.maxTokenLength ?? this.model.maxTokenLength ?? DEFAULT_MAX_TOKENS);
144
139
  }
145
140
 
141
+ getModelMaxReturnTokens() {
142
+ return (this.promptParameters.maxReturnTokens ?? this.model.maxReturnTokens ?? DEFAULT_MAX_RETURN_TOKENS);
143
+ }
144
+
146
145
  getPromptTokenRatio() {
147
146
  // TODO: Is this the right order of precedence? inputParameters should maybe be second?
148
147
  return this.promptParameters.inputParameters?.tokenRatio ?? this.promptParameters.tokenRatio ?? DEFAULT_PROMPT_TOKEN_RATIO;
@@ -201,10 +200,24 @@ class ModelPlugin {
201
200
  parseResponse(data) { return data; };
202
201
 
203
202
  // Default simple logging
203
+ logRequestStart(url, data) {
204
+ this.requestCount++;
205
+ const logMessage = `>>> [${this.requestId}: ${this.pathwayName}.${this.requestCount}] request`;
206
+ const header = '>'.repeat(logMessage.length);
207
+ console.log(`\n${header}\n${logMessage}`);
208
+ console.log(`>>> Making API request to ${url}`);
209
+ };
210
+
211
+ logAIRequestFinished() {
212
+ const currentTime = new Date();
213
+ const timeElapsed = (currentTime - this.lastRequestStartTime) / 1000;
214
+ const logMessage = `<<< [${this.requestId}: ${this.pathwayName}.${this.requestCount}] response - complete in ${timeElapsed}s - data:`;
215
+ const header = '<'.repeat(logMessage.length);
216
+ console.log(`\n${header}\n${logMessage}\n`);
217
+ };
218
+
204
219
  logRequestData(data, responseData, prompt) {
205
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
206
- console.log(separator);
207
-
220
+ this.logAIRequestFinished();
208
221
  const modelInput = data.prompt || (data.messages && data.messages[0].content) || (data.length > 0 && data[0].Text) || null;
209
222
 
210
223
  if (modelInput) {
@@ -216,7 +229,10 @@ class ModelPlugin {
216
229
  prompt && prompt.debugInfo && (prompt.debugInfo += `${separator}${JSON.stringify(data)}`);
217
230
  }
218
231
 
219
- async executeRequest(url, data, params, headers, prompt) {
232
+ async executeRequest(url, data, params, headers, prompt, requestId) {
233
+ this.aiRequestStartTime = new Date();
234
+ this.requestId = requestId;
235
+ this.logRequestStart(url, data);
220
236
  const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName);
221
237
 
222
238
  if (responseData.error) {
@@ -3,8 +3,8 @@ import ModelPlugin from './modelPlugin.js';
3
3
  import { encode } from 'gpt-3-encoder';
4
4
 
5
5
  class OpenAIChatPlugin extends ModelPlugin {
6
- constructor(config, pathway) {
7
- super(config, pathway);
6
+ constructor(config, pathway, modelName, model) {
7
+ super(config, pathway, modelName, model);
8
8
  }
9
9
 
10
10
  // convert to OpenAI messages array format if necessary
@@ -76,14 +76,15 @@ class OpenAIChatPlugin extends ModelPlugin {
76
76
  }
77
77
 
78
78
  // Execute the request to the OpenAI Chat API
79
- async execute(text, parameters, prompt) {
79
+ async execute(text, parameters, prompt, pathwayResolver) {
80
80
  const url = this.requestUrl(text);
81
81
  const requestParameters = this.getRequestParameters(text, parameters, prompt);
82
+ const requestId = pathwayResolver?.requestId;
82
83
 
83
84
  const data = { ...(this.model.params || {}), ...requestParameters };
84
85
  const params = {};
85
86
  const headers = this.model.headers || {};
86
- return this.executeRequest(url, data, params, headers, prompt);
87
+ return this.executeRequest(url, data, params, headers, prompt, requestId);
87
88
  }
88
89
 
89
90
  // Parse the response from the OpenAI Chat API
@@ -105,8 +106,7 @@ class OpenAIChatPlugin extends ModelPlugin {
105
106
 
106
107
  // Override the logging function to display the messages and responses
107
108
  logRequestData(data, responseData, prompt) {
108
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
- console.log(separator);
109
+ this.logAIRequestFinished();
110
110
 
111
111
  const { stream, messages } = data;
112
112
  if (messages && messages.length > 1) {
@@ -1,5 +1,6 @@
1
1
  // OpenAICompletionPlugin.js
2
2
 
3
+ import { request } from 'https';
3
4
  import ModelPlugin from './modelPlugin.js';
4
5
  import { encode } from 'gpt-3-encoder';
5
6
 
@@ -15,8 +16,8 @@ const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, tar
15
16
  }
16
17
 
17
18
  class OpenAICompletionPlugin extends ModelPlugin {
18
- constructor(config, pathway) {
19
- super(config, pathway);
19
+ constructor(config, pathway, modelName, model) {
20
+ super(config, pathway, modelName, model);
20
21
  }
21
22
 
22
23
  // Set up parameters specific to the OpenAI Completion API
@@ -78,12 +79,13 @@ class OpenAICompletionPlugin extends ModelPlugin {
78
79
  async execute(text, parameters, prompt, pathwayResolver) {
79
80
  const url = this.requestUrl(text);
80
81
  const requestParameters = this.getRequestParameters(text, parameters, prompt, pathwayResolver);
81
-
82
+ const requestId = pathwayResolver?.requestId;
83
+
82
84
  const data = { ...(this.model.params || {}), ...requestParameters };
83
85
  const params = {};
84
86
  const headers = this.model.headers || {};
85
87
 
86
- return this.executeRequest(url, data, params, headers, prompt);
88
+ return this.executeRequest(url, data, params, headers, prompt, requestId);
87
89
  }
88
90
 
89
91
  // Parse the response from the OpenAI Completion API
@@ -105,8 +107,7 @@ class OpenAICompletionPlugin extends ModelPlugin {
105
107
 
106
108
  // Override the logging function to log the prompt and response
107
109
  logRequestData(data, responseData, prompt) {
108
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
109
- console.log(separator);
110
+ this.logAIRequestFinished();
110
111
 
111
112
  const stream = data.stream;
112
113
  const modelInput = data.prompt;
@@ -75,14 +75,14 @@ const downloadFile = async (fileUrl) => {
75
75
  fs.unlink(localFilePath, () => {
76
76
  reject(error);
77
77
  });
78
- throw error;
78
+ //throw error;
79
79
  }
80
80
  });
81
81
  };
82
82
 
83
83
  class OpenAIWhisperPlugin extends ModelPlugin {
84
- constructor(config, pathway) {
85
- super(config, pathway);
84
+ constructor(config, pathway, modelName, model) {
85
+ super(config, pathway, modelName, model);
86
86
  }
87
87
 
88
88
  async getMediaChunks(file, requestId) {
@@ -4,8 +4,8 @@ import { encode } from 'gpt-3-encoder';
4
4
  import HandleBars from '../../lib/handleBars.js';
5
5
 
6
6
  class PalmChatPlugin extends ModelPlugin {
7
- constructor(config, pathway) {
8
- super(config, pathway);
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
9
9
  }
10
10
 
11
11
  // Convert to PaLM messages array format if necessary
@@ -92,10 +92,8 @@ class PalmChatPlugin extends ModelPlugin {
92
92
  const context = this.getCompiledContext(text, parameters, prompt.context || palmMessages.context || '');
93
93
  const examples = this.getCompiledExamples(text, parameters, prompt.examples || []);
94
94
 
95
- // For PaLM right now, the max return tokens is 1024, regardless of the max context length
96
- // I can't think of a time you'd want to constrain it to fewer at the moment.
97
- const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
98
-
95
+ const max_tokens = this.getModelMaxReturnTokens();
96
+
99
97
  if (max_tokens < 0) {
100
98
  throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
101
99
  }
@@ -139,9 +137,10 @@ class PalmChatPlugin extends ModelPlugin {
139
137
  }
140
138
 
141
139
  // Execute the request to the PaLM Chat API
142
- async execute(text, parameters, prompt) {
140
+ async execute(text, parameters, prompt, pathwayResolver) {
143
141
  const url = this.requestUrl(text);
144
142
  const requestParameters = this.getRequestParameters(text, parameters, prompt);
143
+ const requestId = pathwayResolver?.requestId;
145
144
 
146
145
  const data = { ...(this.model.params || {}), ...requestParameters };
147
146
  const params = {};
@@ -149,7 +148,7 @@ class PalmChatPlugin extends ModelPlugin {
149
148
  const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
150
149
  const authToken = await gcpAuthTokenHelper.getAccessToken();
151
150
  headers.Authorization = `Bearer ${authToken}`;
152
- return this.executeRequest(url, data, params, headers, prompt);
151
+ return this.executeRequest(url, data, params, headers, prompt, requestId);
153
152
  }
154
153
 
155
154
  // Parse the response from the PaLM Chat API
@@ -183,8 +182,7 @@ class PalmChatPlugin extends ModelPlugin {
183
182
 
184
183
  // Override the logging function to display the messages and responses
185
184
  logRequestData(data, responseData, prompt) {
186
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
187
- console.log(separator);
185
+ this.logAIRequestFinished();
188
186
 
189
187
  const instances = data && data.instances;
190
188
  const messages = instances && instances[0] && instances[0].messages;
@@ -0,0 +1,46 @@
1
+ // palmCodeCompletionPlugin.js
2
+
3
+ import PalmCompletionPlugin from './palmCompletionPlugin.js';
4
+
5
+ // PalmCodeCompletionPlugin class for handling requests and responses to the PaLM API Code Completion API
6
+ class PalmCodeCompletionPlugin extends PalmCompletionPlugin {
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
9
+ }
10
+
11
+ // Set up parameters specific to the PaLM API Code Completion API
12
+ getRequestParameters(text, parameters, prompt, pathwayResolver) {
13
+ const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
14
+ const { stream } = parameters;
15
+ // Define the model's max token length
16
+ const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
17
+
18
+ const truncatedPrompt = this.truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
19
+
20
+ const max_tokens = this.getModelMaxReturnTokens();
21
+
22
+ if (max_tokens < 0) {
23
+ throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
24
+ }
25
+
26
+ if (!truncatedPrompt) {
27
+ throw new Error(`Prompt is empty. The model will not be called.`);
28
+ }
29
+
30
+ const requestParameters = {
31
+ instances: [
32
+ { prefix: truncatedPrompt }
33
+ ],
34
+ parameters: {
35
+ temperature: this.temperature ?? 0.7,
36
+ maxOutputTokens: max_tokens,
37
+ topP: parameters.topP ?? 0.95,
38
+ topK: parameters.topK ?? 40,
39
+ }
40
+ };
41
+
42
+ return requestParameters;
43
+ }
44
+ }
45
+
46
+ export default PalmCodeCompletionPlugin;
@@ -2,23 +2,21 @@
2
2
 
3
3
  import ModelPlugin from './modelPlugin.js';
4
4
 
5
- // Helper function to truncate the prompt if it is too long
6
- const truncatePromptIfNecessary = (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) => {
7
- const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
8
-
9
- if (textTokenCount > maxAllowedTokens) {
10
- pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
11
- return pathwayResolver.truncate(text, maxAllowedTokens);
12
- }
13
- return text;
14
- }
15
-
16
5
  // PalmCompletionPlugin class for handling requests and responses to the PaLM API Text Completion API
17
6
  class PalmCompletionPlugin extends ModelPlugin {
18
- constructor(config, pathway) {
19
- super(config, pathway);
7
+ constructor(config, pathway, modelName, model) {
8
+ super(config, pathway, modelName, model);
20
9
  }
21
10
 
11
+ truncatePromptIfNecessary (text, textTokenCount, modelMaxTokenCount, targetTextTokenCount, pathwayResolver) {
12
+ const maxAllowedTokens = textTokenCount + ((modelMaxTokenCount - targetTextTokenCount) * 0.5);
13
+
14
+ if (textTokenCount > maxAllowedTokens) {
15
+ pathwayResolver.logWarning(`Prompt is too long at ${textTokenCount} tokens (this target token length for this pathway is ${targetTextTokenCount} tokens because the response is expected to take up the rest of the model's max tokens (${modelMaxTokenCount}). Prompt will be truncated.`);
16
+ return pathwayResolver.truncate(text, maxAllowedTokens);
17
+ }
18
+ return text;
19
+ }
22
20
  // Set up parameters specific to the PaLM API Text Completion API
23
21
  getRequestParameters(text, parameters, prompt, pathwayResolver) {
24
22
  const { modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
@@ -26,9 +24,9 @@ class PalmCompletionPlugin extends ModelPlugin {
26
24
  // Define the model's max token length
27
25
  const modelTargetTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
28
26
 
29
- const truncatedPrompt = truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
27
+ const truncatedPrompt = this.truncatePromptIfNecessary(modelPromptText, tokenLength, this.getModelMaxTokenLength(), modelTargetTokenLength, pathwayResolver);
30
28
 
31
- const max_tokens = 1024//this.getModelMaxTokenLength() - tokenLength;
29
+ const max_tokens = this.getModelMaxReturnTokens();
32
30
 
33
31
  if (max_tokens < 0) {
34
32
  throw new Error(`Prompt is too long to successfully call the model at ${tokenLength} tokens. The model will not be called.`);
@@ -57,6 +55,7 @@ class PalmCompletionPlugin extends ModelPlugin {
57
55
  async execute(text, parameters, prompt, pathwayResolver) {
58
56
  const url = this.requestUrl(text);
59
57
  const requestParameters = this.getRequestParameters(text, parameters, prompt, pathwayResolver);
58
+ const requestId = pathwayResolver?.requestId;
60
59
 
61
60
  const data = { ...requestParameters };
62
61
  const params = {};
@@ -64,7 +63,7 @@ class PalmCompletionPlugin extends ModelPlugin {
64
63
  const gcpAuthTokenHelper = this.config.get('gcpAuthTokenHelper');
65
64
  const authToken = await gcpAuthTokenHelper.getAccessToken();
66
65
  headers.Authorization = `Bearer ${authToken}`;
67
- return this.executeRequest(url, data, params, headers, prompt);
66
+ return this.executeRequest(url, data, params, headers, prompt, requestId);
68
67
  }
69
68
 
70
69
  // Parse the response from the PaLM API Text Completion API
@@ -107,8 +106,7 @@ class PalmCompletionPlugin extends ModelPlugin {
107
106
 
108
107
  // Override the logging function to log the prompt and response
109
108
  logRequestData(data, responseData, prompt) {
110
- const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
111
- console.log(separator);
109
+ this.logAIRequestFinished();
112
110
 
113
111
  const safetyAttributes = this.getSafetyAttributes(responseData);
114
112
 
@@ -1,5 +1,6 @@
1
1
  import test from 'ava';
2
- import { getSemanticChunks } from '../server/chunker.js';
2
+ import { getSemanticChunks, determineTextFormat } from '../server/chunker.js';
3
+
3
4
  import { encode } from 'gpt-3-encoder';
4
5
 
5
6
  const testText = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. In id erat sem. Phasellus ac dapibus purus, in fermentum nunc. Mauris quis rutrum magna. Quisque rutrum, augue vel blandit posuere, augue magna convallis turpis, nec elementum augue mauris sit amet nunc. Aenean sit amet leo est. Nunc ante ex, blandit et felis ut, iaculis lacinia est. Phasellus dictum orci id libero ullamcorper tempor.
@@ -69,34 +70,119 @@ test('should return identical text that chunker was passed, given tiny chunk siz
69
70
  t.is(recomposedText, testText); //check recomposition
70
71
  });
71
72
 
72
- /*
73
- it('should return identical text that chunker was passed, given tiny chunk size (1)', () => {
74
- const maxChunkToken = 1;
75
- const chunks = getSemanticChunks(testText, maxChunkToken);
76
- expect(chunks.length).toBeGreaterThan(1); //check chunking
77
- expect(chunks.every(chunk => encode(chunk).length <= maxChunkToken)).toBe(true); //check chunk size
78
- const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
79
- expect(recomposedText).toBe(testText); //check recomposition
73
+ const htmlChunkOne = `<p>Lorem ipsum dolor sit amet, consectetur adipiscing elit. <a href="https://www.google.com">Google</a></p> Vivamus id pharetra odio. Sed consectetur leo sed tortor dictum venenatis.Donec gravida libero non accumsan suscipit.Donec lectus turpis, ullamcorper eu pulvinar iaculis, ornare ut risus.Phasellus aliquam, turpis quis viverra condimentum, risus est pretium metus, in porta ipsum tortor vitae elit.Pellentesque id finibus erat. In suscipit, sapien non posuere dignissim, augue nisl ultrices tortor, sit amet eleifend nibh elit at risus.`
74
+ const htmlVoidElement = `<br>`
75
+ const htmlChunkTwo = `<p><img src="https://www.google.com/googlelogo_color_272x92dp.png"></p>`
76
+ const htmlSelfClosingElement = `<img src="https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png" />`
77
+ const plainTextChunk = 'Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae; Fusce at dignissim quam.'
78
+
79
+ test('should throw an error if html cannot be accommodated within the chunk size', async t => {
80
+ const chunkSize = encode(htmlChunkTwo).length;
81
+ const error = t.throws(() => getSemanticChunks(htmlChunkTwo, chunkSize - 1, 'html'));
82
+ t.is(error.message, 'The HTML contains elements that are larger than the chunk size. Please try again with HTML that has smaller elements.');
83
+ });
84
+
85
+ test('should chunk text between html elements if needed', async t => {
86
+ const chunkSize = encode(htmlChunkTwo).length;
87
+ const chunks = getSemanticChunks(htmlChunkTwo + plainTextChunk + htmlChunkTwo, chunkSize, 'html');
88
+
89
+ t.is(chunks.length, 4);
90
+ t.is(chunks[0], htmlChunkTwo);
91
+ t.is(chunks[1], 'Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae');
92
+ t.is(encode(chunks[1]).length, chunkSize);
93
+ t.is(chunks[2], '; Fusce at dignissim quam.');
94
+ t.is(chunks[3], htmlChunkTwo);
95
+ });
96
+
97
+ test('should chunk html element correctly when chunk size is exactly the same as the element length', async t => {
98
+ const chunkSize = encode(htmlChunkTwo).length;
99
+ const chunks = getSemanticChunks(htmlChunkTwo, chunkSize, 'html');
100
+
101
+ t.is(chunks.length, 1);
102
+ t.is(chunks[0], htmlChunkTwo);
103
+ });
104
+
105
+ test('should chunk html element correctly when chunk size is greater than the element length', async t => {
106
+ const chunkSize = encode(htmlChunkTwo).length;
107
+ const chunks = getSemanticChunks(htmlChunkTwo, chunkSize + 1, 'html');
108
+
109
+ t.is(chunks.length, 1);
110
+ t.is(chunks[0], htmlChunkTwo);
111
+ });
112
+
113
+ test('should not break up second html element correctly when chunk size is greater than the first element length', async t => {
114
+ const chunkSize = encode(htmlChunkTwo).length;
115
+ const chunks = getSemanticChunks(htmlChunkTwo + htmlChunkTwo, chunkSize + 10, 'html');
116
+
117
+ t.is(chunks.length, 2);
118
+ t.is(chunks[0], htmlChunkTwo);
119
+ t.is(chunks[1], htmlChunkTwo);
120
+ });
121
+
122
+ test('should treat text chunks as also unbreakable chunks', async t => {
123
+ const chunkSize = encode(htmlChunkTwo).length;
124
+ const chunks = getSemanticChunks(htmlChunkTwo + plainTextChunk + htmlChunkTwo, chunkSize + 20, 'html');
125
+
126
+ t.is(chunks.length, 3);
127
+ t.is(chunks[0], htmlChunkTwo);
128
+ t.is(chunks[1], plainTextChunk);
129
+ t.is(chunks[2], htmlChunkTwo);
130
+ });
131
+
132
+
133
+ test('should determine format correctly for text only', async t => {
134
+ const format = determineTextFormat(plainTextChunk);
135
+ t.is(format, 'text');
136
+ });
137
+
138
+ test('should determine format correctly for simple html element', async t => {
139
+ const format = determineTextFormat(htmlChunkTwo);
140
+ t.is(format, 'html');
141
+ });
142
+
143
+ test('should determine format correctly for simple html element embedded in text', async t => {
144
+ const format = determineTextFormat(plainTextChunk + htmlChunkTwo + plainTextChunk);
145
+ t.is(format, 'html');
146
+ });
147
+
148
+ test('should determine format correctly for self-closing html element', async t => {
149
+ const format = determineTextFormat(htmlSelfClosingElement);
150
+ t.is(format, 'html');
151
+ });
152
+
153
+ test('should determine format correctly for self-closing html element embedded in text', async t => {
154
+ const format = determineTextFormat(plainTextChunk + htmlSelfClosingElement + plainTextChunk);
155
+ t.is(format, 'html');
156
+ });
157
+
158
+ test('should determine format correctly for void element', async t => {
159
+ const format = determineTextFormat(htmlVoidElement);
160
+ t.is(format, 'html');
161
+ });
162
+
163
+ test('should determine format correctly for void element embedded in text', async t => {
164
+ const format = determineTextFormat(plainTextChunk + htmlVoidElement + plainTextChunk);
165
+ t.is(format, 'html');
80
166
  });
81
167
 
82
- it('should return identical text that chunker was passed, given huge chunk size (32000)', () => {
168
+ test('should return identical text that chunker was passed, given huge chunk size (32000)', t => {
83
169
  const maxChunkToken = 32000;
84
170
  const chunks = getSemanticChunks(testText, maxChunkToken);
85
- expect(chunks.length).toBe(1); //check chunking
86
- expect(chunks.every(chunk => encode(chunk).length <= maxChunkToken)).toBe(true); //check chunk size
171
+ t.assert(chunks.length === 1); //check chunking
172
+ t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
87
173
  const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
88
- expect(recomposedText).toBe(testText); //check recomposition
174
+ t.assert(recomposedText === testText); //check recomposition
89
175
  });
90
176
 
91
177
  const testTextNoSpaces = `Loremipsumdolorsitamet,consecteturadipiscingelit.Inideratsem.Phasellusacdapibuspurus,infermentumnunc.Maurisquisrutrummagna.Quisquerutrum,auguevelblanditposuere,auguemagnacon vallisturpis,necelementumauguemaurissitametnunc.Aeneansitametleoest.Nuncanteex,blanditetfelisut,iaculislaciniaest.Phasellusdictumorciidliberoullamcorpertempor.Vivamusidpharetraodioq.Sedconsecteturleosedtortordictumvenenatis.Donecgravidaliberononaccumsansuscipit.Doneclectusturpis,ullamcorpereupulvinariaculis,ornareutrisus.Phasellusaliquam,turpisquisviverracondimentum,risusestpretiummetus,inportaips umtortorvita elit.Pellentesqueidfinibuserat.Insuscipit,sapiennonposueredignissim,auguenisl ultricestortor,sitameteleifendnibhelitatrisus.`;
92
178
 
93
- it('should return identical text that chunker was passed, given no spaces and small chunks(5)', () => {
179
+ test('should return identical text that chunker was passed, given no spaces and small chunks(5)', t => {
94
180
  const maxChunkToken = 5;
95
181
  const chunks = getSemanticChunks(testTextNoSpaces, maxChunkToken);
96
- expect(chunks.length).toBeGreaterThan(0); //check chunking
97
- expect(chunks.every(chunk => encode(chunk).length <= maxChunkToken)).toBe(true); //check chunk size
182
+ t.assert(chunks.length > 0); //check chunking
183
+ t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
98
184
  const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
99
- expect(recomposedText).toBe(testTextNoSpaces); //check recomposition
185
+ t.assert(recomposedText === testTextNoSpaces); //check recomposition
100
186
  });
101
187
 
102
188
  const testTextShortWeirdSpaces=`Lorem ipsum dolor sit amet, consectetur adipiscing elit. In id erat sem. Phasellus ac dapibus purus, in fermentum nunc.............................. Mauris quis rutrum magna. Quisque rutrum, augue vel blandit posuere, augue magna convallis turpis, nec elementum augue mauris sit amet nunc. Aenean sit a;lksjdf 098098- -23 eln ;lkn l;kn09 oij[0u ,,,,,,,,,,,,,,,,,,,,, amet leo est. Nunc ante ex, blandit et felis ut, iaculis lacinia est. Phasellus dictum orci id libero ullamcorper tempor.
@@ -106,20 +192,20 @@ const testTextShortWeirdSpaces=`Lorem ipsum dolor sit amet, consectetur adipisci
106
192
 
107
193
  Vivamus id pharetra odio. Sed consectetur leo sed tortor dictum venenatis.Donec gravida libero non accumsan suscipit.Donec lectus turpis, ullamcorper eu pulvinar iaculis, ornare ut risus.Phasellus aliquam, turpis quis viverra condimentum, risus est pretium metus, in porta ipsum tortor vitae elit.Pellentesque id finibus erat. In suscipit, sapien non posuere dignissim, augue nisl ultrices tortor, sit amet eleifend nibh elit at risus.`;
108
194
 
109
- it('should return identical text that chunker was passed, given weird spaces and tiny chunks(1)', () => {
195
+ test('should return identical text that chunker was passed, given weird spaces and tiny chunks(1)', t => {
110
196
  const maxChunkToken = 1;
111
197
  const chunks = getSemanticChunks(testTextShortWeirdSpaces, maxChunkToken);
112
- expect(chunks.length).toBeGreaterThan(0); //check chunking
113
- expect(chunks.every(chunk => encode(chunk).length <= maxChunkToken)).toBe(true); //check chunk size
198
+ t.assert(chunks.length > 0); //check chunking
199
+ t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
114
200
  const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
115
- expect(recomposedText).toBe(testTextShortWeirdSpaces); //check recomposition
201
+ t.assert(recomposedText === testTextShortWeirdSpaces); //check recomposition
116
202
  });
117
203
 
118
- it('should return identical text that chunker was passed, given weird spaces and small chunks(10)', () => {
204
+ test('should return identical text that chunker was passed, given weird spaces and small chunks(10)', t => {
119
205
  const maxChunkToken = 1;
120
206
  const chunks = getSemanticChunks(testTextShortWeirdSpaces, maxChunkToken);
121
- expect(chunks.length).toBeGreaterThan(0); //check chunking
122
- expect(chunks.every(chunk => encode(chunk).length <= maxChunkToken)).toBe(true); //check chunk size
207
+ t.assert(chunks.length > 0); //check chunking
208
+ t.assert(chunks.every(chunk => encode(chunk).length <= maxChunkToken)); //check chunk size
123
209
  const recomposedText = chunks.reduce((acc, chunk) => acc + chunk, '');
124
- expect(recomposedText).toBe(testTextShortWeirdSpaces); //check recomposition
125
- });*/
210
+ t.assert(recomposedText === testTextShortWeirdSpaces); //check recomposition
211
+ });
package/tests/mocks.js CHANGED
@@ -36,4 +36,45 @@ export const mockConfig = {
36
36
  { role: 'assistant', content: 'Translating: {{{text}}}' },
37
37
  ],
38
38
  }),
39
- };
39
+ };
40
+
41
+ export const mockPathwayResolverString = {
42
+ model: {
43
+ url: 'https://api.example.com/testModel',
44
+ type: 'OPENAI-COMPLETION',
45
+ },
46
+ modelName: 'testModel',
47
+ pathway: mockPathwayString,
48
+ config: mockConfig,
49
+ prompt: new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.'),
50
+ };
51
+
52
+ export const mockPathwayResolverFunction = {
53
+ model: {
54
+ url: 'https://api.example.com/testModel',
55
+ type: 'OPENAI-COMPLETION',
56
+ },
57
+ modelName: 'testModel',
58
+ pathway: mockPathwayFunction,
59
+ config: mockConfig,
60
+ prompt: () => {
61
+ return new Prompt('User: {{text}}\nAssistant: Please help {{name}} who is {{age}} years old.')
62
+ }
63
+ };
64
+
65
+ export const mockPathwayResolverMessages = {
66
+ model: {
67
+ url: 'https://api.example.com/testModel',
68
+ type: 'OPENAI-COMPLETION',
69
+ },
70
+ modelName: 'testModel',
71
+ pathway: mockPathwayMessages,
72
+ config: mockConfig,
73
+ prompt: new Prompt({
74
+ messages: [
75
+ { role: 'user', content: 'Translate this: {{{text}}}' },
76
+ { role: 'assistant', content: 'Translating: {{{text}}}' },
77
+ ],
78
+ }),
79
+ };
80
+
@@ -2,17 +2,16 @@
2
2
  import test from 'ava';
3
3
  import ModelPlugin from '../server/plugins/modelPlugin.js';
4
4
  import HandleBars from '../lib/handleBars.js';
5
- import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages } from './mocks.js';
5
+ import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages, mockPathwayResolverString } from './mocks.js';
6
6
 
7
7
  const DEFAULT_MAX_TOKENS = 4096;
8
8
  const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
9
9
 
10
10
  // Mock configuration and pathway objects
11
- const config = mockConfig;
12
- const pathway = mockPathwayString;
11
+ const { config, pathway, modelName, model } = mockPathwayResolverString;
13
12
 
14
13
  test('ModelPlugin constructor', (t) => {
15
- const modelPlugin = new ModelPlugin(config, pathway);
14
+ const modelPlugin = new ModelPlugin(config, pathway, modelName, model);
16
15
 
17
16
  t.is(modelPlugin.modelName, pathway.model, 'modelName should be set from pathway');
18
17
  t.deepEqual(modelPlugin.model, config.get('models')[pathway.model], 'model should be set from config');
@@ -21,7 +20,7 @@ test('ModelPlugin constructor', (t) => {
21
20
  });
22
21
 
23
22
  test.beforeEach((t) => {
24
- t.context.modelPlugin = new ModelPlugin(mockConfig, mockPathwayString);
23
+ t.context.modelPlugin = new ModelPlugin(config, pathway, modelName, model);
25
24
  });
26
25
 
27
26
  test('getCompiledPrompt - text and parameters', (t) => {
@@ -1,17 +1,19 @@
1
1
  import test from 'ava';
2
2
  import OpenAIChatPlugin from '../server/plugins/openAiChatPlugin.js';
3
- import { mockConfig, mockPathwayString, mockPathwayFunction, mockPathwayMessages } from './mocks.js';
3
+ import { mockPathwayResolverMessages } from './mocks.js';
4
+
5
+ const { config, pathway, modelName, model } = mockPathwayResolverMessages;
4
6
 
5
7
  // Test the constructor
6
8
  test('constructor', (t) => {
7
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
8
- t.is(plugin.config, mockConfig);
9
- t.is(plugin.pathwayPrompt, mockPathwayString.prompt);
9
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
10
+ t.is(plugin.config, mockPathwayResolverMessages.config);
11
+ t.is(plugin.pathwayPrompt, mockPathwayResolverMessages.pathway.prompt);
10
12
  });
11
13
 
12
14
  // Test the convertPalmToOpenAIMessages function
13
15
  test('convertPalmToOpenAIMessages', (t) => {
14
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
16
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
15
17
  const context = 'This is a test context.';
16
18
  const examples = [
17
19
  {
@@ -35,14 +37,21 @@ test('convertPalmToOpenAIMessages', (t) => {
35
37
 
36
38
  // Test the getRequestParameters function
37
39
  test('getRequestParameters', async (t) => {
38
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
40
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
39
41
  const text = 'Help me';
40
42
  const parameters = { name: 'John', age: 30 };
41
- const prompt = mockPathwayString.prompt;
43
+ const prompt = mockPathwayResolverMessages.pathway.prompt;
42
44
  const result = await plugin.getRequestParameters(text, parameters, prompt);
43
45
  t.deepEqual(result, {
44
46
  messages: [
45
- { role: 'user', content: 'User: Help me\nAssistant: Please help John who is 30 years old.' },
47
+ {
48
+ content: 'Translate this: Help me',
49
+ role: 'user',
50
+ },
51
+ {
52
+ content: 'Translating: Help me',
53
+ role: 'assistant',
54
+ },
46
55
  ],
47
56
  temperature: 0.7,
48
57
  });
@@ -50,10 +59,10 @@ test('getRequestParameters', async (t) => {
50
59
 
51
60
  // Test the execute function
52
61
  test('execute', async (t) => {
53
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
62
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
54
63
  const text = 'Help me';
55
64
  const parameters = { name: 'John', age: 30 };
56
- const prompt = mockPathwayString.prompt;
65
+ const prompt = mockPathwayResolverMessages.pathway.prompt;
57
66
 
58
67
  // Mock the executeRequest function
59
68
  plugin.executeRequest = () => {
@@ -82,7 +91,7 @@ test('execute', async (t) => {
82
91
 
83
92
  // Test the parseResponse function
84
93
  test('parseResponse', (t) => {
85
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
94
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
86
95
  const data = {
87
96
  choices: [
88
97
  {
@@ -98,7 +107,7 @@ test('parseResponse', (t) => {
98
107
 
99
108
  // Test the logRequestData function
100
109
  test('logRequestData', (t) => {
101
- const plugin = new OpenAIChatPlugin(mockConfig, mockPathwayString);
110
+ const plugin = new OpenAIChatPlugin(config, pathway, modelName, model);
102
111
  const data = {
103
112
  messages: [
104
113
  { role: 'user', content: 'User: Help me\nAssistant: Please help John who is 30 years old.' },
@@ -113,7 +122,7 @@ test('logRequestData', (t) => {
113
122
  },
114
123
  ],
115
124
  };
116
- const prompt = mockPathwayString.prompt;
125
+ const prompt = mockPathwayResolverMessages.pathway.prompt;
117
126
 
118
127
  // Mock console.log function
119
128
  const originalConsoleLog = console.log;
@@ -1,11 +1,12 @@
1
1
  // test_palmChatPlugin.js
2
2
  import test from 'ava';
3
3
  import PalmChatPlugin from '../server/plugins/palmChatPlugin.js';
4
- import { mockConfig } from './mocks.js';
4
+ import { mockPathwayResolverMessages } from './mocks.js';
5
+
6
+ const { config, pathway, modelName, model } = mockPathwayResolverMessages;
5
7
 
6
8
  test.beforeEach((t) => {
7
- const pathway = 'testPathway';
8
- const palmChatPlugin = new PalmChatPlugin(mockConfig, pathway);
9
+ const palmChatPlugin = new PalmChatPlugin(config, pathway, modelName, model);
9
10
  t.context = { palmChatPlugin };
10
11
  });
11
12
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import test from 'ava';
4
4
  import PalmCompletionPlugin from '../server/plugins/palmCompletionPlugin.js';
5
- import { mockConfig } from './mocks.js';
5
+ import { mockPathwayResolverString } from './mocks.js';
6
+
7
+ const { config, pathway, modelName, model } = mockPathwayResolverString;
6
8
 
7
9
  test.beforeEach((t) => {
8
- const pathway = 'testPathway';
9
- const palmCompletionPlugin = new PalmCompletionPlugin(mockConfig, pathway);
10
+ const palmCompletionPlugin = new PalmCompletionPlugin(config, pathway, modelName, model);
10
11
  t.context = { palmCompletionPlugin };
11
12
  });
12
13
 
@@ -2,12 +2,11 @@
2
2
  import test from 'ava';
3
3
  import ModelPlugin from '../server/plugins/modelPlugin.js';
4
4
  import { encode } from 'gpt-3-encoder';
5
- import { mockConfig, mockPathwayString } from './mocks.js';
5
+ import { mockPathwayResolverString } from './mocks.js';
6
6
 
7
- const config = mockConfig;
8
- const pathway = mockPathwayString;
7
+ const { config, pathway, modelName, model } = mockPathwayResolverString;
9
8
 
10
- const modelPlugin = new ModelPlugin(config, pathway);
9
+ const modelPlugin = new ModelPlugin(config, pathway, modelName, model);
11
10
 
12
11
  const generateMessage = (role, content) => ({ role, content });
13
12
 
package/tests/server.js DELETED
@@ -1,23 +0,0 @@
1
- import 'dotenv/config'
2
- import { ApolloServer } from 'apollo-server';
3
- import { config } from '../config.js';
4
- import typeDefsresolversFactory from '../index.js';
5
-
6
- let typeDefs;
7
- let resolvers;
8
-
9
- const initTypeDefsResolvers = async () => {
10
- const result = await typeDefsresolversFactory();
11
- typeDefs = result.typeDefs;
12
- resolvers = result.resolvers;
13
- };
14
-
15
- export const startTestServer = async () => {
16
- await initTypeDefsResolvers();
17
-
18
- return new ApolloServer({
19
- typeDefs,
20
- resolvers,
21
- context: () => ({ config, requestState: {} }),
22
- });
23
- };