semantic-compressor 1.6__py3-none-any.whl → 1.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- compressor/semantic.py +80 -18
- {semantic_compressor-1.6.dist-info → semantic_compressor-1.7.dist-info}/METADATA +1 -1
- {semantic_compressor-1.6.dist-info → semantic_compressor-1.7.dist-info}/RECORD +6 -6
- {semantic_compressor-1.6.dist-info → semantic_compressor-1.7.dist-info}/LICENSE +0 -0
- {semantic_compressor-1.6.dist-info → semantic_compressor-1.7.dist-info}/WHEEL +0 -0
- {semantic_compressor-1.6.dist-info → semantic_compressor-1.7.dist-info}/top_level.txt +0 -0
compressor/semantic.py
CHANGED
@@ -4,6 +4,7 @@ from sklearn.decomposition import LatentDirichletAllocation
|
|
4
4
|
from sklearn.metrics.pairwise import cosine_similarity
|
5
5
|
from onnxruntime_extensions import get_library_path
|
6
6
|
from compressor.minbpe.regex import RegexTokenizer
|
7
|
+
from concurrent.futures import ProcessPoolExecutor
|
7
8
|
from nltk.tokenize import sent_tokenize
|
8
9
|
from multiprocessing import cpu_count
|
9
10
|
from spellchecker import SpellChecker
|
@@ -31,7 +32,7 @@ english_stopwords = pickle.load(open(english_stopwords_path, "rb"))
|
|
31
32
|
portuguese_stopwords = pickle.load(open(portuguese_stopwords_path, "rb"))
|
32
33
|
langdetect_model = fasttext.load_model(fasttext_model_path)
|
33
34
|
|
34
|
-
embedding_model_cpu_count = os.environ.get('EMBEDDING_MODEL_CPU_COUNT',
|
35
|
+
embedding_model_cpu_count = os.environ.get('EMBEDDING_MODEL_CPU_COUNT', 1)
|
35
36
|
|
36
37
|
_options = ort.SessionOptions()
|
37
38
|
_options.inter_op_num_threads, _options.intra_op_num_threads = embedding_model_cpu_count, embedding_model_cpu_count
|
@@ -263,8 +264,25 @@ def correct_spelling(sentence, detected_lang="pt"):
|
|
263
264
|
|
264
265
|
return " ".join(final_words)
|
265
266
|
|
267
|
+
def preprocess_and_extract_textual_embedding(block, use_stemming, lang):
|
268
|
+
"""
|
269
|
+
Preprocesses a block (lowercasing and stemming if required) and extracts textual embeddings.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
block (str): The text block to process.
|
273
|
+
use_stemming (bool): Whether to apply stemming.
|
274
|
+
lang (str): Language of the text for stemming.
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
np.array: The textual embedding of the processed block.
|
278
|
+
"""
|
279
|
+
processed_block = block.lower() if not use_stemming else stem_text(block.lower(), lang)
|
280
|
+
return extract_textual_embeddings(processed_block)
|
281
|
+
|
282
|
+
|
266
283
|
def find_needle_in_haystack(
|
267
|
-
*, haystack: str, needle: str, block_size
|
284
|
+
*, haystack: str, needle: str, block_size=300,
|
285
|
+
embedding_mode: str = 'both', # 'semantic', 'textual', or 'both'
|
268
286
|
semantic_embeddings_weight: float = 0.3,
|
269
287
|
textual_embeddings_weight: float = 0.7,
|
270
288
|
use_stemming: bool = False,
|
@@ -277,16 +295,21 @@ def find_needle_in_haystack(
|
|
277
295
|
haystack (str): The haystack string.
|
278
296
|
needle (str): The needle string.
|
279
297
|
block_size (int, optional): The size of each string block. The needle will be searched in each block. Defaults to 350.
|
298
|
+
embedding_mode (str, optional): The embedding type to use: 'semantic', 'textual', or 'both'. Defaults to 'both'.
|
280
299
|
semantic_embeddings_weight (float, optional): The weight of the semantic embeddings in the similarity calculation. Defaults to 0.3.
|
281
300
|
textual_embeddings_weight (float, optional): The weight of the textual embeddings in the similarity calculation. Defaults to 0.7.
|
282
301
|
use_stemming (bool, optional): Whether to use stemming for the text. Defaults to False.
|
283
302
|
correct_spelling_needle (bool, optional): Whether to correct the spelling of the needle. Defaults to False.
|
284
|
-
|
303
|
+
|
285
304
|
Returns:
|
286
305
|
str: The string block in the haystack that contains the needle. The size of the needle will be less than or equal to the block size.
|
287
306
|
"""
|
288
307
|
|
289
308
|
try:
|
309
|
+
# Validate embedding_mode
|
310
|
+
if embedding_mode not in {'semantic', 'textual', 'both'}:
|
311
|
+
raise ValueError("Invalid embedding_mode. Choose 'semantic', 'textual', or 'both'.")
|
312
|
+
|
290
313
|
# Split the haystack into blocks
|
291
314
|
blocks = structurize_text(haystack, tokens_per_chunk=block_size)
|
292
315
|
|
@@ -295,33 +318,72 @@ def find_needle_in_haystack(
|
|
295
318
|
if correct_spelling_needle:
|
296
319
|
needle = correct_spelling(needle, lang)
|
297
320
|
|
298
|
-
# Compute the embeddings of the needle
|
299
|
-
needle_semantic_embedding =
|
300
|
-
needle_textual_embedding =
|
321
|
+
# Compute the embeddings of the needle based on the embedding mode
|
322
|
+
needle_semantic_embedding = None
|
323
|
+
needle_textual_embedding = None
|
301
324
|
|
325
|
+
if embedding_mode in {'semantic', 'both'}:
|
326
|
+
needle_semantic_embedding = extract_semantic_embeddings(needle)
|
327
|
+
|
328
|
+
if embedding_mode in {'textual', 'both'}:
|
329
|
+
needle_textual_embedding = extract_textual_embeddings(
|
330
|
+
needle.lower() if not use_stemming else stem_text(needle, lang)
|
331
|
+
)
|
332
|
+
|
302
333
|
# Compute the embeddings of the haystack (each block)
|
303
|
-
haystack_semantic_embeddings = [
|
304
|
-
haystack_textual_embeddings = [
|
305
|
-
|
306
|
-
# Compute the similarity between the needle and each block
|
307
|
-
semantic_similarities = [calculate_similarity(needle_semantic_embedding, block_embedding) for block_embedding in haystack_semantic_embeddings]
|
308
|
-
textual_similarities = [calculate_similarity(needle_textual_embedding, block_embedding) for block_embedding in haystack_textual_embeddings]
|
309
|
-
|
310
|
-
# Sort the blocks by similarity, using the weighted average of semantic and textual similarity
|
311
|
-
sorted_blocks = sorted(zip(blocks, semantic_similarities, textual_similarities), key=lambda x: x[1] * semantic_embeddings_weight + x[2] * textual_embeddings_weight, reverse=True)
|
334
|
+
haystack_semantic_embeddings = []
|
335
|
+
haystack_textual_embeddings = []
|
312
336
|
|
337
|
+
if embedding_mode in {'semantic', 'both'}:
|
338
|
+
with ProcessPoolExecutor() as executor:
|
339
|
+
haystack_semantic_embeddings = list(executor.map(extract_semantic_embeddings, blocks))
|
340
|
+
|
341
|
+
if embedding_mode in {'textual', 'both'}:
|
342
|
+
with ProcessPoolExecutor(max_workers=cpu_count()//1.5) as executor:
|
343
|
+
haystack_textual_embeddings = list(
|
344
|
+
executor.map(preprocess_and_extract_textual_embedding, blocks, [use_stemming]*len(blocks), [lang]*len(blocks))
|
345
|
+
)
|
346
|
+
|
347
|
+
# Compute similarities based on the embedding mode
|
348
|
+
semantic_similarities = []
|
349
|
+
textual_similarities = []
|
350
|
+
|
351
|
+
if embedding_mode in {'semantic', 'both'}:
|
352
|
+
semantic_similarities = [
|
353
|
+
calculate_similarity(needle_semantic_embedding, block_embedding)
|
354
|
+
for block_embedding in haystack_semantic_embeddings
|
355
|
+
]
|
356
|
+
|
357
|
+
if embedding_mode in {'textual', 'both'}:
|
358
|
+
textual_similarities = [
|
359
|
+
calculate_similarity(needle_textual_embedding, block_embedding)
|
360
|
+
for block_embedding in haystack_textual_embeddings
|
361
|
+
]
|
362
|
+
|
363
|
+
# Calculate the overall similarity score
|
364
|
+
if embedding_mode == 'semantic':
|
365
|
+
sorted_blocks = sorted(zip(blocks, semantic_similarities), key=lambda x: x[1], reverse=True)
|
366
|
+
elif embedding_mode == 'textual':
|
367
|
+
sorted_blocks = sorted(zip(blocks, textual_similarities), key=lambda x: x[1], reverse=True)
|
368
|
+
else: # both
|
369
|
+
sorted_blocks = sorted(
|
370
|
+
zip(blocks, semantic_similarities, textual_similarities),
|
371
|
+
key=lambda x: x[1] * semantic_embeddings_weight + x[2] * textual_embeddings_weight,
|
372
|
+
reverse=True
|
373
|
+
)
|
374
|
+
|
313
375
|
# The most similar block is the one that contains the needle
|
314
376
|
most_similar_block = sorted_blocks[0][0]
|
315
377
|
|
316
378
|
# Find the index of the needle in all the blocks
|
317
379
|
most_similar_block_index = blocks.index(most_similar_block)
|
318
380
|
|
319
|
-
start_index = most_similar_block_index-1 if most_similar_block_index > 0 else 0
|
381
|
+
start_index = most_similar_block_index - 1 if most_similar_block_index > 0 else 0
|
320
382
|
|
321
|
-
needle_region = blocks[start_index:most_similar_block_index+2]
|
383
|
+
needle_region = blocks[start_index:most_similar_block_index + 2]
|
322
384
|
|
323
385
|
return ''.join(needle_region).strip()
|
324
386
|
except Exception:
|
325
387
|
traceback.print_exc()
|
326
388
|
|
327
|
-
return haystack
|
389
|
+
return haystack
|
@@ -1,5 +1,5 @@
|
|
1
1
|
compressor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
compressor/semantic.py,sha256=
|
2
|
+
compressor/semantic.py,sha256=CJ5WhWKDuBT19qB_5EvMqWw5mtU2jCqBmOkVWXODLX0,16257
|
3
3
|
compressor/minbpe/__init__.py,sha256=wZ1z2QKkncvGgiZDBc91AP5m7-M-MVenPStKbS6xylE,95
|
4
4
|
compressor/minbpe/base.py,sha256=tTKag04RRFnc4ppoieBbDV0V6thzi_ZvZTlhOYIoY7Q,6881
|
5
5
|
compressor/minbpe/basic.py,sha256=0kD4tU8l2MZegfPaHMfDo5CnaSzb9i1v9tDBy6GwMbg,2883
|
@@ -8,8 +8,8 @@ compressor/resources/embedding_model.onnx,sha256=uLBbAfCGEJTwR1yyiK0bMDroruLr6W5
|
|
8
8
|
compressor/resources/en_stopwords.pkl,sha256=Q2PyGQnphPUs_jxN9NMSqp2EQjYv4b4oMJY2aMYvbSY,1310
|
9
9
|
compressor/resources/lid.176.ftz,sha256=jzRyz-hzintgmejpmcPL-uDc0VaWqsfXc4qAOdtgPoM,938013
|
10
10
|
compressor/resources/pt_stopwords.pkl,sha256=-9bJaxJWjeOFxLHLT9D-rI3XTzGC0iLJfMiwBDnkCYI,1716
|
11
|
-
semantic_compressor-1.
|
12
|
-
semantic_compressor-1.
|
13
|
-
semantic_compressor-1.
|
14
|
-
semantic_compressor-1.
|
15
|
-
semantic_compressor-1.
|
11
|
+
semantic_compressor-1.7.dist-info/LICENSE,sha256=DFRihXonZ3qVRaTrzuXNaDI_-h2jyT2SqWqjtTDHfqI,1067
|
12
|
+
semantic_compressor-1.7.dist-info/METADATA,sha256=I4nO2VQxeIOJAAzs2DMhxmotVV6IvdVMfeheUwAFCTQ,6178
|
13
|
+
semantic_compressor-1.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
14
|
+
semantic_compressor-1.7.dist-info/top_level.txt,sha256=qb2SlKrEmMrQDVrhwxu3Wr7U6JupPXtDGrJpIQr8xSc,11
|
15
|
+
semantic_compressor-1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|