@aj-archipelago/cortex 0.0.9 → 0.0.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/config.js +17 -11
  2. package/graphql/chunker.js +97 -107
  3. package/graphql/graphql.js +19 -22
  4. package/graphql/parser.js +1 -1
  5. package/graphql/pathwayPrompter.js +8 -9
  6. package/graphql/pathwayResolver.js +12 -14
  7. package/graphql/pathwayResponseParser.js +2 -2
  8. package/graphql/plugins/azureTranslatePlugin.js +2 -2
  9. package/graphql/plugins/modelPlugin.js +67 -25
  10. package/graphql/plugins/openAiChatPlugin.js +3 -3
  11. package/graphql/plugins/openAiCompletionPlugin.js +5 -4
  12. package/graphql/plugins/openAiWhisperPlugin.js +7 -6
  13. package/graphql/prompt.js +1 -1
  14. package/graphql/pubsub.js +2 -2
  15. package/graphql/requestState.js +1 -1
  16. package/graphql/resolver.js +4 -4
  17. package/graphql/subscriptions.js +5 -4
  18. package/graphql/typeDef.js +53 -53
  19. package/index.js +5 -5
  20. package/lib/fileChunker.js +15 -11
  21. package/lib/keyValueStorageClient.js +5 -5
  22. package/lib/promiser.js +2 -2
  23. package/lib/request.js +11 -9
  24. package/lib/requestMonitor.js +2 -2
  25. package/package.json +15 -5
  26. package/pathways/basePathway.js +5 -4
  27. package/pathways/bias.js +2 -2
  28. package/pathways/chat.js +3 -2
  29. package/pathways/complete.js +4 -2
  30. package/pathways/edit.js +3 -2
  31. package/pathways/entities.js +3 -2
  32. package/pathways/index.js +25 -12
  33. package/pathways/lc_test.mjs +99 -0
  34. package/pathways/paraphrase.js +3 -2
  35. package/pathways/sentiment.js +3 -2
  36. package/pathways/summary.js +27 -10
  37. package/pathways/transcribe.js +4 -2
  38. package/pathways/translate.js +3 -2
  39. package/start.js +5 -2
  40. package/tests/chunkfunction.test.js +125 -0
  41. package/tests/chunking.test.js +25 -19
  42. package/tests/main.test.js +52 -38
  43. package/tests/translate.test.js +13 -10
package/config.js CHANGED
@@ -1,7 +1,8 @@
1
- const path = require('path');
2
- const convict = require('convict');
3
- const handlebars = require("handlebars");
4
- const fs = require('fs');
1
+ import path from 'path';
2
+ const __dirname = path.dirname(new URL(import.meta.url).pathname);
3
+ import convict from 'convict';
4
+ import handlebars from 'handlebars';
5
+ import fs from 'fs';
5
6
 
6
7
  // Schema for config
7
8
  var config = convict({
@@ -108,7 +109,13 @@ var config = convict({
108
109
  format: String,
109
110
  default: null,
110
111
  env: 'CORTEX_CONFIG_FILE'
111
- }
112
+ },
113
+ serpApiKey: {
114
+ format: String,
115
+ default: null,
116
+ env: 'SERPAPI_API_KEY',
117
+ sensitive: true
118
+ },
112
119
  });
113
120
 
114
121
  // Read in environment variables and set up service configuration
@@ -127,22 +134,21 @@ if (configFile && fs.existsSync(configFile)) {
127
134
  }
128
135
  }
129
136
 
130
-
131
137
  // Build and load pathways to config
132
- const buildPathways = (config) => {
138
+ const buildPathways = async (config) => {
133
139
  const { pathwaysPath, corePathwaysPath, basePathwayPath } = config.getProperties();
134
140
 
135
141
  // Load cortex base pathway
136
- const basePathway = require(basePathwayPath);
142
+ const basePathway = await import(basePathwayPath).then(module => module.default);
137
143
 
138
144
  // Load core pathways, default from the Cortex package
139
145
  console.log('Loading core pathways from', corePathwaysPath)
140
- let loadedPathways = require(corePathwaysPath);
146
+ let loadedPathways = await import(`${corePathwaysPath}/index.js`).then(module => module);
141
147
 
142
148
  // Load custom pathways and override core pathways if same
143
149
  if (pathwaysPath && fs.existsSync(pathwaysPath)) {
144
150
  console.log('Loading custom pathways from', pathwaysPath)
145
- const customPathways = require(pathwaysPath);
151
+ const customPathways = await import(`${pathwaysPath}/index.js`).then(module => module);
146
152
  loadedPathways = { ...loadedPathways, ...customPathways };
147
153
  }
148
154
 
@@ -191,4 +197,4 @@ const buildModels = (config) => {
191
197
  // TODO: Perform validation
192
198
  // config.validate({ allowed: 'strict' });
193
199
 
194
- module.exports = { config, buildPathways, buildModels };
200
+ export { config, buildPathways, buildModels };
@@ -1,21 +1,4 @@
1
- const { encode, decode } = require('gpt-3-encoder')
2
-
3
- const estimateCharPerToken = (text) => {
4
- // check text only contains asciish characters
5
- if (/^[ -~\t\n\r]+$/.test(text)) {
6
- return 4;
7
- }
8
- return 1;
9
- }
10
-
11
- const getLastNChar = (text, maxLen) => {
12
- if (text.length > maxLen) {
13
- //slice text to avoid maxLen limit but keep the last n characters up to a \n or space to avoid cutting words
14
- text = text.slice(-maxLen);
15
- text = text.slice(text.search(/\s/) + 1);
16
- }
17
- return text;
18
- }
1
+ import { encode, decode } from 'gpt-3-encoder';
19
2
 
20
3
  const getLastNToken = (text, maxTokenLen) => {
21
4
  const encoded = encode(text);
@@ -35,113 +18,120 @@ const getFirstNToken = (text, maxTokenLen) => {
35
18
  return text;
36
19
  }
37
20
 
38
- const isBigChunk = ({ text, maxChunkLength, maxChunkToken }) => {
39
- if (maxChunkLength && text.length > maxChunkLength) {
40
- return true;
41
- }
42
- if (maxChunkToken && encode(text).length > maxChunkToken) {
43
- return true;
44
- }
45
- return false;
46
- }
47
-
48
- const getSemanticChunks = ({ text, maxChunkLength, maxChunkToken,
49
- enableParagraphChunks = true, enableSentenceChunks = true, enableLineChunks = true,
50
- enableWordChunks = true, finallyMergeChunks = true }) => {
51
-
52
- if (maxChunkLength && maxChunkLength <= 0) {
53
- throw new Error(`Invalid maxChunkLength: ${maxChunkLength}`);
54
- }
55
- if (maxChunkToken && maxChunkToken <= 0) {
56
- throw new Error(`Invalid maxChunkToken: ${maxChunkToken}`);
57
- }
58
-
59
- const isBig = (text) => {
60
- return isBigChunk({ text, maxChunkLength, maxChunkToken });
61
- }
62
-
63
- // split into paragraphs
64
- let paragraphChunks = enableParagraphChunks ? text.split('\n\n') : [text];
65
-
66
- // Chunk paragraphs into sentences if needed
67
- const sentenceChunks = enableSentenceChunks ? [] : paragraphChunks;
68
- for (let i = 0; enableSentenceChunks && i < paragraphChunks.length; i++) {
69
- if (isBig(paragraphChunks[i])) { // too long paragraph, chunk into sentences
70
- sentenceChunks.push(...paragraphChunks[i].split('.\n')); // split into sentences
71
- } else {
72
- sentenceChunks.push(paragraphChunks[i]);
21
+ const getSemanticChunks = (text, chunkSize) => {
22
+
23
+ const breakByRegex = (str, regex, preserveWhitespace = false) => {
24
+ const result = [];
25
+ let match;
26
+
27
+ while ((match = regex.exec(str)) !== null) {
28
+ const value = str.slice(0, match.index);
29
+ result.push(value);
30
+
31
+ if (preserveWhitespace || /\S/.test(match[0])) {
32
+ result.push(match[0]);
33
+ }
34
+
35
+ str = str.slice(match.index + match[0].length);
36
+ }
37
+
38
+ if (str) {
39
+ result.push(str);
40
+ }
41
+
42
+ return result.filter(Boolean);
43
+ };
44
+
45
+ const breakByParagraphs = (str) => breakByRegex(str, /[\r\n]+/, true);
46
+ const breakBySentences = (str) => breakByRegex(str, /(?<=[.。؟!\?!\n])\s+/, true);
47
+ const breakByWords = (str) => breakByRegex(str, /(\s,;:.+)/);
48
+
49
+ const createChunks = (tokens) => {
50
+ let chunks = [];
51
+ let currentChunk = '';
52
+
53
+ for (const token of tokens) {
54
+ const currentTokenLength = encode(currentChunk + token).length;
55
+ if (currentTokenLength <= chunkSize) {
56
+ currentChunk += token;
57
+ } else {
58
+ if (currentChunk) {
59
+ chunks.push(currentChunk);
73
60
  }
61
+ currentChunk = token;
62
+ }
74
63
  }
75
-
76
- // Chunk sentences with newlines if needed
77
- const newlineChunks = enableLineChunks ? [] : sentenceChunks;
78
- for (let i = 0; enableLineChunks && i < sentenceChunks.length; i++) {
79
- if (isBig(sentenceChunks[i])) { // too long, split into lines
80
- newlineChunks.push(...sentenceChunks[i].split('\n'));
81
- } else {
82
- newlineChunks.push(sentenceChunks[i]);
83
- }
64
+
65
+ if (currentChunk) {
66
+ chunks.push(currentChunk);
84
67
  }
68
+
69
+ return chunks;
70
+ };
85
71
 
86
- // Chunk sentences into word chunks if needed
87
- let chunks = enableWordChunks ? [] : newlineChunks;
88
- for (let j = 0; enableWordChunks && j < newlineChunks.length; j++) {
89
- if (isBig(newlineChunks[j])) { // too long sentence, chunk into words
90
- const words = newlineChunks[j].split(' ');
91
- // merge words into chunks up to max
92
- let chunk = '';
93
- for (let k = 0; k < words.length; k++) {
94
- if (isBig( chunk + ' ' + words[k]) ) {
95
- chunks.push(chunk.trim());
96
- chunk = '';
97
- }
98
- chunk += words[k] + ' ';
99
- }
100
- if (chunk.length > 0) {
101
- chunks.push(chunk.trim());
102
- }
72
+ const combineChunks = (chunks) => {
73
+ let optimizedChunks = [];
74
+
75
+ for (let i = 0; i < chunks.length; i++) {
76
+ if (i < chunks.length - 1) {
77
+ const combinedChunk = chunks[i] + chunks[i + 1];
78
+ const combinedLen = encode(combinedChunk).length;
79
+
80
+ if (combinedLen <= chunkSize) {
81
+ optimizedChunks.push(combinedChunk);
82
+ i += 1;
103
83
  } else {
104
- chunks.push(newlineChunks[j]);
84
+ optimizedChunks.push(chunks[i]);
105
85
  }
86
+ } else {
87
+ optimizedChunks.push(chunks[i]);
88
+ }
106
89
  }
90
+
91
+ return optimizedChunks;
92
+ };
107
93
 
108
- chunks = chunks.filter(Boolean).map(chunk => '\n' + chunk + '\n'); //filter empty chunks and add newlines
109
-
110
- return finallyMergeChunks ? mergeChunks({ chunks, maxChunkLength, maxChunkToken }) : chunks;
111
- }
94
+ const breakText = (str) => {
95
+ const tokenLength = encode(str).length;
112
96
 
113
- const mergeChunks = ({ chunks, maxChunkLength, maxChunkToken }) => {
114
- const isBig = (text) => {
115
- return isBigChunk({ text, maxChunkLength, maxChunkToken });
97
+ if (tokenLength <= chunkSize) {
98
+ return [str];
116
99
  }
117
100
 
118
- // Merge chunks into maxChunkLength chunks
119
- let mergedChunks = [];
120
- let chunk = '';
121
- for (let i = 0; i < chunks.length; i++) {
122
- if (isBig(chunk + ' ' + chunks[i])) {
123
- mergedChunks.push(chunk);
124
- chunk = '';
101
+ const breakers = [breakByParagraphs, breakBySentences, breakByWords];
102
+
103
+ for (let i = 0; i < breakers.length; i++) {
104
+ const tokens = breakers[i](str);
105
+ if (tokens.length > 1) {
106
+ let chunks = createChunks(tokens);
107
+ chunks = combineChunks(chunks);
108
+ const brokenChunks = chunks.flatMap(breakText);
109
+ if (brokenChunks.every(chunk => encode(chunk).length <= chunkSize)) {
110
+ return brokenChunks;
125
111
  }
126
- chunk += chunks[i];
127
- }
128
- if (chunk.length > 0) {
129
- mergedChunks.push(chunk);
112
+ }
130
113
  }
131
- return mergedChunks;
114
+
115
+ return createChunks([...str]); // Split by characters
116
+ };
117
+
118
+ return breakText(text);
132
119
  }
133
120
 
134
121
 
135
122
  const semanticTruncate = (text, maxLength) => {
136
- if (text.length > maxLength) {
137
- text = getSemanticChunks({ text, maxChunkLength: maxLength })[0].slice(0, maxLength - 3).trim() + "...";
138
- }
123
+ if (text.length <= maxLength) {
139
124
  return text;
140
- }
125
+ }
141
126
 
127
+ const truncatedText = text.slice(0, maxLength - 3).trim();
128
+ const lastSpaceIndex = truncatedText.lastIndexOf(" ");
142
129
 
130
+ return (lastSpaceIndex !== -1)
131
+ ? truncatedText.slice(0, lastSpaceIndex) + "..."
132
+ : truncatedText + "...";
133
+ };
143
134
 
144
- module.exports = {
145
- getSemanticChunks, semanticTruncate, mergeChunks,
146
- getLastNChar, getLastNToken, getFirstNToken, estimateCharPerToken
147
- }
135
+ export {
136
+ getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken
137
+ };
@@ -1,23 +1,21 @@
1
- const { createServer } = require('http');
2
- const {
1
+ import { createServer } from 'http';
2
+ import {
3
3
  ApolloServerPluginDrainHttpServer,
4
4
  ApolloServerPluginLandingPageLocalDefault,
5
- } = require("apollo-server-core");
6
- const { makeExecutableSchema } = require('@graphql-tools/schema');
7
- const { WebSocketServer } = require('ws');
8
- const { useServer } = require('graphql-ws/lib/use/ws');
9
- const express = require('express');
10
-
11
- /// Create apollo graphql server
12
- const Keyv = require("keyv");
13
- const { KeyvAdapter } = require("@apollo/utils.keyvadapter");
14
- const responseCachePlugin = require('apollo-server-plugin-response-cache').default
15
-
16
- const subscriptions = require('./subscriptions');
17
- const { buildLimiters } = require('../lib/request');
18
- const { cancelRequestResolver } = require('./resolver');
19
- const { buildPathways, buildModels } = require('../config');
20
- const { requestState } = require('./requestState');
5
+ } from 'apollo-server-core';
6
+ import { makeExecutableSchema } from '@graphql-tools/schema';
7
+ import { WebSocketServer } from 'ws';
8
+ import { useServer } from 'graphql-ws/lib/use/ws';
9
+ import express from 'express';
10
+ import { ApolloServer } from 'apollo-server-express';
11
+ import Keyv from 'keyv';
12
+ import { KeyvAdapter } from '@apollo/utils.keyvadapter';
13
+ import responseCachePlugin from 'apollo-server-plugin-response-cache';
14
+ import subscriptions from './subscriptions.js';
15
+ import { buildLimiters } from '../lib/request.js';
16
+ import { cancelRequestResolver } from './resolver.js';
17
+ import { buildPathways, buildModels } from '../config.js';
18
+ import { requestState } from './requestState.js';
21
19
 
22
20
  const getPlugins = (config) => {
23
21
  // server plugins
@@ -134,9 +132,9 @@ const getResolvers = (config, pathways) => {
134
132
  }
135
133
 
136
134
  //graphql api build factory method
137
- const build = (config) => {
135
+ const build = async (config) => {
138
136
  // First perform config build
139
- buildPathways(config);
137
+ await buildPathways(config);
140
138
  buildModels(config);
141
139
 
142
140
  // build api limiters
@@ -152,7 +150,6 @@ const build = (config) => {
152
150
 
153
151
  const { plugins, cache } = getPlugins(config);
154
152
 
155
- const { ApolloServer, gql } = require('apollo-server-express');
156
153
  const app = express()
157
154
 
158
155
  const httpServer = createServer(app);
@@ -221,6 +218,6 @@ const build = (config) => {
221
218
  }
222
219
 
223
220
 
224
- module.exports = {
221
+ export {
225
222
  build
226
223
  };
package/graphql/parser.js CHANGED
@@ -31,7 +31,7 @@ const parseNumberedObjectList = (text, format) => {
31
31
  return result;
32
32
  }
33
33
 
34
- module.exports = {
34
+ export {
35
35
  regexParser,
36
36
  parseNumberedList,
37
37
  parseNumberedObjectList,
@@ -1,10 +1,9 @@
1
1
  // PathwayPrompter.js
2
- const OpenAIChatPlugin = require('./plugins/openAIChatPlugin');
3
- const OpenAICompletionPlugin = require('./plugins/openAICompletionPlugin');
4
- const AzureTranslatePlugin = require('./plugins/azureTranslatePlugin');
5
- const OpenAIWhisperPlugin = require('./plugins/openAiWhisperPlugin');
6
- const handlebars = require("handlebars");
7
- const { Exception } = require("handlebars");
2
+ import OpenAIChatPlugin from './plugins/openAIChatPlugin.js';
3
+ import OpenAICompletionPlugin from './plugins/openAICompletionPlugin.js';
4
+ import AzureTranslatePlugin from './plugins/azureTranslatePlugin.js';
5
+ import OpenAIWhisperPlugin from './plugins/openAiWhisperPlugin.js';
6
+ import handlebars from 'handlebars';
8
7
 
9
8
  // register functions that can be called directly in the prompt markdown
10
9
  handlebars.registerHelper('stripHTML', function (value) {
@@ -27,7 +26,7 @@ class PathwayPrompter {
27
26
  const model = config.get('models')[modelName];
28
27
 
29
28
  if (!model) {
30
- throw new Exception(`Model ${modelName} not found in config`);
29
+ throw new handlebars.Exception(`Model ${modelName} not found in config`);
31
30
  }
32
31
 
33
32
  let plugin;
@@ -46,7 +45,7 @@ class PathwayPrompter {
46
45
  plugin = new OpenAIWhisperPlugin(config, pathway);
47
46
  break;
48
47
  default:
49
- throw new Exception(`Unsupported model type: ${model.type}`);
48
+ throw new handlebars.Exception(`Unsupported model type: ${model.type}`);
50
49
  }
51
50
 
52
51
  this.plugin = plugin;
@@ -57,6 +56,6 @@ class PathwayPrompter {
57
56
  }
58
57
  }
59
58
 
60
- module.exports = {
59
+ export {
61
60
  PathwayPrompter
62
61
  };
@@ -1,14 +1,12 @@
1
- const { PathwayPrompter } = require('./pathwayPrompter');
2
- const {
3
- v4: uuidv4,
4
- } = require('uuid');
5
- const pubsub = require('./pubsub');
6
- const { encode } = require('gpt-3-encoder')
7
- const { getFirstNToken, getLastNToken, getSemanticChunks } = require('./chunker');
8
- const { PathwayResponseParser } = require('./pathwayResponseParser');
9
- const { Prompt } = require('./prompt');
10
- const { getv, setv } = require('../lib/keyValueStorageClient');
11
- const { requestState } = require('./requestState');
1
+ import { PathwayPrompter } from './pathwayPrompter.js';
2
+ import { v4 as uuidv4 } from 'uuid';
3
+ import pubsub from './pubsub.js';
4
+ import { encode } from 'gpt-3-encoder';
5
+ import { getFirstNToken, getLastNToken, getSemanticChunks } from './chunker.js';
6
+ import { PathwayResponseParser } from './pathwayResponseParser.js';
7
+ import { Prompt } from './prompt.js';
8
+ import { getv, setv } from '../lib/keyValueStorageClient.js';
9
+ import { requestState } from './requestState.js';
12
10
 
13
11
  const MAX_PREVIOUS_RESULT_TOKEN_LENGTH = 1000;
14
12
 
@@ -125,7 +123,7 @@ class PathwayResolver {
125
123
  // Get saved context from contextId or change contextId if needed
126
124
  const { contextId } = args;
127
125
  this.savedContextId = contextId ? contextId : null;
128
- this.savedContext = contextId ? (getv && await getv(contextId) || {}) : {};
126
+ this.savedContext = contextId ? (getv && (await getv(contextId)) || {}) : {};
129
127
 
130
128
  // Save the context before processing the request
131
129
  const savedContextStr = JSON.stringify(this.savedContext);
@@ -163,7 +161,7 @@ class PathwayResolver {
163
161
  }
164
162
 
165
163
  // chunk the text and return the chunks with newline separators
166
- return getSemanticChunks({ text, maxChunkToken: chunkTokenLength });
164
+ return getSemanticChunks(text, chunkTokenLength);
167
165
  }
168
166
 
169
167
  truncate(str, n) {
@@ -312,4 +310,4 @@ class PathwayResolver {
312
310
  }
313
311
  }
314
312
 
315
- module.exports = { PathwayResolver };
313
+ export { PathwayResolver };
@@ -1,4 +1,4 @@
1
- const { parseNumberedList, parseNumberedObjectList } = require('./parser')
1
+ import { parseNumberedList, parseNumberedObjectList } from './parser.js';
2
2
 
3
3
  class PathwayResponseParser {
4
4
  constructor(pathway) {
@@ -21,4 +21,4 @@ class PathwayResponseParser {
21
21
  }
22
22
  }
23
23
 
24
- module.exports = { PathwayResponseParser };
24
+ export { PathwayResponseParser };
@@ -1,5 +1,5 @@
1
1
  // AzureTranslatePlugin.js
2
- const ModelPlugin = require('./modelPlugin');
2
+ import ModelPlugin from './modelPlugin.js';
3
3
 
4
4
  class AzureTranslatePlugin extends ModelPlugin {
5
5
  constructor(config, pathway) {
@@ -37,4 +37,4 @@ class AzureTranslatePlugin extends ModelPlugin {
37
37
  }
38
38
  }
39
39
 
40
- module.exports = AzureTranslatePlugin;
40
+ export default AzureTranslatePlugin;
@@ -1,7 +1,9 @@
1
1
  // ModelPlugin.js
2
- const handlebars = require('handlebars');
3
- const { request } = require("../../lib/request");
4
- const { encode } = require("gpt-3-encoder");
2
+ import handlebars from 'handlebars';
3
+
4
+ import { request } from '../../lib/request.js';
5
+ import { encode } from 'gpt-3-encoder';
6
+ import { getFirstNToken } from '../chunker.js';
5
7
 
6
8
  const DEFAULT_MAX_TOKENS = 4096;
7
9
  const DEFAULT_PROMPT_TOKEN_RATIO = 0.5;
@@ -38,37 +40,77 @@ class ModelPlugin {
38
40
  this.shouldCache = config.get('enableCache') && (pathway.enableCache || pathway.temperature == 0);
39
41
  }
40
42
 
41
- // Function to remove non-system messages until token length is less than target
42
- removeMessagesUntilTarget = (messages, targetTokenLength) => {
43
- let chatML = this.messagesToChatML(messages);
44
- let tokenLength = encode(chatML).length;
43
+ truncateMessagesToTargetLength = (messages, targetTokenLength) => {
44
+ // Calculate the token length of each message
45
+ const tokenLengths = messages.map((message) => ({
46
+ message,
47
+ tokenLength: encode(this.messagesToChatML([message], false)).length,
48
+ }));
45
49
 
46
- while (tokenLength > targetTokenLength) {
47
- for (let i = 0; i < messages.length; i++) {
48
- if (messages[i].role !== 'system') {
49
- messages.splice(i, 1);
50
- chatML = this.messagesToChatML(messages);
51
- tokenLength = encode(chatML).length;
52
- break;
53
- }
50
+ // Calculate the total token length of all messages
51
+ let totalTokenLength = tokenLengths.reduce(
52
+ (sum, { tokenLength }) => sum + tokenLength,
53
+ 0
54
+ );
55
+
56
+ // If we're already under the target token length, just bail
57
+ if (totalTokenLength <= targetTokenLength) return messages;
58
+
59
+ // Remove and/or truncate messages until the target token length is reached
60
+ let index = 0;
61
+ while (totalTokenLength > targetTokenLength) {
62
+ const message = tokenLengths[index].message;
63
+
64
+ // Skip system messages
65
+ if (message.role === 'system') {
66
+ index++;
67
+ continue;
54
68
  }
55
- if (messages.every(message => message.role === 'system')) {
56
- break; // All remaining messages are 'system', stop removing messages
69
+
70
+ const currentTokenLength = tokenLengths[index].tokenLength;
71
+
72
+ if (totalTokenLength - currentTokenLength >= targetTokenLength) {
73
+ // Remove the message entirely if doing so won't go below the target token length
74
+ totalTokenLength -= currentTokenLength;
75
+ tokenLengths.splice(index, 1);
76
+ } else {
77
+ // Truncate the message to fit the remaining target token length
78
+ const emptyContentLength = encode(this.messagesToChatML([{ ...message, content: '' }], false)).length;
79
+ const otherMessageTokens = totalTokenLength - currentTokenLength;
80
+ const tokensToKeep = targetTokenLength - (otherMessageTokens + emptyContentLength);
81
+
82
+ const truncatedContent = getFirstNToken(message.content, tokensToKeep);
83
+ const truncatedMessage = { ...message, content: truncatedContent };
84
+
85
+ tokenLengths[index] = {
86
+ message: truncatedMessage,
87
+ tokenLength: encode(this.messagesToChatML([ truncatedMessage ], false)).length
88
+ }
89
+
90
+ // calculate the length again to keep us honest
91
+ totalTokenLength = tokenLengths.reduce(
92
+ (sum, { tokenLength }) => sum + tokenLength,
93
+ 0
94
+ );
57
95
  }
58
96
  }
59
- return messages;
60
- }
61
-
97
+
98
+ // Return the modified messages array
99
+ return tokenLengths.map(({ message }) => message);
100
+ };
101
+
62
102
  //convert a messages array to a simple chatML format
63
- messagesToChatML = (messages) => {
103
+ messagesToChatML(messages, addAssistant = true) {
64
104
  let output = "";
65
105
  if (messages && messages.length) {
66
106
  for (let message of messages) {
67
- output += (message.role && message.content) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
107
+ output += (message.role && (message.content || message.content === '')) ? `<|im_start|>${message.role}\n${message.content}\n<|im_end|>\n` : `${message}\n`;
68
108
  }
69
109
  // you always want the assistant to respond next so add a
70
110
  // directive for that
71
- output += "<|im_start|>assistant\n";
111
+ if (addAssistant) {
112
+ output += "<|im_start|>assistant\n";
113
+ }
72
114
  }
73
115
  return output;
74
116
  }
@@ -196,7 +238,7 @@ class ModelPlugin {
196
238
  const responseData = await request({ url, data, params, headers, cache: this.shouldCache }, this.modelName);
197
239
 
198
240
  if (responseData.error) {
199
- throw new Exception(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
241
+ throw new Error(`An error was returned from the server: ${JSON.stringify(responseData.error)}`);
200
242
  }
201
243
 
202
244
  this.logRequestData(data, responseData, prompt);
@@ -205,6 +247,6 @@ class ModelPlugin {
205
247
 
206
248
  }
207
249
 
208
- module.exports = ModelPlugin;
250
+ export default ModelPlugin;
209
251
 
210
252
 
@@ -1,5 +1,5 @@
1
1
  // OpenAIChatPlugin.js
2
- const ModelPlugin = require('./modelPlugin');
2
+ import ModelPlugin from './modelPlugin.js';
3
3
 
4
4
  class OpenAIChatPlugin extends ModelPlugin {
5
5
  constructor(config, pathway) {
@@ -19,7 +19,7 @@ class OpenAIChatPlugin extends ModelPlugin {
19
19
  // Check if the token length exceeds the model's max token length
20
20
  if (tokenLength > modelMaxTokenLength) {
21
21
  // Remove older messages until the token length is within the model's limit
22
- requestMessages = this.removeMessagesUntilTarget(requestMessages, modelMaxTokenLength);
22
+ requestMessages = this.truncateMessagesToTargetLength(requestMessages, modelMaxTokenLength);
23
23
  }
24
24
 
25
25
  const requestParameters = {
@@ -43,4 +43,4 @@ class OpenAIChatPlugin extends ModelPlugin {
43
43
  }
44
44
  }
45
45
 
46
- module.exports = OpenAIChatPlugin;
46
+ export default OpenAIChatPlugin;