@aj-archipelago/cortex 1.0.5 → 1.0.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. package/README.md +2 -2
  2. package/config/default.example.json +4 -2
  3. package/config.js +14 -8
  4. package/helper_apps/WhisperX/.dockerignore +27 -0
  5. package/helper_apps/WhisperX/Dockerfile +31 -0
  6. package/helper_apps/WhisperX/app-ts.py +76 -0
  7. package/helper_apps/WhisperX/app.py +115 -0
  8. package/helper_apps/WhisperX/docker-compose.debug.yml +12 -0
  9. package/helper_apps/WhisperX/docker-compose.yml +10 -0
  10. package/helper_apps/WhisperX/requirements.txt +6 -0
  11. package/index.js +1 -1
  12. package/lib/redisSubscription.js +1 -1
  13. package/package.json +8 -7
  14. package/pathways/basePathway.js +3 -2
  15. package/pathways/index.js +4 -0
  16. package/pathways/summary.js +2 -2
  17. package/pathways/sys_openai_chat.js +19 -0
  18. package/pathways/sys_openai_completion.js +11 -0
  19. package/pathways/test_palm_chat.js +1 -1
  20. package/pathways/transcribe.js +2 -1
  21. package/{graphql → server}/chunker.js +48 -3
  22. package/{graphql → server}/graphql.js +70 -62
  23. package/{graphql → server}/pathwayPrompter.js +14 -17
  24. package/{graphql → server}/pathwayResolver.js +59 -42
  25. package/{graphql → server}/plugins/azureTranslatePlugin.js +2 -2
  26. package/{graphql → server}/plugins/localModelPlugin.js +2 -2
  27. package/{graphql → server}/plugins/modelPlugin.js +8 -10
  28. package/{graphql → server}/plugins/openAiChatPlugin.js +13 -8
  29. package/{graphql → server}/plugins/openAiCompletionPlugin.js +9 -3
  30. package/{graphql → server}/plugins/openAiWhisperPlugin.js +30 -7
  31. package/{graphql → server}/plugins/palmChatPlugin.js +4 -6
  32. package/server/plugins/palmCodeCompletionPlugin.js +46 -0
  33. package/{graphql → server}/plugins/palmCompletionPlugin.js +13 -15
  34. package/server/rest.js +321 -0
  35. package/{graphql → server}/typeDef.js +30 -13
  36. package/tests/chunkfunction.test.js +112 -26
  37. package/tests/config.test.js +1 -1
  38. package/tests/main.test.js +282 -43
  39. package/tests/mocks.js +43 -2
  40. package/tests/modelPlugin.test.js +4 -4
  41. package/tests/openAiChatPlugin.test.js +21 -14
  42. package/tests/openai_api.test.js +147 -0
  43. package/tests/palmChatPlugin.test.js +10 -11
  44. package/tests/palmCompletionPlugin.test.js +3 -4
  45. package/tests/pathwayResolver.test.js +1 -1
  46. package/tests/truncateMessages.test.js +4 -5
  47. package/pathways/completions.js +0 -17
  48. package/pathways/test_oai_chat.js +0 -18
  49. package/pathways/test_oai_cmpl.js +0 -13
  50. package/tests/chunking.test.js +0 -157
  51. package/tests/translate.test.js +0 -126
  52. /package/{graphql → server}/parser.js +0 -0
  53. /package/{graphql → server}/pathwayResponseParser.js +0 -0
  54. /package/{graphql → server}/prompt.js +0 -0
  55. /package/{graphql → server}/pubsub.js +0 -0
  56. /package/{graphql → server}/requestState.js +0 -0
  57. /package/{graphql → server}/resolver.js +0 -0
  58. /package/{graphql → server}/subscriptions.js +0 -0
package/README.md CHANGED
@@ -198,8 +198,8 @@ The core pathway `summary.js` below is implemented using custom pathway logic an
198
198
  // This module exports a prompt that takes an input text and generates a summary using a custom resolver.
199
199
 
200
200
  // Import required modules
201
- import { semanticTruncate } from '../graphql/chunker.js';
202
- import { PathwayResolver } from '../graphql/pathwayResolver.js';
201
+ import { semanticTruncate } from '../server/chunker.js';
202
+ import { PathwayResolver } from '../server/pathwayResolver.js';
203
203
 
204
204
  export default {
205
205
  // The main prompt function that takes the input text and asks to generate a summary.
@@ -58,7 +58,8 @@
58
58
  "Content-Type": "application/json"
59
59
  },
60
60
  "requestsPerSecond": 10,
61
- "maxTokenLength": 2048
61
+ "maxTokenLength": 2048,
62
+ "maxReturnTokens": 1024
62
63
  },
63
64
  "palm-chat": {
64
65
  "type": "PALM-CHAT",
@@ -67,7 +68,8 @@
67
68
  "Content-Type": "application/json"
68
69
  },
69
70
  "requestsPerSecond": 10,
70
- "maxTokenLength": 2048
71
+ "maxTokenLength": 2048,
72
+ "maxReturnTokens": 1024
71
73
  },
72
74
  "local-llama13B": {
73
75
  "type": "LOCAL-CPP-MODEL",
package/config.js CHANGED
@@ -69,20 +69,21 @@ var config = convict({
69
69
  models: {
70
70
  format: Object,
71
71
  default: {
72
- "oai-td3": {
73
- "type": "OPENAI-COMPLETION",
74
- "url": "{{openaiApiUrl}}",
72
+ "oai-gpturbo": {
73
+ "type": "OPENAI-CHAT",
74
+ "url": "https://api.openai.com/v1/chat/completions",
75
75
  "headers": {
76
- "Authorization": "Bearer {{openaiApiKey}}",
76
+ "Authorization": "Bearer {{OPENAI_API_KEY}}",
77
77
  "Content-Type": "application/json"
78
78
  },
79
79
  "params": {
80
- "model": "{{openaiDefaultModel}}"
80
+ "model": "gpt-3.5-turbo"
81
81
  },
82
- "requestsPerSecond": 2,
82
+ "requestsPerSecond": 10,
83
+ "maxTokenLength": 8192
83
84
  },
84
85
  "oai-whisper": {
85
- "type": "OPENAI_WHISPER",
86
+ "type": "OPENAI-WHISPER",
86
87
  "url": "https://api.openai.com/v1/audio/transcriptions",
87
88
  "headers": {
88
89
  "Authorization": "Bearer {{OPENAI_API_KEY}}"
@@ -96,7 +97,7 @@ var config = convict({
96
97
  },
97
98
  openaiDefaultModel: {
98
99
  format: String,
99
- default: 'text-davinci-003',
100
+ default: 'gpt-3.5-turbo',
100
101
  env: 'OPENAI_DEFAULT_MODEL'
101
102
  },
102
103
  openaiApiKey: {
@@ -120,6 +121,11 @@ var config = convict({
120
121
  default: 'null',
121
122
  env: 'WHISPER_MEDIA_API_URL'
122
123
  },
124
+ whisperTSApiUrl: {
125
+ format: String,
126
+ default: 'null',
127
+ env: 'WHISPER_TS_API_URL'
128
+ },
123
129
  gcpServiceAccountKey: {
124
130
  format: String,
125
131
  default: null,
@@ -0,0 +1,27 @@
1
+ **/__pycache__
2
+ **/.venv
3
+ **/.classpath
4
+ **/.dockerignore
5
+ **/.env
6
+ **/.git
7
+ **/.gitignore
8
+ **/.project
9
+ **/.settings
10
+ **/.toolstarget
11
+ **/.vs
12
+ **/.vscode
13
+ **/*.*proj.user
14
+ **/*.dbmdl
15
+ **/*.jfm
16
+ **/bin
17
+ **/charts
18
+ **/docker-compose*
19
+ **/compose*
20
+ **/Dockerfile*
21
+ **/node_modules
22
+ **/npm-debug.log
23
+ **/obj
24
+ **/secrets.dev.yaml
25
+ **/values.dev.yaml
26
+ LICENSE
27
+ README.md
@@ -0,0 +1,31 @@
1
+ # For more information, please refer to https://aka.ms/vscode-docker-python
2
+ FROM python:3.10-slim
3
+
4
+ EXPOSE 8000
5
+
6
+ ## following 3 lines are for installing ffmepg
7
+ RUN apt-get -y update
8
+ RUN apt-get -y upgrade
9
+ RUN apt-get install -y ffmpeg
10
+
11
+ # Keeps Python from generating .pyc files in the container
12
+ ENV PYTHONDONTWRITEBYTECODE=1
13
+
14
+ # Turns off buffering for easier container logging
15
+ ENV PYTHONUNBUFFERED=1
16
+
17
+ # Install pip requirements
18
+ COPY requirements.txt .
19
+ RUN python -m pip install -r requirements.txt
20
+
21
+ WORKDIR /app
22
+ COPY ./models /app/models
23
+ COPY . /app
24
+
25
+ # Creates a non-root user with an explicit UID and adds permission to access the /app folder
26
+ # For more info, please refer to https://aka.ms/vscode-docker-python-configure-containers
27
+ RUN adduser -u 5678 --disabled-password --gecos "" appuser && chown -R appuser /app
28
+ USER appuser
29
+
30
+ # During debugging, this entry point will be overridden. For more information, please refer to https://aka.ms/vscode-docker-python-debug
31
+ CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--timeout", "0", "-k", "uvicorn.workers.UvicornWorker", "app:app"]
@@ -0,0 +1,76 @@
1
+ import uvicorn
2
+ from fastapi import FastAPI
3
+ import stable_whisper
4
+ from uuid import uuid4
5
+ import requests
6
+ import os
7
+
8
+ model_download_root = './models'
9
+ model = stable_whisper.load_model('large', download_root=model_download_root) #large, tiny
10
+
11
+ app = FastAPI()
12
+
13
+ save_directory = "./tmp" # folder for downloaded files
14
+ os.makedirs(save_directory, exist_ok=True)
15
+
16
+
17
+ def download_remote_file(url, save_directory):
18
+ # Generate a unique file name with a UUID
19
+ unique_name = str(uuid4()) + os.path.splitext(url)[-1]
20
+ save_path = os.path.join(save_directory, unique_name)
21
+
22
+ # Download the remote file
23
+ response = requests.get(url, stream=True)
24
+ response.raise_for_status()
25
+
26
+ # Save the downloaded file with the unique name
27
+ with open(save_path, 'wb') as file:
28
+ for chunk in response.iter_content(chunk_size=8192):
29
+ file.write(chunk)
30
+
31
+ return [unique_name, save_path]
32
+
33
+
34
+ def delete_tmp_file(file_path):
35
+ try:
36
+ os.remove(file_path)
37
+ print(f"Temporary file '{file_path}' has been deleted.")
38
+ except OSError as e:
39
+ print(f"Error: {e.strerror}")
40
+
41
+
42
+ async def transcribe(fileurl):
43
+ print(f"Downloading file from: {fileurl}")
44
+ [unique_file_name, save_path] = download_remote_file(
45
+ fileurl, save_directory)
46
+ print(f"Downloaded file saved as: {unique_file_name}")
47
+
48
+ print(f"Transcribing file")
49
+ result = model.transcribe(save_path)
50
+
51
+ srtpath = os.path.join(save_directory, str(uuid4()) + ".srt")
52
+
53
+ print(f"Saving transcription as : {srtpath}")
54
+ result.to_srt_vtt(srtpath, segment_level=False)
55
+
56
+ with open(srtpath, "r") as f:
57
+ srtstr = f.read()
58
+
59
+ # clean up tmp files
60
+ delete_tmp_file(save_path)
61
+ delete_tmp_file(srtpath)
62
+
63
+ print(f"Transcription done.")
64
+ return srtstr
65
+
66
+
67
+ @app.get("/")
68
+ async def root(fileurl: str):
69
+ if not fileurl:
70
+ return "No fileurl given!"
71
+ result = await transcribe(fileurl)
72
+ return result
73
+
74
+ if __name__ == "__main__":
75
+ print("Starting server", flush=True)
76
+ uvicorn.run(app, host="0.0.0.0", port=8000)
@@ -0,0 +1,115 @@
1
+ import uvicorn
2
+ from fastapi import FastAPI, HTTPException, Request
3
+ from uuid import uuid4
4
+ import os
5
+ import requests
6
+ import asyncio
7
+ import whisper
8
+ from whisper.utils import get_writer
9
+ from fastapi.encoders import jsonable_encoder
10
+
11
+ model_download_root = './models'
12
+ model = whisper.load_model("large", download_root=model_download_root) #large, tiny
13
+
14
+ # Create a semaphore with a limit of 1
15
+ semaphore = asyncio.Semaphore(1)
16
+
17
+ app = FastAPI()
18
+
19
+ save_directory = "./tmp" # folder for downloaded files
20
+ os.makedirs(save_directory, exist_ok=True)
21
+
22
+
23
+ def download_remote_file(url, save_directory):
24
+ # Generate a unique file name with a UUID
25
+ unique_name = str(uuid4()) + os.path.splitext(url)[-1]
26
+ save_path = os.path.join(save_directory, unique_name)
27
+
28
+ # Download the remote file
29
+ response = requests.get(url, stream=True)
30
+ response.raise_for_status()
31
+
32
+ # Save the downloaded file with the unique name
33
+ with open(save_path, 'wb') as file:
34
+ for chunk in response.iter_content(chunk_size=8192):
35
+ file.write(chunk)
36
+
37
+ return [unique_name, save_path]
38
+
39
+ def delete_tmp_file(file_path):
40
+ try:
41
+ os.remove(file_path)
42
+ print(f"Temporary file '{file_path}' has been deleted.")
43
+ except OSError as e:
44
+ print(f"Error: {e.strerror}")
45
+
46
+ def modify_segments(result):
47
+ modified_segments = []
48
+
49
+ id = 0
50
+ for segment in result["segments"]:
51
+ for word_info in segment['words']:
52
+ word = word_info['word']
53
+ start = word_info['start']
54
+ end = word_info['end']
55
+
56
+ modified_segment = {} #segment.copy()
57
+ modified_segment['id'] = id
58
+ modified_segment['text'] = word
59
+ modified_segment['start'] = start
60
+ modified_segment['end'] = end
61
+ modified_segments.append(modified_segment)
62
+ id+=1
63
+
64
+ result["segments"] = modified_segments
65
+
66
+ def transcribe(fileurl):
67
+ print(f"Downloading file from: {fileurl}")
68
+ [unique_file_name, save_path] = download_remote_file(
69
+ fileurl, save_directory)
70
+ print(f"Downloaded file saved as: {unique_file_name}")
71
+
72
+ print(f"Transcribing file")
73
+ result = model.transcribe(save_path, word_timestamps=True)
74
+
75
+ modify_segments(result)
76
+
77
+ srtpath = os.path.join(save_directory, str(uuid4()) + ".srt")
78
+
79
+ print(f"Saving transcription as : {srtpath}")
80
+ writer = get_writer("srt", save_directory)
81
+ with open(srtpath, 'w', encoding='utf-8') as file_obj :
82
+ writer.write_result(result, file_obj)
83
+
84
+ with open(srtpath, "r") as f:
85
+ srtstr = f.read()
86
+
87
+ # clean up tmp files
88
+ delete_tmp_file(save_path)
89
+ delete_tmp_file(srtpath)
90
+
91
+ print(f"Transcription done.")
92
+ return srtstr
93
+
94
+
95
+ @app.get("/")
96
+ @app.post("/")
97
+ async def root(request: Request):
98
+ if request.method == "POST":
99
+ body = jsonable_encoder(await request.json())
100
+ fileurl = body.get("fileurl")
101
+ else:
102
+ fileurl = request.query_params.get("fileurl")
103
+ if not fileurl:
104
+ return "No fileurl given!"
105
+
106
+ if semaphore.locked():
107
+ raise HTTPException(status_code=429, detail="Too Many Requests")
108
+
109
+ async with semaphore:
110
+ result = await asyncio.to_thread(transcribe, fileurl)
111
+ return result
112
+
113
+ if __name__ == "__main__":
114
+ print("Starting APPWhisper server", flush=True)
115
+ uvicorn.run(app, host="0.0.0.0", port=8000)
@@ -0,0 +1,12 @@
1
+ version: '3.4'
2
+
3
+ services:
4
+ cortex:
5
+ image: cortex
6
+ build:
7
+ context: .
8
+ dockerfile: ./Dockerfile
9
+ command: ["sh", "-c", "pip install debugpy -t /tmp && python /tmp/debugpy --wait-for-client --listen 0.0.0.0:5678 -m uvicorn helper_apps.WhisperX/app:app --host 0.0.0.0 --port 8000"]
10
+ ports:
11
+ - 8000:8000
12
+ - 5678:5678
@@ -0,0 +1,10 @@
1
+ version: '3.4'
2
+
3
+ services:
4
+ cortex:
5
+ image: cortex
6
+ build:
7
+ context: .
8
+ dockerfile: ./Dockerfile
9
+ ports:
10
+ - 8000:8000
@@ -0,0 +1,6 @@
1
+ # To ensure app dependencies are ported from your virtual environment/host machine into your container, run 'pip freeze > requirements.txt' in the terminal to overwrite this file
2
+ fastapi[all]==0.89.0
3
+ uvicorn[standard]==0.20.0
4
+ gunicorn==20.1.0
5
+ whisper
6
+ stable-ts
package/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { config } from './config.js';
2
- import { build } from './graphql/graphql.js';
2
+ import { build } from './server/graphql.js';
3
3
 
4
4
  export default async (configParams) => {
5
5
  configParams && config.load(configParams);
@@ -1,6 +1,6 @@
1
1
  import Redis from 'ioredis';
2
2
  import { config } from '../config.js';
3
- import pubsub from '../graphql/pubsub.js';
3
+ import pubsub from '../server/pubsub.js';
4
4
 
5
5
  const connectionString = config.get('storageConnectionString');
6
6
  const client = new Redis(connectionString);
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aj-archipelago/cortex",
3
- "version": "1.0.5",
3
+ "version": "1.0.7",
4
4
  "description": "Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.",
5
5
  "repository": {
6
6
  "type": "git",
@@ -28,16 +28,15 @@
28
28
  "type": "module",
29
29
  "homepage": "https://github.com/aj-archipelago/cortex#readme",
30
30
  "dependencies": {
31
- "@apollo/utils.keyvadapter": "^1.1.2",
31
+ "@apollo/server": "^4.7.3",
32
+ "@apollo/server-plugin-response-cache": "^4.1.2",
33
+ "@apollo/utils.keyvadapter": "^3.0.0",
32
34
  "@graphql-tools/schema": "^9.0.12",
33
35
  "@keyv/redis": "^2.5.4",
34
- "apollo-server": "^3.12.0",
35
- "apollo-server-core": "^3.11.1",
36
- "apollo-server-express": "^3.11.1",
37
- "apollo-server-plugin-response-cache": "^3.8.1",
38
36
  "axios": "^1.3.4",
39
37
  "axios-cache-interceptor": "^1.0.1",
40
38
  "bottleneck": "^2.19.5",
39
+ "cheerio": "^1.0.0-rc.12",
41
40
  "compromise": "^14.8.1",
42
41
  "compromise-paragraphs": "^0.1.0",
43
42
  "convict": "^6.2.3",
@@ -61,6 +60,7 @@
61
60
  "dotenv": "^16.0.3",
62
61
  "eslint": "^8.38.0",
63
62
  "eslint-plugin-import": "^2.27.5",
63
+ "got": "^13.0.0",
64
64
  "sinon": "^15.0.3"
65
65
  },
66
66
  "publishConfig": {
@@ -72,6 +72,7 @@
72
72
  ],
73
73
  "require": [
74
74
  "dotenv/config"
75
- ]
75
+ ],
76
+ "concurrency": 1
76
77
  }
77
78
  }
@@ -1,5 +1,5 @@
1
- import { rootResolver, resolver } from '../graphql/resolver.js';
2
- import { typeDef } from '../graphql/typeDef.js';
1
+ import { rootResolver, resolver } from '../server/resolver.js';
2
+ import { typeDef } from '../server/typeDef.js';
3
3
 
4
4
  // all default definitions of a single pathway
5
5
  export default {
@@ -14,6 +14,7 @@ export default {
14
14
  typeDef,
15
15
  rootResolver,
16
16
  resolver,
17
+ inputFormat: 'text',
17
18
  useInputChunking: true,
18
19
  useParallelChunkProcessing: false,
19
20
  useInputSummarization: false,
package/pathways/index.js CHANGED
@@ -6,6 +6,8 @@ import entities from './entities.js';
6
6
  import paraphrase from './paraphrase.js';
7
7
  import sentiment from './sentiment.js';
8
8
  import summary from './summary.js';
9
+ import sys_openai_chat from './sys_openai_chat.js';
10
+ import sys_openai_completion from './sys_openai_completion.js';
9
11
  import test_langchain from './test_langchain.mjs';
10
12
  import test_palm_chat from './test_palm_chat.js';
11
13
  import transcribe from './transcribe.js';
@@ -20,6 +22,8 @@ export {
20
22
  paraphrase,
21
23
  sentiment,
22
24
  summary,
25
+ sys_openai_chat,
26
+ sys_openai_completion,
23
27
  test_langchain,
24
28
  test_palm_chat,
25
29
  transcribe,
@@ -3,8 +3,8 @@
3
3
  // This module exports a prompt that takes an input text and generates a summary using a custom resolver.
4
4
 
5
5
  // Import required modules
6
- import { semanticTruncate } from '../graphql/chunker.js';
7
- import { PathwayResolver } from '../graphql/pathwayResolver.js';
6
+ import { semanticTruncate } from '../server/chunker.js';
7
+ import { PathwayResolver } from '../server/pathwayResolver.js';
8
8
 
9
9
  export default {
10
10
  // The main prompt function that takes the input text and asks to generate a summary.
@@ -0,0 +1,19 @@
1
+ // sys_openai_chat.js
2
+ // default handler for openAI chat endpoints when REST endpoints are enabled
3
+
4
+ import { Prompt } from '../server/prompt.js';
5
+
6
+ export default {
7
+ prompt:
8
+ [
9
+ new Prompt({ messages: [
10
+ "{{messages}}",
11
+ ]}),
12
+ ],
13
+ inputParameters: {
14
+ messages: [],
15
+ },
16
+ model: 'oai-gpturbo',
17
+ useInputChunking: false,
18
+ emulateOpenAIChatModel: '*',
19
+ }
@@ -0,0 +1,11 @@
1
+ // sys_openai_completion.js
2
+ // default handler for openAI completion endpoints when REST endpoints are enabled
3
+
4
+ import { Prompt } from '../server/prompt.js';
5
+
6
+ export default {
7
+ prompt: `{{text}}`,
8
+ model: 'oai-gpturbo',
9
+ useInputChunking: false,
10
+ emulateOpenAICompletionModel: '*',
11
+ }
@@ -1,7 +1,7 @@
1
1
  //test_palm_chat.mjs
2
2
  // Test for handling of prompts in the PaLM chat format for Cortex
3
3
 
4
- import { Prompt } from '../graphql/prompt.js';
4
+ import { Prompt } from '../server/prompt.js';
5
5
 
6
6
  // Description: Have a chat with a bot that uses context to understand the conversation
7
7
  export default {
@@ -5,8 +5,9 @@ export default {
5
5
  file: ``,
6
6
  language: ``,
7
7
  responseFormat: `text`,
8
+ wordTimestamped: false,
8
9
  },
9
- timeout: 1800, // in seconds
10
+ timeout: 3600, // in seconds
10
11
  };
11
12
 
12
13
 
@@ -1,4 +1,5 @@
1
1
  import { encode, decode } from 'gpt-3-encoder';
2
+ import cheerio from 'cheerio';
2
3
 
3
4
  const getLastNToken = (text, maxTokenLen) => {
4
5
  const encoded = encode(text);
@@ -18,8 +19,18 @@ const getFirstNToken = (text, maxTokenLen) => {
18
19
  return text;
19
20
  }
20
21
 
21
- const getSemanticChunks = (text, chunkSize) => {
22
+ const determineTextFormat = (text) => {
23
+ const htmlTagPattern = /<[^>]*>/g;
24
+
25
+ if (htmlTagPattern.test(text)) {
26
+ return 'html';
27
+ }
28
+ else {
29
+ return 'text';
30
+ }
31
+ }
22
32
 
33
+ const getSemanticChunks = (text, chunkSize, inputFormat = 'text') => {
23
34
  const breakByRegex = (str, regex, preserveWhitespace = false) => {
24
35
  const result = [];
25
36
  let match;
@@ -46,6 +57,19 @@ const getSemanticChunks = (text, chunkSize) => {
46
57
  const breakBySentences = (str) => breakByRegex(str, /(?<=[.。؟!?!\n])\s+/, true);
47
58
  const breakByWords = (str) => breakByRegex(str, /(\s,;:.+)/);
48
59
 
60
+ const breakByHtmlElements = (str) => {
61
+ const $ = cheerio.load(str, null, true);
62
+
63
+ // the .filter() call is important to get the text nodes
64
+ // https://stackoverflow.com/questions/54878673/cheerio-get-normal-text-nodes
65
+ let rootNodes = $('body').contents();
66
+
67
+ // create an array with the outerHTML of each node
68
+ const nodes = rootNodes.map((i, el) => $(el).prop('outerHTML') || $(el).text()).get();
69
+
70
+ return nodes;
71
+ };
72
+
49
73
  const createChunks = (tokens) => {
50
74
  let chunks = [];
51
75
  let currentChunk = '';
@@ -115,7 +139,28 @@ const getSemanticChunks = (text, chunkSize) => {
115
139
  return createChunks([...str]); // Split by characters
116
140
  };
117
141
 
118
- return breakText(text);
142
+ if (inputFormat === 'html') {
143
+ const tokens = breakByHtmlElements(text);
144
+ let chunks = createChunks(tokens);
145
+ chunks = combineChunks(chunks);
146
+
147
+ chunks = chunks.flatMap(chunk => {
148
+ if (determineTextFormat(chunk) === 'text') {
149
+ return getSemanticChunks(chunk, chunkSize);
150
+ } else {
151
+ return chunk;
152
+ }
153
+ });
154
+
155
+ if (chunks.some(chunk => encode(chunk).length > chunkSize)) {
156
+ throw new Error('The HTML contains elements that are larger than the chunk size. Please try again with HTML that has smaller elements.');
157
+ }
158
+
159
+ return chunks;
160
+ }
161
+ else {
162
+ return breakText(text);
163
+ }
119
164
  }
120
165
 
121
166
 
@@ -133,5 +178,5 @@ const semanticTruncate = (text, maxLength) => {
133
178
  };
134
179
 
135
180
  export {
136
- getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken
181
+ getSemanticChunks, semanticTruncate, getLastNToken, getFirstNToken, determineTextFormat
137
182
  };