unique_toolkit 0.0.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unique_toolkit/app/init_logging.py +31 -0
- unique_toolkit/app/init_sdk.py +41 -0
- unique_toolkit/app/performance/async_executor.py +186 -0
- unique_toolkit/app/performance/async_wrapper.py +28 -0
- unique_toolkit/app/schemas.py +54 -0
- unique_toolkit/app/verification.py +58 -0
- unique_toolkit/chat/schemas.py +30 -0
- unique_toolkit/chat/service.py +380 -0
- unique_toolkit/chat/state.py +60 -0
- unique_toolkit/chat/utils.py +25 -0
- unique_toolkit/content/schemas.py +90 -0
- unique_toolkit/content/service.py +356 -0
- unique_toolkit/content/utils.py +188 -0
- unique_toolkit/embedding/schemas.py +5 -0
- unique_toolkit/embedding/service.py +89 -0
- unique_toolkit/language_model/infos.py +305 -0
- unique_toolkit/language_model/schemas.py +168 -0
- unique_toolkit/language_model/service.py +261 -0
- unique_toolkit/language_model/utils.py +44 -0
- unique_toolkit-0.5.1.dist-info/METADATA +138 -0
- unique_toolkit-0.5.1.dist-info/RECORD +24 -0
- unique_toolkit-0.0.2.dist-info/METADATA +0 -33
- unique_toolkit-0.0.2.dist-info/RECORD +0 -5
- {unique_toolkit-0.0.2.dist-info → unique_toolkit-0.5.1.dist-info}/LICENSE +0 -0
- {unique_toolkit-0.0.2.dist-info → unique_toolkit-0.5.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,356 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Optional, cast
|
6
|
+
|
7
|
+
import requests
|
8
|
+
import unique_sdk
|
9
|
+
|
10
|
+
from unique_toolkit.app.performance.async_wrapper import async_warning, to_async
|
11
|
+
from unique_toolkit.chat.state import ChatState
|
12
|
+
from unique_toolkit.content.schemas import (
|
13
|
+
Content,
|
14
|
+
ContentChunk,
|
15
|
+
ContentSearchType,
|
16
|
+
ContentUploadInput,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
class ContentService:
|
21
|
+
def __init__(self, state: ChatState, logger: Optional[logging.Logger] = None):
|
22
|
+
self.state = state
|
23
|
+
self.logger = logger or logging.getLogger(__name__)
|
24
|
+
|
25
|
+
def search_content_chunks(
|
26
|
+
self,
|
27
|
+
search_string: str,
|
28
|
+
search_type: ContentSearchType,
|
29
|
+
limit: int,
|
30
|
+
scope_ids: Optional[list[str]] = None,
|
31
|
+
) -> list[ContentChunk]:
|
32
|
+
"""
|
33
|
+
Performs a synchronous search for content chunks in the knowledge base.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
search_string (str): The search string.
|
37
|
+
search_type (ContentSearchType): The type of search to perform.
|
38
|
+
limit (int): The maximum number of results to return.
|
39
|
+
scope_ids (Optional[list[str]]): The scope IDs. Defaults to None.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
list[ContentChunk]: The search results.
|
43
|
+
"""
|
44
|
+
return self._trigger_search_content_chunks(
|
45
|
+
search_string,
|
46
|
+
search_type,
|
47
|
+
limit,
|
48
|
+
scope_ids,
|
49
|
+
)
|
50
|
+
|
51
|
+
@to_async
|
52
|
+
@async_warning
|
53
|
+
def async_search_content_chunks(
|
54
|
+
self,
|
55
|
+
search_string: str,
|
56
|
+
search_type: ContentSearchType,
|
57
|
+
limit: int,
|
58
|
+
scope_ids: Optional[list[str]],
|
59
|
+
):
|
60
|
+
"""
|
61
|
+
Performs an asynchronous search for content chunks in the knowledge base.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
search_string (str): The search string.
|
65
|
+
search_type (ContentSearchType): The type of search to perform.
|
66
|
+
limit (int): The maximum number of results to return.
|
67
|
+
scope_ids (Optional[list[str]]): The scope IDs. Defaults to [].
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
list[ContentChunk]: The search results.
|
71
|
+
"""
|
72
|
+
return self._trigger_search_content_chunks(
|
73
|
+
search_string,
|
74
|
+
search_type,
|
75
|
+
limit,
|
76
|
+
scope_ids,
|
77
|
+
)
|
78
|
+
|
79
|
+
def _trigger_search_content_chunks(
|
80
|
+
self,
|
81
|
+
search_string: str,
|
82
|
+
search_type: ContentSearchType,
|
83
|
+
limit: int,
|
84
|
+
scope_ids: Optional[list[str]],
|
85
|
+
) -> list[ContentChunk]:
|
86
|
+
scope_ids = scope_ids or self.state.scope_ids or []
|
87
|
+
|
88
|
+
if not scope_ids:
|
89
|
+
self.logger.warning("No scope IDs provided for search.")
|
90
|
+
|
91
|
+
try:
|
92
|
+
searches = unique_sdk.Search.create(
|
93
|
+
user_id=self.state.user_id,
|
94
|
+
company_id=self.state.company_id,
|
95
|
+
chatId=self.state.chat_id,
|
96
|
+
searchString=search_string,
|
97
|
+
searchType=search_type.name,
|
98
|
+
scopeIds=scope_ids,
|
99
|
+
limit=limit,
|
100
|
+
chatOnly=self.state.chat_only,
|
101
|
+
)
|
102
|
+
except Exception as e:
|
103
|
+
self.logger.error(f"Error while searching content chunks: {e}")
|
104
|
+
raise e
|
105
|
+
|
106
|
+
def map_to_content_chunks(searches: list[unique_sdk.Search]):
|
107
|
+
return [ContentChunk(**search) for search in searches]
|
108
|
+
|
109
|
+
# TODO change return type in sdk from Search to list[Search]
|
110
|
+
searches = cast(list[unique_sdk.Search], searches)
|
111
|
+
return map_to_content_chunks(searches)
|
112
|
+
|
113
|
+
def search_contents(
|
114
|
+
self,
|
115
|
+
where: dict,
|
116
|
+
) -> list[Content]:
|
117
|
+
"""
|
118
|
+
Performs a search in the knowledge base by filter (and not a smilarity search)
|
119
|
+
This function loads complete content of the files from the knowledge base in contrast to search_content_chunks.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
where (dict): The search criteria.
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
list[Content]: The search results.
|
126
|
+
"""
|
127
|
+
return self._trigger_search_contents(where)
|
128
|
+
|
129
|
+
@to_async
|
130
|
+
@async_warning
|
131
|
+
def async_search_contents(
|
132
|
+
self,
|
133
|
+
where: dict,
|
134
|
+
) -> list[Content]:
|
135
|
+
"""
|
136
|
+
Performs an asynchronous search for content files in the knowledge base by filter.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
where (dict): The search criteria.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
list[Content]: The search results.
|
143
|
+
"""
|
144
|
+
return self._trigger_search_contents(where)
|
145
|
+
|
146
|
+
def _trigger_search_contents(
|
147
|
+
self,
|
148
|
+
where: dict,
|
149
|
+
) -> list[Content]:
|
150
|
+
def map_content_chunk(content_chunk):
|
151
|
+
return ContentChunk(
|
152
|
+
id=content_chunk["id"],
|
153
|
+
text=content_chunk["text"],
|
154
|
+
start_page=content_chunk["startPage"],
|
155
|
+
end_page=content_chunk["endPage"],
|
156
|
+
order=content_chunk["order"],
|
157
|
+
)
|
158
|
+
|
159
|
+
def map_content(content):
|
160
|
+
return Content(
|
161
|
+
id=content["id"],
|
162
|
+
key=content["key"],
|
163
|
+
title=content["title"],
|
164
|
+
url=content["url"],
|
165
|
+
chunks=[map_content_chunk(chunk) for chunk in content["chunks"]],
|
166
|
+
)
|
167
|
+
|
168
|
+
def map_contents(contents):
|
169
|
+
return [map_content(content) for content in contents]
|
170
|
+
|
171
|
+
try:
|
172
|
+
contents = unique_sdk.Content.search(
|
173
|
+
user_id=self.state.user_id,
|
174
|
+
company_id=self.state.company_id,
|
175
|
+
chatId=self.state.chat_id,
|
176
|
+
# TODO add type parameter
|
177
|
+
where=where, # type: ignore
|
178
|
+
)
|
179
|
+
except Exception as e:
|
180
|
+
self.logger.error(f"Error while searching contents: {e}")
|
181
|
+
raise e
|
182
|
+
|
183
|
+
return map_contents(contents)
|
184
|
+
|
185
|
+
def upload_content(
|
186
|
+
self,
|
187
|
+
path_to_content: str,
|
188
|
+
content_name: str,
|
189
|
+
mime_type: str,
|
190
|
+
scope_id: Optional[str] = None,
|
191
|
+
chat_id: Optional[str] = None,
|
192
|
+
):
|
193
|
+
"""
|
194
|
+
Uploads content to the knowledge base.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
path_to_content (str): The path to the content to upload.
|
198
|
+
content_name (str): The name of the content.
|
199
|
+
mime_type (str): The MIME type of the content.
|
200
|
+
scope_id (Optional[str]): The scope ID. Defaults to None.
|
201
|
+
chat_id (Optional[str]): The chat ID. Defaults to None.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
Content: The uploaded content.
|
205
|
+
"""
|
206
|
+
|
207
|
+
byte_size = os.path.getsize(path_to_content)
|
208
|
+
created_content = self._trigger_upsert_content(
|
209
|
+
input=ContentUploadInput(
|
210
|
+
key=content_name, title=content_name, mime_type=mime_type
|
211
|
+
),
|
212
|
+
scope_id=scope_id,
|
213
|
+
chat_id=chat_id,
|
214
|
+
)
|
215
|
+
|
216
|
+
write_url = created_content.write_url
|
217
|
+
|
218
|
+
if not write_url:
|
219
|
+
error_msg = "Write url for uploaded content is missing"
|
220
|
+
self.logger.error(error_msg)
|
221
|
+
raise ValueError(error_msg)
|
222
|
+
|
223
|
+
# upload to azure blob storage SAS url uploadUrl the pdf file translatedFile make sure it is treated as a application/pdf
|
224
|
+
with open(path_to_content, "rb") as file:
|
225
|
+
requests.put(
|
226
|
+
url=write_url,
|
227
|
+
data=file,
|
228
|
+
headers={
|
229
|
+
"X-Ms-Blob-Content-Type": mime_type,
|
230
|
+
"X-Ms-Blob-Type": "BlockBlob",
|
231
|
+
},
|
232
|
+
)
|
233
|
+
|
234
|
+
read_url = created_content.read_url
|
235
|
+
|
236
|
+
if not read_url:
|
237
|
+
error_msg = "Read url for uploaded content is missing"
|
238
|
+
self.logger.error(error_msg)
|
239
|
+
raise ValueError(error_msg)
|
240
|
+
|
241
|
+
if chat_id:
|
242
|
+
self._trigger_upsert_content(
|
243
|
+
input=ContentUploadInput(
|
244
|
+
key=content_name,
|
245
|
+
title=content_name,
|
246
|
+
mime_type=mime_type,
|
247
|
+
byte_size=byte_size,
|
248
|
+
),
|
249
|
+
content_url=read_url,
|
250
|
+
chat_id=chat_id,
|
251
|
+
)
|
252
|
+
else:
|
253
|
+
self._trigger_upsert_content(
|
254
|
+
input=ContentUploadInput(
|
255
|
+
key=content_name,
|
256
|
+
title=content_name,
|
257
|
+
mime_type=mime_type,
|
258
|
+
byte_size=byte_size,
|
259
|
+
),
|
260
|
+
content_url=read_url,
|
261
|
+
scope_id=scope_id,
|
262
|
+
)
|
263
|
+
|
264
|
+
return created_content
|
265
|
+
|
266
|
+
def _trigger_upsert_content(
|
267
|
+
self,
|
268
|
+
input: ContentUploadInput,
|
269
|
+
scope_id: Optional[str] = None,
|
270
|
+
chat_id: Optional[str] = None,
|
271
|
+
content_url: Optional[str] = None,
|
272
|
+
):
|
273
|
+
if not chat_id and not scope_id:
|
274
|
+
raise ValueError("chat_id or scope_id must be provided")
|
275
|
+
|
276
|
+
try:
|
277
|
+
if input.byte_size:
|
278
|
+
input_json = {
|
279
|
+
"key": input.key,
|
280
|
+
"title": input.title,
|
281
|
+
"mimeType": input.mime_type,
|
282
|
+
"byteSize": input.byte_size,
|
283
|
+
}
|
284
|
+
else:
|
285
|
+
input_json = {
|
286
|
+
"key": input.key,
|
287
|
+
"title": input.title,
|
288
|
+
"mimeType": input.mime_type,
|
289
|
+
}
|
290
|
+
content = unique_sdk.Content.upsert(
|
291
|
+
user_id=self.state.user_id,
|
292
|
+
company_id=self.state.company_id,
|
293
|
+
input=input_json, # type: ignore
|
294
|
+
fileUrl=content_url,
|
295
|
+
scopeId=scope_id,
|
296
|
+
chatId=chat_id,
|
297
|
+
sourceOwnerType=None, # type: ignore
|
298
|
+
storeInternally=False,
|
299
|
+
)
|
300
|
+
return Content(**content)
|
301
|
+
except Exception as e:
|
302
|
+
self.logger.error(f"Error while uploading content: {e}")
|
303
|
+
raise e
|
304
|
+
|
305
|
+
def download_content(
|
306
|
+
self,
|
307
|
+
content_id: str,
|
308
|
+
content_name: str,
|
309
|
+
chat_id: Optional[str] = None,
|
310
|
+
) -> Path:
|
311
|
+
"""
|
312
|
+
Downloads content to temporary directory
|
313
|
+
|
314
|
+
Args:
|
315
|
+
content_id (str): The id of the uploaded content.
|
316
|
+
content_name (str): The name of the uploaded content.
|
317
|
+
chat_id (Optional[str]): The chat_id, defaults to None.
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
content_path: The path to the downloaded content in the temporary directory.
|
321
|
+
|
322
|
+
Raises:
|
323
|
+
Exception: If the download fails.
|
324
|
+
"""
|
325
|
+
|
326
|
+
print("download chat id", chat_id)
|
327
|
+
|
328
|
+
url = f"{unique_sdk.api_base}/content/{content_id}/file"
|
329
|
+
if chat_id:
|
330
|
+
url = f"{url}?chatId={chat_id}"
|
331
|
+
|
332
|
+
# Create a random directory inside /tmp
|
333
|
+
random_dir = tempfile.mkdtemp(dir="/tmp")
|
334
|
+
|
335
|
+
# Create the full file path
|
336
|
+
content_path = Path(random_dir) / content_name
|
337
|
+
|
338
|
+
# Download the file and save it to the random directory
|
339
|
+
headers = {
|
340
|
+
"x-api-version": unique_sdk.api_version,
|
341
|
+
"x-app-id": unique_sdk.app_id,
|
342
|
+
"x-user-id": self.state.user_id,
|
343
|
+
"x-company-id": self.state.company_id,
|
344
|
+
"Authorization": "Bearer %s" % (unique_sdk.api_key,),
|
345
|
+
}
|
346
|
+
|
347
|
+
response = requests.get(url, headers=headers)
|
348
|
+
if response.status_code == 200:
|
349
|
+
with open(content_path, "wb") as file:
|
350
|
+
file.write(response.content)
|
351
|
+
else:
|
352
|
+
error_msg = f"Error downloading file: Status code {response.status_code}"
|
353
|
+
self.logger.error(error_msg)
|
354
|
+
raise Exception(error_msg)
|
355
|
+
|
356
|
+
return content_path
|
@@ -0,0 +1,188 @@
|
|
1
|
+
import re
|
2
|
+
|
3
|
+
import tiktoken
|
4
|
+
|
5
|
+
from unique_toolkit.content.schemas import (
|
6
|
+
ContentChunk,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
def _map_content_id_to_chunks(content_chunks: list[ContentChunk]):
|
11
|
+
doc_id_to_chunks: dict[str, list[ContentChunk]] = {}
|
12
|
+
for chunk in content_chunks:
|
13
|
+
source_chunks = doc_id_to_chunks.get(chunk.id)
|
14
|
+
if not source_chunks:
|
15
|
+
doc_id_to_chunks[chunk.id] = [chunk]
|
16
|
+
else:
|
17
|
+
source_chunks.append(chunk)
|
18
|
+
return doc_id_to_chunks
|
19
|
+
|
20
|
+
|
21
|
+
def sort_content_chunks(content_chunks: list[ContentChunk]):
|
22
|
+
"""
|
23
|
+
Sorts the content chunks based on their 'order' in the original content.
|
24
|
+
This function sorts the search results based on their 'order' in ascending order.
|
25
|
+
It also performs text modifications by replacing the string within the tags <|/content|>
|
26
|
+
with 'text part {order}' and removing any <|info|> tags (Which is useful in referencing the chunk).
|
27
|
+
Parameters:
|
28
|
+
- content_chunks (list): A list of ContentChunkt objects.
|
29
|
+
Returns:
|
30
|
+
- list: A list of ContentChunk objects sorted according to their order.
|
31
|
+
"""
|
32
|
+
doc_id_to_chunks = _map_content_id_to_chunks(content_chunks)
|
33
|
+
sorted_chunks: list[ContentChunk] = []
|
34
|
+
for chunks in doc_id_to_chunks.values():
|
35
|
+
chunks.sort(key=lambda x: x.order)
|
36
|
+
for i, s in enumerate(chunks):
|
37
|
+
s.text = re.sub(
|
38
|
+
r"<\|/content\|>", f" text part {s.order}<|/content|>", s.text
|
39
|
+
)
|
40
|
+
s.text = re.sub(r"<\|info\|>(.*?)<\|\/info\|>", "", s.text)
|
41
|
+
pages_postfix = _generate_pages_postfix([s])
|
42
|
+
s.key = s.key + pages_postfix if s.key else s.key
|
43
|
+
s.title = s.title + pages_postfix if s.title else s.title
|
44
|
+
sorted_chunks.extend(chunks)
|
45
|
+
return sorted_chunks
|
46
|
+
|
47
|
+
|
48
|
+
def merge_content_chunks(content_chunks: list[ContentChunk]):
|
49
|
+
"""
|
50
|
+
Merges multiple search results based on their 'id', removing redundant content and info markers.
|
51
|
+
|
52
|
+
This function groups search results by their 'id' and then concatenates their texts,
|
53
|
+
cleaning up any content or info markers in subsequent chunks beyond the first one.
|
54
|
+
|
55
|
+
Parameters:
|
56
|
+
- content_chunks (list): A list of objects, each representing a search result with 'id' and 'text' keys.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
- list: A list of objects with merged texts for each unique 'id'.
|
60
|
+
"""
|
61
|
+
|
62
|
+
doc_id_to_chunks = _map_content_id_to_chunks(content_chunks)
|
63
|
+
merged_chunks: list[ContentChunk] = []
|
64
|
+
for chunks in doc_id_to_chunks.values():
|
65
|
+
chunks.sort(key=lambda x: x.order)
|
66
|
+
for i, s in enumerate(chunks):
|
67
|
+
## skip first element
|
68
|
+
if i > 0:
|
69
|
+
## replace the string within the tags <|content|>...<|/content|> and <|info|> and <|/info|>
|
70
|
+
s.text = re.sub(r"<\|content\|>(.*?)<\|\/content\|>", "", s.text)
|
71
|
+
s.text = re.sub(r"<\|info\|>(.*?)<\|\/info\|>", "", s.text)
|
72
|
+
|
73
|
+
pages_postfix = _generate_pages_postfix(chunks)
|
74
|
+
chunks[0].text = "\n".join(str(s.text) for s in chunks)
|
75
|
+
chunks[0].key = (
|
76
|
+
chunks[0].key + pages_postfix if chunks[0].key else chunks[0].key
|
77
|
+
)
|
78
|
+
chunks[0].title = (
|
79
|
+
chunks[0].title + pages_postfix if chunks[0].title else chunks[0].title
|
80
|
+
)
|
81
|
+
chunks[0].end_page = chunks[-1].end_page
|
82
|
+
merged_chunks.append(chunks[0])
|
83
|
+
|
84
|
+
return merged_chunks
|
85
|
+
|
86
|
+
|
87
|
+
def _generate_pages_postfix(chunks: list[ContentChunk]) -> str:
|
88
|
+
"""
|
89
|
+
Generates a postfix string of page numbers from a list of source objects.
|
90
|
+
Each source object contains startPage and endPage numbers. The function
|
91
|
+
compiles a list of all unique page numbers greater than 0 from all chunks,
|
92
|
+
and then returns them as a string prefixed with " : " if there are any.
|
93
|
+
|
94
|
+
Parameters:
|
95
|
+
- chunks (list): A list of objects with 'startPage' and 'endPage' keys.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
- string: A string of page numbers separated by commas, prefixed with " : ".
|
99
|
+
"""
|
100
|
+
|
101
|
+
def gen_all_numbers_in_between(start, end) -> list[int]:
|
102
|
+
"""
|
103
|
+
Generates a list of all numbers between start and end, inclusive.
|
104
|
+
If start or end is -1, it behaves as follows:
|
105
|
+
- If both start and end are -1, it returns an empty list.
|
106
|
+
- If only end is -1, it returns a list containing only the start.
|
107
|
+
- If start is -1, it returns an empty list.
|
108
|
+
|
109
|
+
Parameters:
|
110
|
+
- start (int): The starting page number.
|
111
|
+
- end (int): The ending page number.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
- list: A list of numbers from start to end, inclusive.
|
115
|
+
"""
|
116
|
+
if start == -1 and end == -1:
|
117
|
+
return []
|
118
|
+
if end == -1:
|
119
|
+
return [start]
|
120
|
+
if start == -1:
|
121
|
+
return []
|
122
|
+
return list(range(start, end + 1))
|
123
|
+
|
124
|
+
page_numbers_array = [
|
125
|
+
gen_all_numbers_in_between(s.start_page, s.end_page) for s in chunks
|
126
|
+
]
|
127
|
+
page_numbers = [number for sublist in page_numbers_array for number in sublist]
|
128
|
+
page_numbers = [p for p in page_numbers if p > 0]
|
129
|
+
page_numbers = sorted(set(page_numbers))
|
130
|
+
pages_postfix = (
|
131
|
+
" : " + ",".join(str(p) for p in page_numbers) if page_numbers else ""
|
132
|
+
)
|
133
|
+
return pages_postfix
|
134
|
+
|
135
|
+
|
136
|
+
def pick_content_chunks_for_token_window(
|
137
|
+
content_chunks: list[ContentChunk],
|
138
|
+
token_limit: int,
|
139
|
+
encoding_model="cl100k_base",
|
140
|
+
):
|
141
|
+
"""
|
142
|
+
Selects and returns a list of search results that fit within a specified token limit.
|
143
|
+
|
144
|
+
This function iterates over a list of search results, each with a 'text' field, and
|
145
|
+
encodes the text using a predefined encoding scheme. It accumulates search results
|
146
|
+
until the token limit is reached or exceeded.
|
147
|
+
|
148
|
+
Parameters:
|
149
|
+
- content_chunks (list): A list of dictionaries, each containing a 'text' key with string value.
|
150
|
+
- token_limit (int): The maximum number of tokens to include in the output.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
- list: A list of dictionaries representing the search results that fit within the token limit.
|
154
|
+
"""
|
155
|
+
picked_chunks: list[ContentChunk] = []
|
156
|
+
token_count = 0
|
157
|
+
|
158
|
+
encoding = tiktoken.get_encoding(encoding_model)
|
159
|
+
|
160
|
+
for chunk in content_chunks:
|
161
|
+
try:
|
162
|
+
searchtoken_count = len(encoding.encode(chunk.text))
|
163
|
+
except Exception:
|
164
|
+
searchtoken_count = 0
|
165
|
+
if token_count + searchtoken_count > token_limit:
|
166
|
+
break
|
167
|
+
|
168
|
+
picked_chunks.append(chunk)
|
169
|
+
token_count += searchtoken_count
|
170
|
+
|
171
|
+
return picked_chunks
|
172
|
+
|
173
|
+
|
174
|
+
def count_tokens(text, encoding_model="cl100k_base") -> int:
|
175
|
+
"""
|
176
|
+
Counts the number of tokens in the provided text.
|
177
|
+
|
178
|
+
This function encodes the input text using a predefined encoding scheme
|
179
|
+
and returns the number of tokens in the encoded text.
|
180
|
+
|
181
|
+
Parameters:
|
182
|
+
- text (str): The text to count tokens for.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
- int: The number of tokens in the text.
|
186
|
+
"""
|
187
|
+
encoding = tiktoken.get_encoding(encoding_model)
|
188
|
+
return len(encoding.encode(text))
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import unique_sdk
|
6
|
+
|
7
|
+
from unique_toolkit.app.performance.async_wrapper import async_warning, to_async
|
8
|
+
from unique_toolkit.chat.state import ChatState
|
9
|
+
from unique_toolkit.embedding.schemas import Embeddings
|
10
|
+
|
11
|
+
|
12
|
+
class EmbeddingService:
|
13
|
+
def __init__(self, state: ChatState, logger: Optional[logging.Logger] = None):
|
14
|
+
self.state = state
|
15
|
+
self.logger = logger or logging.getLogger(__name__)
|
16
|
+
|
17
|
+
_DEFAULT_TIMEOUT = 600_000
|
18
|
+
|
19
|
+
def embed_texts(
|
20
|
+
self,
|
21
|
+
texts: list[str],
|
22
|
+
timeout: int = _DEFAULT_TIMEOUT,
|
23
|
+
) -> Embeddings:
|
24
|
+
"""
|
25
|
+
Embed text.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
text (str): The text to embed.
|
29
|
+
timeout (int): The timeout in milliseconds. Defaults to None.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Embeddings: The Embedding object.
|
33
|
+
|
34
|
+
Raises:
|
35
|
+
Exception: If an error occurs.
|
36
|
+
"""
|
37
|
+
return self._trigger_embed_texts(
|
38
|
+
texts=texts,
|
39
|
+
timeout=timeout,
|
40
|
+
)
|
41
|
+
|
42
|
+
@to_async
|
43
|
+
@async_warning
|
44
|
+
def async_embed_texts(
|
45
|
+
self,
|
46
|
+
texts: list[str],
|
47
|
+
timeout: int = _DEFAULT_TIMEOUT,
|
48
|
+
) -> Embeddings:
|
49
|
+
"""
|
50
|
+
Embed text asynchronously.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
text (str): The text to embed.
|
54
|
+
timeout (int): The timeout in milliseconds. Defaults to None.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Embeddings: The Embedding object.
|
58
|
+
|
59
|
+
Raises:
|
60
|
+
Exception: If an error occurs.
|
61
|
+
"""
|
62
|
+
return self._trigger_embed_texts(
|
63
|
+
texts=texts,
|
64
|
+
timeout=timeout,
|
65
|
+
)
|
66
|
+
|
67
|
+
def _trigger_embed_texts(self, texts: list[str], timeout: int) -> Embeddings:
|
68
|
+
request = {
|
69
|
+
"user_id": self.state.user_id,
|
70
|
+
"company_id": self.state.company_id,
|
71
|
+
"texts": texts,
|
72
|
+
"timeout": timeout,
|
73
|
+
}
|
74
|
+
try:
|
75
|
+
response = unique_sdk.Embeddings.create(**request)
|
76
|
+
return Embeddings(**response)
|
77
|
+
except Exception as e:
|
78
|
+
self.logger.error(f"Error embedding texts: {e}")
|
79
|
+
raise e
|
80
|
+
|
81
|
+
def get_cosine_similarity(
|
82
|
+
self,
|
83
|
+
embedding_1: list[float],
|
84
|
+
embedding_2: list[float],
|
85
|
+
) -> float:
|
86
|
+
"""Get cosine similarity."""
|
87
|
+
return np.dot(embedding_1, embedding_2) / (
|
88
|
+
np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2)
|
89
|
+
)
|