@aj-archipelago/cortex 0.0.5 → 0.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,19 +8,20 @@ const { getFirstNToken, getLastNToken, getSemanticChunks } = require('./chunker'
8
8
  const { PathwayResponseParser } = require('./pathwayResponseParser');
9
9
  const { Prompt } = require('./prompt');
10
10
  const { getv, setv } = require('../lib/keyValueStorageClient');
11
+ const { requestState } = require('./requestState');
11
12
 
12
13
  const MAX_PREVIOUS_RESULT_TOKEN_LENGTH = 1000;
13
14
 
14
- const callPathway = async (config, pathwayName, requestState, { text, ...parameters }) => {
15
- const pathwayResolver = new PathwayResolver({ config, pathway: config.get(`pathways.${pathwayName}`), requestState });
15
+ const callPathway = async (config, pathwayName, args, requestState, { text, ...parameters }) => {
16
+ const pathwayResolver = new PathwayResolver({ config, pathway: config.get(`pathways.${pathwayName}`), args, requestState });
16
17
  return await pathwayResolver.resolve({ text, ...parameters });
17
18
  }
18
19
 
19
20
  class PathwayResolver {
20
- constructor({ config, pathway, requestState }) {
21
+ constructor({ config, pathway, args }) {
21
22
  this.config = config;
22
- this.requestState = requestState;
23
23
  this.pathway = pathway;
24
+ this.args = args;
24
25
  this.useInputChunking = pathway.useInputChunking;
25
26
  this.chunkMaxTokenLength = 0;
26
27
  this.warnings = [];
@@ -29,38 +30,89 @@ class PathwayResolver {
29
30
  this.pathwayPrompter = new PathwayPrompter({ config, pathway });
30
31
  this.previousResult = '';
31
32
  this.prompts = [];
32
- this._pathwayPrompt = '';
33
33
 
34
34
  Object.defineProperty(this, 'pathwayPrompt', {
35
35
  get() {
36
- return this._pathwayPrompt;
36
+ return this.prompts
37
37
  },
38
38
  set(value) {
39
- this._pathwayPrompt = value;
40
- if (!Array.isArray(this._pathwayPrompt)) {
41
- this._pathwayPrompt = [this._pathwayPrompt];
39
+ if (!Array.isArray(value)) {
40
+ value = [value];
42
41
  }
43
- this.prompts = this._pathwayPrompt.map(p => (p instanceof Prompt) ? p : new Prompt({ prompt:p }));
42
+ this.prompts = value.map(p => (p instanceof Prompt) ? p : new Prompt({ prompt:p }));
44
43
  this.chunkMaxTokenLength = this.getChunkMaxTokenLength();
45
44
  }
46
45
  });
47
46
 
47
+ // set up initial prompt
48
48
  this.pathwayPrompt = pathway.prompt;
49
49
  }
50
50
 
51
- async resolve(args) {
52
- if (args.async) {
53
- // Asynchronously process the request
54
- this.promptAndParse(args).then((data) => {
55
- this.requestState[this.requestId].data = data;
56
- pubsub.publish('REQUEST_PROGRESS', {
57
- requestProgress: {
58
- requestId: this.requestId,
59
- data: JSON.stringify(data)
51
+ async asyncResolve(args) {
52
+ // Wait with a sleep promise for the race condition to resolve
53
+ // const results = await Promise.all([this.promptAndParse(args), await new Promise(resolve => setTimeout(resolve, 250))]);
54
+ const data = await this.promptAndParse(args);
55
+ // Process the results for async
56
+ if(args.async || typeof data === 'string') { // if async flag set or processed async and got string response
57
+ const { completedCount, totalCount } = requestState[this.requestId];
58
+ requestState[this.requestId].data = data;
59
+ pubsub.publish('REQUEST_PROGRESS', {
60
+ requestProgress: {
61
+ requestId: this.requestId,
62
+ progress: completedCount / totalCount,
63
+ data: JSON.stringify(data),
64
+ }
65
+ });
66
+ } else { //stream
67
+ for (const handle of data) {
68
+ handle.on('data', data => {
69
+ console.log(data.toString());
70
+ const lines = data.toString().split('\n').filter(line => line.trim() !== '');
71
+ for (const line of lines) {
72
+ const message = line.replace(/^data: /, '');
73
+ if (message === '[DONE]') {
74
+ // Send stream finished message
75
+ pubsub.publish('REQUEST_PROGRESS', {
76
+ requestProgress: {
77
+ requestId: this.requestId,
78
+ data: null,
79
+ progress: 1,
80
+ }
81
+ });
82
+ return; // Stream finished
83
+ }
84
+ try {
85
+ const parsed = JSON.parse(message);
86
+ const result = this.pathwayPrompter.plugin.parseResponse(parsed)
87
+
88
+ pubsub.publish('REQUEST_PROGRESS', {
89
+ requestProgress: {
90
+ requestId: this.requestId,
91
+ data: JSON.stringify(result)
92
+ }
93
+ });
94
+ } catch (error) {
95
+ console.error('Could not JSON parse stream message', message, error);
96
+ }
60
97
  }
61
98
  });
62
- });
63
99
 
100
+ // data.on('end', () => {
101
+ // console.log("stream done");
102
+ // });
103
+ }
104
+
105
+ }
106
+ }
107
+
108
+ async resolve(args) {
109
+ if (args.async || args.stream) {
110
+ // Asyncronously process the request
111
+ // this.asyncResolve(args);
112
+ if (!requestState[this.requestId]) {
113
+ requestState[this.requestId] = {}
114
+ }
115
+ requestState[this.requestId] = { ...requestState[this.requestId], args, resolver: this.asyncResolve.bind(this) };
64
116
  return this.requestId;
65
117
  }
66
118
  else {
@@ -70,7 +122,6 @@ class PathwayResolver {
70
122
  }
71
123
 
72
124
  async promptAndParse(args) {
73
-
74
125
  // Get saved context from contextId or change contextId if needed
75
126
  const { contextId } = args;
76
127
  this.savedContextId = contextId ? contextId : null;
@@ -94,25 +145,25 @@ class PathwayResolver {
94
145
 
95
146
  // Here we choose how to handle long input - either summarize or chunk
96
147
  processInputText(text) {
97
- let chunkMaxChunkTokenLength = 0;
148
+ let chunkTokenLength = 0;
98
149
  if (this.pathway.inputChunkSize) {
99
- chunkMaxChunkTokenLength = Math.min(this.pathway.inputChunkSize, this.chunkMaxTokenLength);
150
+ chunkTokenLength = Math.min(this.pathway.inputChunkSize, this.chunkMaxTokenLength);
100
151
  } else {
101
- chunkMaxChunkTokenLength = this.chunkMaxTokenLength;
152
+ chunkTokenLength = this.chunkMaxTokenLength;
102
153
  }
103
154
  const encoded = encode(text);
104
- if (!this.useInputChunking || encoded.length <= chunkMaxChunkTokenLength) { // no chunking, return as is
105
- if (encoded.length >= chunkMaxChunkTokenLength) {
106
- const warnText = `Your input is possibly too long, truncating! Text length: ${text.length}`;
155
+ if (!this.useInputChunking || encoded.length <= chunkTokenLength) { // no chunking, return as is
156
+ if (encoded.length >= chunkTokenLength) {
157
+ const warnText = `Truncating long input text. Text length: ${text.length}`;
107
158
  this.warnings.push(warnText);
108
159
  console.warn(warnText);
109
- text = this.truncate(text, chunkMaxChunkTokenLength);
160
+ text = this.truncate(text, chunkTokenLength);
110
161
  }
111
162
  return [text];
112
163
  }
113
164
 
114
165
  // chunk the text and return the chunks with newline separators
115
- return getSemanticChunks({ text, maxChunkToken: chunkMaxChunkTokenLength });
166
+ return getSemanticChunks({ text, maxChunkToken: chunkTokenLength });
116
167
  }
117
168
 
118
169
  truncate(str, n) {
@@ -124,7 +175,7 @@ class PathwayResolver {
124
175
 
125
176
  async summarizeIfEnabled({ text, ...parameters }) {
126
177
  if (this.pathway.useInputSummarization) {
127
- return await callPathway(this.config, 'summary', this.requestState, { text, targetLength: 1000, ...parameters });
178
+ return await callPathway(this.config, 'summary', this.args, requestState, { text, targetLength: 1000, ...parameters });
128
179
  }
129
180
  return text;
130
181
  }
@@ -132,46 +183,44 @@ class PathwayResolver {
132
183
  // Calculate the maximum token length for a chunk
133
184
  getChunkMaxTokenLength() {
134
185
  // find the longest prompt
135
- const maxPromptTokenLength = Math.max(...this.prompts.map(({ prompt }) => prompt ? encode(String(prompt)).length : 0));
136
- const maxMessagesTokenLength = Math.max(...this.prompts.map(({ messages }) => messages ? messages.reduce((acc, {role, content}) => {
137
- return (role && content) ? acc + encode(role).length + encode(content).length : acc;
138
- }, 0) : 0));
139
-
140
- const maxTokenLength = Math.max(maxPromptTokenLength, maxMessagesTokenLength);
141
-
186
+ const maxPromptTokenLength = Math.max(...this.prompts.map((promptData) => this.pathwayPrompter.plugin.getCompiledPrompt('', this.args, promptData).tokenLength));
187
+
142
188
  // find out if any prompts use both text input and previous result
143
- const hasBothProperties = this.prompts.some(prompt => prompt.usesInputText && prompt.usesPreviousResult);
189
+ const hasBothProperties = this.prompts.some(prompt => prompt.usesTextInput && prompt.usesPreviousResult);
144
190
 
145
191
  // the token ratio is the ratio of the total prompt to the result text - both have to be included
146
192
  // in computing the max token length
147
193
  const promptRatio = this.pathwayPrompter.plugin.getPromptTokenRatio();
148
- let maxChunkToken = promptRatio * this.pathwayPrompter.plugin.getModelMaxTokenLength() - maxTokenLength;
149
-
194
+ let chunkMaxTokenLength = promptRatio * this.pathwayPrompter.plugin.getModelMaxTokenLength() - maxPromptTokenLength;
195
+
150
196
  // if we have to deal with prompts that have both text input
151
197
  // and previous result, we need to split the maxChunkToken in half
152
- maxChunkToken = hasBothProperties ? maxChunkToken / 2 : maxChunkToken;
153
-
154
- // detect if the longest prompt might be too long to allow any chunk size
155
- if (maxChunkToken && maxChunkToken <= 0) {
156
- throw new Error(`Your prompt is too long! Split to multiple prompts or reduce length of your prompt, prompt length: ${maxPromptLength}`);
157
- }
158
- return maxChunkToken;
198
+ chunkMaxTokenLength = hasBothProperties ? chunkMaxTokenLength / 2 : chunkMaxTokenLength;
199
+
200
+ return chunkMaxTokenLength;
159
201
  }
160
202
 
161
203
  // Process the request and return the result
162
204
  async processRequest({ text, ...parameters }) {
163
-
164
205
  text = await this.summarizeIfEnabled({ text, ...parameters }); // summarize if flag enabled
165
206
  const chunks = this.processInputText(text);
166
207
 
167
208
  const anticipatedRequestCount = chunks.length * this.prompts.length;
168
209
 
169
- if ((this.requestState[this.requestId] || {}).canceled) {
210
+ if ((requestState[this.requestId] || {}).canceled) {
170
211
  throw new Error('Request canceled');
171
212
  }
172
213
 
173
214
  // Store the request state
174
- this.requestState[this.requestId] = { totalCount: anticipatedRequestCount, completedCount: 0 };
215
+ requestState[this.requestId] = { ...requestState[this.requestId], totalCount: anticipatedRequestCount, completedCount: 0 };
216
+
217
+ if (chunks.length > 1) {
218
+ // stream behaves as async if there are multiple chunks
219
+ if (parameters.stream) {
220
+ parameters.async = true;
221
+ parameters.stream = false;
222
+ }
223
+ }
175
224
 
176
225
  // If pre information is needed, apply current prompt with previous prompt info, only parallelize current call
177
226
  if (this.pathway.useParallelChunkProcessing) {
@@ -189,17 +238,31 @@ class PathwayResolver {
189
238
  let result = '';
190
239
 
191
240
  for (let i = 0; i < this.prompts.length; i++) {
241
+ const currentParameters = { ...parameters, previousResult };
242
+
243
+ if (currentParameters.stream) { // stream special flow
244
+ if (i < this.prompts.length - 1) {
245
+ currentParameters.stream = false; // if not the last prompt then don't stream
246
+ }
247
+ else {
248
+ // use the stream parameter if not async
249
+ currentParameters.stream = currentParameters.async ? false : currentParameters.stream;
250
+ }
251
+ }
252
+
192
253
  // If the prompt doesn't contain {{text}} then we can skip the chunking, and also give that token space to the previous result
193
254
  if (!this.prompts[i].usesTextInput) {
194
255
  // Limit context to it's N + text's characters
195
256
  previousResult = this.truncate(previousResult, 2 * this.chunkMaxTokenLength);
196
- result = await this.applyPrompt(this.prompts[i], null, { ...parameters, previousResult });
257
+ result = await this.applyPrompt(this.prompts[i], null, currentParameters);
197
258
  } else {
198
259
  // Limit context to N characters
199
260
  previousResult = this.truncate(previousResult, this.chunkMaxTokenLength);
200
261
  result = await Promise.all(chunks.map(chunk =>
201
- this.applyPrompt(this.prompts[i], chunk, { ...parameters, previousResult })));
202
- result = result.join("\n\n")
262
+ this.applyPrompt(this.prompts[i], chunk, currentParameters)));
263
+ if (!currentParameters.stream) {
264
+ result = result.join("\n\n")
265
+ }
203
266
  }
204
267
 
205
268
  // If this is any prompt other than the last, use the result as the previous context
@@ -225,20 +288,22 @@ class PathwayResolver {
225
288
  }
226
289
 
227
290
  async applyPrompt(prompt, text, parameters) {
228
- if (this.requestState[this.requestId].canceled) {
291
+ if (requestState[this.requestId].canceled) {
229
292
  return;
230
293
  }
231
- const result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt);
232
- this.requestState[this.requestId].completedCount++;
294
+ const result = await this.pathwayPrompter.execute(text, { ...parameters, ...this.savedContext }, prompt, this);
295
+ requestState[this.requestId].completedCount++;
233
296
 
234
- const { completedCount, totalCount } = this.requestState[this.requestId];
297
+ const { completedCount, totalCount } = requestState[this.requestId];
235
298
 
236
- pubsub.publish('REQUEST_PROGRESS', {
237
- requestProgress: {
238
- requestId: this.requestId,
239
- progress: completedCount / totalCount,
240
- }
241
- });
299
+ if (completedCount < totalCount) {
300
+ pubsub.publish('REQUEST_PROGRESS', {
301
+ requestProgress: {
302
+ requestId: this.requestId,
303
+ progress: completedCount / totalCount,
304
+ }
305
+ });
306
+ }
242
307
 
243
308
  if (prompt.saveResultTo) {
244
309
  this.savedContext[prompt.saveResultTo] = result;
@@ -1,19 +1,26 @@
1
1
  // AzureTranslatePlugin.js
2
2
  const ModelPlugin = require('./modelPlugin');
3
3
  const handlebars = require("handlebars");
4
+ const { encode } = require("gpt-3-encoder");
4
5
 
5
6
  class AzureTranslatePlugin extends ModelPlugin {
6
- constructor(config, modelName, pathway) {
7
- super(config, modelName, pathway);
7
+ constructor(config, pathway) {
8
+ super(config, pathway);
8
9
  }
9
10
 
10
- // Set up parameters specific to the Azure Translate API
11
- requestParameters(text, parameters, prompt) {
11
+ getCompiledPrompt(text, parameters, prompt) {
12
12
  const combinedParameters = { ...this.promptParameters, ...parameters };
13
13
  const modelPrompt = this.getModelPrompt(prompt, parameters);
14
14
  const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
15
-
16
- return {
15
+
16
+ return { modelPromptText, tokenLength: encode(modelPromptText).length };
17
+ }
18
+
19
+ // Set up parameters specific to the Azure Translate API
20
+ getRequestParameters(text, parameters, prompt) {
21
+ const combinedParameters = { ...this.promptParameters, ...parameters };
22
+ const { modelPromptText } = this.getCompiledPrompt(text, parameters, prompt);
23
+ const requestParameters = {
17
24
  data: [
18
25
  {
19
26
  Text: modelPromptText,
@@ -23,11 +30,12 @@ class AzureTranslatePlugin extends ModelPlugin {
23
30
  to: combinedParameters.to
24
31
  }
25
32
  };
33
+ return requestParameters;
26
34
  }
27
35
 
28
36
  // Execute the request to the Azure Translate API
29
37
  async execute(text, parameters, prompt) {
30
- const requestParameters = this.requestParameters(text, parameters, prompt);
38
+ const requestParameters = this.getRequestParameters(text, parameters, prompt);
31
39
 
32
40
  const url = this.requestUrl(text);
33
41
 
@@ -35,7 +43,7 @@ class AzureTranslatePlugin extends ModelPlugin {
35
43
  const params = requestParameters.params;
36
44
  const headers = this.model.headers || {};
37
45
 
38
- return this.executeRequest(url, data, params, headers);
46
+ return this.executeRequest(url, data, params, headers, prompt);
39
47
  }
40
48
  }
41
49
 
@@ -1,7 +1,7 @@
1
1
  // ModelPlugin.js
2
2
  const handlebars = require('handlebars');
3
3
  const { request } = require("../../lib/request");
4
- const { getResponseResult } = require("../parser");
4
+ const { encode } = require("gpt-3-encoder");
5
5
 
6
6
  const DEFAULT_MAX_TOKENS = 4096;
7
7
  const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
@@ -35,6 +35,42 @@ class ModelPlugin {
35
35
  }
36
36
 
37
37
  this.requestCount = 1;
38
+ this.shouldCache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
39
+ }
40
+
41
+ // Function to remove non-system messages until token length is less than target
42
+ removeMessagesUntilTarget = (messages, targetTokenLength) => {
43
+ let chatML = this.messagesToChatML(messages);
44
+ let tokenLength = encode(chatML).length;
45
+
46
+ while (tokenLength > targetTokenLength) {
47
+ for (let i = 0; i < messages.length; i++) {
48
+ if (messages[i].role !== 'system') {
49
+ messages.splice(i, 1);
50
+ chatML = this.messagesToChatML(messages);
51
+ tokenLength = encode(chatML).length;
52
+ break;
53
+ }
54
+ }
55
+ if (messages.every(message => message.role === 'system')) {
56
+ break; // All remaining messages are 'system', stop removing messages
57
+ }
58
+ }
59
+ return messages;
60
+ }
61
+
62
+ //convert a messages array to a simple chatML format
63
+ messagesToChatML = (messages) => {
64
+ let output = "";
65
+ if (messages && messages.length) {
66
+ for (let message of messages) {
67
+ output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
68
+ }
69
+ // you always want the assistant to respond next so add a
70
+ // directive for that
71
+ output += "<|im_start|>assistant\n";
72
+ }
73
+ return output;
38
74
  }
39
75
 
40
76
  getModelMaxTokenLength() {
@@ -102,6 +138,8 @@ class ModelPlugin {
102
138
  if (!choices || !choices.length) {
103
139
  if (Array.isArray(data) && data.length > 0 && data[0].translations) {
104
140
  return data[0].translations[0].text.trim();
141
+ } else {
142
+ return data;
105
143
  }
106
144
  }
107
145
 
@@ -114,20 +152,40 @@ class ModelPlugin {
114
152
  const textResult = choices[0].text && choices[0].text.trim();
115
153
  const messageResult = choices[0].message && choices[0].message.content && choices[0].message.content.trim();
116
154
 
117
- return messageResult || textResult || null;
155
+ return messageResult ?? textResult ?? null;
118
156
  }
119
157
 
120
- async executeRequest(url, data, params, headers) {
121
- const responseData = await request({ url, data, params, headers }, this.modelName);
122
- const modelInput = data.prompt || (data.messages && data.messages[0].content) || data[0].Text || null;
123
- console.log(`=== ${this.pathwayName}.${this.requestCount++} ===`)
124
- console.log(`\x1b[36m${modelInput}\x1b[0m`)
158
+ logRequestData(data, responseData, prompt) {
159
+ const separator = `\n=== ${this.pathwayName}.${this.requestCount++} ===\n`;
160
+ console.log(separator);
161
+
162
+ const modelInput = data.prompt || (data.messages && data.messages[0].content) || (data.length > 0 && data[0].Text) || null;
163
+
164
+ if (data.messages && data.messages.length > 1) {
165
+ data.messages.forEach((message, index) => {
166
+ const words = message.content.split(" ");
167
+ const tokenCount = encode(message.content).length;
168
+ const preview = words.length < 41 ? message.content : words.slice(0, 20).join(" ") + " ... " + words.slice(-20).join(" ");
169
+
170
+ console.log(`\x1b[36mMessage ${index + 1}: Role: ${message.role}, Tokens: ${tokenCount}, Content: "${preview}"\x1b[0m`);
171
+ });
172
+ } else {
173
+ console.log(`\x1b[36m${modelInput}\x1b[0m`);
174
+ }
175
+
125
176
  console.log(`\x1b[34m> ${this.parseResponse(responseData)}\x1b[0m`);
126
-
177
+
178
+ prompt.debugInfo += `${separator}${JSON.stringify(data)}`;
179
+ }
180
+
181
+ async executeRequest(url, data, params, headers, prompt) {
182
+ const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName);
183
+
127
184
  if (responseData.error) {
128
185
  throw new Exception(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
129
186
  }
130
-
187
+
188
+ this.logRequestData(data, responseData, prompt);
131
189
  return this.parseResponse(responseData);
132
190
  }
133
191
 
@@ -1,34 +1,61 @@
1
1
  // OpenAIChatPlugin.js
2
2
  const ModelPlugin = require('./modelPlugin');
3
3
  const handlebars = require("handlebars");
4
+ const { encode } = require("gpt-3-encoder");
4
5
 
5
6
  class OpenAIChatPlugin extends ModelPlugin {
6
7
  constructor(config, pathway) {
7
8
  super(config, pathway);
8
9
  }
9
10
 
10
- // Set up parameters specific to the OpenAI Chat API
11
- requestParameters(text, parameters, prompt) {
11
+ getCompiledPrompt(text, parameters, prompt) {
12
12
  const combinedParameters = { ...this.promptParameters, ...parameters };
13
13
  const modelPrompt = this.getModelPrompt(prompt, parameters);
14
14
  const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
15
15
  const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
16
+ const modelPromptMessagesML = this.messagesToChatML(modelPromptMessages);
17
+
18
+ if (modelPromptMessagesML) {
19
+ return { modelPromptMessages, tokenLength: encode(modelPromptMessagesML).length };
20
+ } else {
21
+ return { modelPromptText, tokenLength: encode(modelPromptText).length };
22
+ }
23
+ }
16
24
 
17
- return {
18
- messages: modelPromptMessages || [{ "role": "user", "content": modelPromptText }],
19
- temperature: this.temperature ?? 0.7,
25
+ // Set up parameters specific to the OpenAI Chat API
26
+ getRequestParameters(text, parameters, prompt) {
27
+ const { modelPromptText, modelPromptMessages, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
28
+ const { stream } = parameters;
29
+
30
+ // Define the model's max token length
31
+ const modelMaxTokenLength = this.getModelMaxTokenLength() * this.getPromptTokenRatio();
32
+
33
+ let requestMessages = modelPromptMessages || [{ "role": "user", "content": modelPromptText }];
34
+
35
+ // Check if the token length exceeds the model's max token length
36
+ if (tokenLength > modelMaxTokenLength) {
37
+ // Remove older messages until the token length is within the model's limit
38
+ requestMessages = this.removeMessagesUntilTarget(requestMessages, modelMaxTokenLength);
39
+ }
40
+
41
+ const requestParameters = {
42
+ messages: requestMessages,
43
+ temperature: this.temperature ?? 0.7,
44
+ stream
20
45
  };
46
+
47
+ return requestParameters;
21
48
  }
22
49
 
23
50
  // Execute the request to the OpenAI Chat API
24
51
  async execute(text, parameters, prompt) {
25
52
  const url = this.requestUrl(text);
26
- const requestParameters = this.requestParameters(text, parameters, prompt);
53
+ const requestParameters = this.getRequestParameters(text, parameters, prompt);
27
54
 
28
55
  const data = { ...(this.model.params || {}), ...requestParameters };
29
56
  const params = {};
30
57
  const headers = this.model.headers || {};
31
- return this.executeRequest(url, data, params, headers);
58
+ return this.executeRequest(url, data, params, headers, prompt);
32
59
  }
33
60
  }
34
61
 
@@ -3,61 +3,81 @@ const ModelPlugin = require('./modelPlugin');
3
3
  const handlebars = require("handlebars");
4
4
  const { encode } = require("gpt-3-encoder");
5
5
 
6
- //convert a messages array to a simple chatML format
7
- const messagesToChatML = (messages) => {
8
- let output = "";
9
- if (messages && messages.length) {
10
- for (let message of messages) {
11
- output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
12
- }
13
- // you always want the assistant to respond next so add a
14
- // directive for that
15
- output += "<|im_start|>assistant\n";
16
- }
17
- return output;
18
- }
19
-
20
6
  class OpenAICompletionPlugin extends ModelPlugin {
21
7
  constructor(config, pathway) {
22
8
  super(config, pathway);
23
9
  }
24
10
 
25
- // Set up parameters specific to the OpenAI Completion API
26
- requestParameters(text, parameters, prompt) {
11
+ getCompiledPrompt(text, parameters, prompt) {
27
12
  const combinedParameters = { ...this.promptParameters, ...parameters };
28
13
  const modelPrompt = this.getModelPrompt(prompt, parameters);
29
14
  const modelPromptText = modelPrompt.prompt ? handlebars.compile(modelPrompt.prompt)({ ...combinedParameters, text }) : '';
30
15
  const modelPromptMessages = this.getModelPromptMessages(modelPrompt, combinedParameters, text);
31
- const modelPromptMessagesML = messagesToChatML(modelPromptMessages);
16
+ const modelPromptMessagesML = this.messagesToChatML(modelPromptMessages);
32
17
 
33
18
  if (modelPromptMessagesML) {
34
- return {
35
- prompt: modelPromptMessagesML,
36
- max_tokens: this.getModelMaxTokenLength() - encode(modelPromptMessagesML).length - 1,
37
- temperature: this.temperature ?? 0.7,
38
- top_p: 0.95,
39
- frequency_penalty: 0,
40
- presence_penalty: 0,
41
- stop: ["<|im_end|>"]
42
- };
19
+ return { modelPromptMessages, tokenLength: encode(modelPromptMessagesML).length };
20
+ } else {
21
+ return { modelPromptText, tokenLength: encode(modelPromptText).length };
22
+ }
23
+ }
24
+
25
+ // Set up parameters specific to the OpenAI Completion API
26
+ getRequestParameters(text, parameters, prompt) {
27
+ let { modelPromptMessages, modelPromptText, tokenLength } = this.getCompiledPrompt(text, parameters, prompt);
28
+ const { stream } = parameters;
29
+ let modelPromptMessagesML = '';
30
+ const modelMaxTokenLength = this.getModelMaxTokenLength();
31
+ let requestParameters = {};
32
+
33
+ if (modelPromptMessages) {
34
+ const requestMessages = this.removeMessagesUntilTarget(modelPromptMessages, modelMaxTokenLength - 1);
35
+ modelPromptMessagesML = this.messagesToChatML(requestMessages);
36
+ tokenLength = encode(modelPromptMessagesML).length;
37
+
38
+ if (tokenLength >= modelMaxTokenLength) {
39
+ throw new Error(`The maximum number of tokens for this model is ${modelMaxTokenLength}. Please reduce the number of messages in the prompt.`);
40
+ }
41
+
42
+ const max_tokens = modelMaxTokenLength - tokenLength - 1;
43
+
44
+ requestParameters = {
45
+ prompt: modelPromptMessagesML,
46
+ max_tokens: max_tokens,
47
+ temperature: this.temperature ?? 0.7,
48
+ top_p: 0.95,
49
+ frequency_penalty: 0,
50
+ presence_penalty: 0,
51
+ stop: ["<|im_end|>"],
52
+ stream
53
+ };
43
54
  } else {
44
- return {
45
- prompt: modelPromptText,
46
- max_tokens: this.getModelMaxTokenLength() - encode(modelPromptText).length - 1,
47
- temperature: this.temperature ?? 0.7,
48
- };
55
+ if (tokenLength >= modelMaxTokenLength) {
56
+ throw new Error(`The maximum number of tokens for this model is ${modelMaxTokenLength}. Please reduce the length of the prompt.`);
57
+ }
58
+
59
+ const max_tokens = modelMaxTokenLength - tokenLength - 1;
60
+
61
+ requestParameters = {
62
+ prompt: modelPromptText,
63
+ max_tokens: max_tokens,
64
+ temperature: this.temperature ?? 0.7,
65
+ stream
66
+ };
49
67
  }
68
+
69
+ return requestParameters;
50
70
  }
51
71
 
52
72
  // Execute the request to the OpenAI Completion API
53
73
  async execute(text, parameters, prompt) {
54
74
  const url = this.requestUrl(text);
55
- const requestParameters = this.requestParameters(text, parameters, prompt);
75
+ const requestParameters = this.getRequestParameters(text, parameters, prompt);
56
76
 
57
77
  const data = { ...(this.model.params || {}), ...requestParameters };
58
78
  const params = {};
59
79
  const headers = this.model.headers || {};
60
- return this.executeRequest(url, data, params, headers);
80
+ return this.executeRequest(url, data, params, headers, prompt);
61
81
  }
62
82
  }
63
83