mb-rag 1.1.46__py3-none-any.whl → 1.1.56.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/basic.py +306 -0
- mb_rag/chatbot/chains.py +206 -206
- mb_rag/chatbot/conversation.py +185 -0
- mb_rag/chatbot/prompts.py +58 -58
- mb_rag/rag/embeddings.py +810 -753
- mb_rag/utils/all_data_extract.py +64 -64
- mb_rag/utils/bounding_box.py +231 -231
- mb_rag/utils/document_extract.py +354 -354
- mb_rag/utils/extra.py +73 -73
- mb_rag/utils/pdf_extract.py +428 -428
- mb_rag/version.py +1 -1
- {mb_rag-1.1.46.dist-info → mb_rag-1.1.56.post0.dist-info}/METADATA +11 -11
- mb_rag-1.1.56.post0.dist-info/RECORD +19 -0
- mb_rag/chatbot/basic.py +0 -644
- mb_rag-1.1.46.dist-info/RECORD +0 -18
- {mb_rag-1.1.46.dist-info → mb_rag-1.1.56.post0.dist-info}/WHEEL +0 -0
- {mb_rag-1.1.46.dist-info → mb_rag-1.1.56.post0.dist-info}/top_level.txt +0 -0
mb_rag/rag/embeddings.py
CHANGED
|
@@ -1,753 +1,810 @@
|
|
|
1
|
-
"""
|
|
2
|
-
RAG (Retrieval-Augmented Generation) Embeddings Module
|
|
3
|
-
|
|
4
|
-
This module provides functionality for generating and managing embeddings for RAG models.
|
|
5
|
-
It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
|
|
6
|
-
features for text processing, embedding generation, vector store management, and
|
|
7
|
-
conversation handling.
|
|
8
|
-
|
|
9
|
-
Example Usage:
|
|
10
|
-
```python
|
|
11
|
-
# Initialize embedding generator
|
|
12
|
-
em_gen = embedding_generator(
|
|
13
|
-
model="openai",
|
|
14
|
-
model_type="text-embedding-3-small",
|
|
15
|
-
vector_store_type="chroma"
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
# Generate embeddings from text
|
|
19
|
-
em_gen.generate_text_embeddings(
|
|
20
|
-
text_data_path=['./data/text.txt'],
|
|
21
|
-
chunk_size=500,
|
|
22
|
-
chunk_overlap=5,
|
|
23
|
-
folder_save_path='./embeddings'
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
# Load embeddings and create retriever
|
|
27
|
-
em_loading = em_gen.load_embeddings('./embeddings')
|
|
28
|
-
em_retriever = em_gen.load_retriever(
|
|
29
|
-
'./embeddings',
|
|
30
|
-
search_params=[{"k": 2, "score_threshold": 0.1}]
|
|
31
|
-
)
|
|
32
|
-
|
|
33
|
-
# Query embeddings
|
|
34
|
-
results = em_retriever.invoke("What is the text about?")
|
|
35
|
-
|
|
36
|
-
# Generate RAG chain for conversation
|
|
37
|
-
rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
|
|
38
|
-
response = em_gen.conversation_chain("Tell me more", rag_chain)
|
|
39
|
-
```
|
|
40
|
-
|
|
41
|
-
Features:
|
|
42
|
-
- Multiple model support (OpenAI, Ollama, Google, Anthropic)
|
|
43
|
-
- Text processing and chunking
|
|
44
|
-
- Embedding generation and storage
|
|
45
|
-
- Vector store management
|
|
46
|
-
- Retrieval operations
|
|
47
|
-
- Conversation chains
|
|
48
|
-
- Web crawling integration
|
|
49
|
-
|
|
50
|
-
Classes:
|
|
51
|
-
- ModelProvider: Base class for model loading and validation
|
|
52
|
-
- TextProcessor: Handles text processing operations
|
|
53
|
-
- embedding_generator: Main class for RAG operations
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
import os
|
|
57
|
-
import shutil
|
|
58
|
-
import importlib.util
|
|
59
|
-
from typing import List, Dict, Optional, Union, Any
|
|
60
|
-
from langchain.text_splitter import (
|
|
61
|
-
CharacterTextSplitter,
|
|
62
|
-
RecursiveCharacterTextSplitter,
|
|
63
|
-
SentenceTransformersTokenTextSplitter,
|
|
64
|
-
TokenTextSplitter
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
from
|
|
68
|
-
from
|
|
69
|
-
from
|
|
70
|
-
from
|
|
71
|
-
from
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
return ModelProvider.
|
|
224
|
-
elif model_name == '
|
|
225
|
-
return ModelProvider.
|
|
226
|
-
elif model_name == '
|
|
227
|
-
return ModelProvider.
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
if
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
#
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
"""
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
)
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
if self.logger:
|
|
430
|
-
self.logger.info(
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
if
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
)
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
("
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
)
|
|
645
|
-
```
|
|
646
|
-
"""
|
|
647
|
-
if
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval-Augmented Generation) Embeddings Module
|
|
3
|
+
|
|
4
|
+
This module provides functionality for generating and managing embeddings for RAG models.
|
|
5
|
+
It supports multiple embedding models (OpenAI, Ollama, Google, Anthropic) and includes
|
|
6
|
+
features for text processing, embedding generation, vector store management, and
|
|
7
|
+
conversation handling.
|
|
8
|
+
|
|
9
|
+
Example Usage:
|
|
10
|
+
```python
|
|
11
|
+
# Initialize embedding generator
|
|
12
|
+
em_gen = embedding_generator(
|
|
13
|
+
model="openai",
|
|
14
|
+
model_type="text-embedding-3-small",
|
|
15
|
+
vector_store_type="chroma"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Generate embeddings from text
|
|
19
|
+
em_gen.generate_text_embeddings(
|
|
20
|
+
text_data_path=['./data/text.txt'],
|
|
21
|
+
chunk_size=500,
|
|
22
|
+
chunk_overlap=5,
|
|
23
|
+
folder_save_path='./embeddings'
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Load embeddings and create retriever
|
|
27
|
+
em_loading = em_gen.load_embeddings('./embeddings')
|
|
28
|
+
em_retriever = em_gen.load_retriever(
|
|
29
|
+
'./embeddings',
|
|
30
|
+
search_params=[{"k": 2, "score_threshold": 0.1}]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Query embeddings
|
|
34
|
+
results = em_retriever.invoke("What is the text about?")
|
|
35
|
+
|
|
36
|
+
# Generate RAG chain for conversation
|
|
37
|
+
rag_chain = em_gen.generate_rag_chain(retriever=em_retriever)
|
|
38
|
+
response = em_gen.conversation_chain("Tell me more", rag_chain)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Features:
|
|
42
|
+
- Multiple model support (OpenAI, Ollama, Google, Anthropic)
|
|
43
|
+
- Text processing and chunking
|
|
44
|
+
- Embedding generation and storage
|
|
45
|
+
- Vector store management
|
|
46
|
+
- Retrieval operations
|
|
47
|
+
- Conversation chains
|
|
48
|
+
- Web crawling integration
|
|
49
|
+
|
|
50
|
+
Classes:
|
|
51
|
+
- ModelProvider: Base class for model loading and validation
|
|
52
|
+
- TextProcessor: Handles text processing operations
|
|
53
|
+
- embedding_generator: Main class for RAG operations
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
import os
|
|
57
|
+
import shutil
|
|
58
|
+
import importlib.util
|
|
59
|
+
from typing import List, Dict, Optional, Union, Any
|
|
60
|
+
from langchain.text_splitter import (
|
|
61
|
+
CharacterTextSplitter,
|
|
62
|
+
RecursiveCharacterTextSplitter,
|
|
63
|
+
SentenceTransformersTokenTextSplitter,
|
|
64
|
+
TokenTextSplitter,
|
|
65
|
+
MarkdownHeaderTextSplitter,
|
|
66
|
+
SemanticChunker)
|
|
67
|
+
from langchain_community.document_loaders import TextLoader, FireCrawlLoader
|
|
68
|
+
from langchain_chroma import Chroma
|
|
69
|
+
from ..utils.extra import load_env_file
|
|
70
|
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
|
71
|
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
|
72
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
73
|
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
74
|
+
from langchain.retrievers import ContextualCompressionRetriever
|
|
75
|
+
from langchain_community.document_compressors import FlashrankRerank
|
|
76
|
+
|
|
77
|
+
load_env_file()
|
|
78
|
+
|
|
79
|
+
__all__ = ['embedding_generator', 'load_embedding_model']
|
|
80
|
+
|
|
81
|
+
class ModelProvider:
|
|
82
|
+
"""
|
|
83
|
+
Base class for managing different model providers and their loading logic.
|
|
84
|
+
|
|
85
|
+
This class provides static methods for loading different types of embedding models
|
|
86
|
+
and checking package dependencies.
|
|
87
|
+
|
|
88
|
+
Methods:
|
|
89
|
+
check_package: Check if a Python package is installed
|
|
90
|
+
get_rag_openai: Load OpenAI embedding model
|
|
91
|
+
get_rag_ollama: Load Ollama embedding model
|
|
92
|
+
get_rag_anthropic: Load Anthropic model
|
|
93
|
+
get_rag_google: Load Google embedding model
|
|
94
|
+
|
|
95
|
+
Example:
|
|
96
|
+
```python
|
|
97
|
+
# Check if a package is installed
|
|
98
|
+
has_openai = ModelProvider.check_package("langchain_openai")
|
|
99
|
+
|
|
100
|
+
# Load an OpenAI model
|
|
101
|
+
model = ModelProvider.get_rag_openai("text-embedding-3-small")
|
|
102
|
+
```
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def check_package(package_name: str) -> bool:
|
|
107
|
+
"""
|
|
108
|
+
Check if a Python package is installed.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
package_name (str): Name of the package to check
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
return importlib.util.find_spec(package_name) is not None
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def get_rag_openai(model_type: str = 'text-embedding-3-small', **kwargs):
|
|
118
|
+
"""
|
|
119
|
+
Load OpenAI embedding model.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
model_type (str): Model identifier (default: 'text-embedding-3-small')
|
|
123
|
+
**kwargs: Additional arguments for model initialization
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
OpenAIEmbeddings: Initialized OpenAI embeddings model
|
|
127
|
+
"""
|
|
128
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
129
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
130
|
+
from langchain_openai import OpenAIEmbeddings
|
|
131
|
+
return OpenAIEmbeddings(model=model_type, **kwargs)
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def get_rag_ollama(model_type: str = 'llama3', **kwargs):
|
|
135
|
+
"""
|
|
136
|
+
Load Ollama embedding model.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
model_type (str): Model identifier (default: 'llama3')
|
|
140
|
+
**kwargs: Additional arguments for model initialization
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
OllamaEmbeddings: Initialized Ollama embeddings model
|
|
144
|
+
"""
|
|
145
|
+
if not ModelProvider.check_package("langchain_ollama"):
|
|
146
|
+
raise ImportError("Ollama package not found. Please install: pip install langchain-ollama")
|
|
147
|
+
from langchain_ollama import OllamaEmbeddings
|
|
148
|
+
return OllamaEmbeddings(model=model_type, **kwargs)
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def get_rag_anthropic(model_name: str = "claude-3-opus-20240229", **kwargs):
|
|
152
|
+
"""
|
|
153
|
+
Load Anthropic model.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
model_name (str): Model identifier (default: "claude-3-opus-20240229")
|
|
157
|
+
**kwargs: Additional arguments for model initialization
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
ChatAnthropic: Initialized Anthropic chat model
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
if not ModelProvider.check_package("langchain_anthropic"):
|
|
164
|
+
raise ImportError("Anthropic package not found. Please install: pip install langchain-anthropic")
|
|
165
|
+
from langchain_anthropic import ChatAnthropic
|
|
166
|
+
kwargs["model_name"] = model_name
|
|
167
|
+
return ChatAnthropic(**kwargs)
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
def get_rag_google(model_name: str = "gemini-1.5-flash", **kwargs):
|
|
171
|
+
"""
|
|
172
|
+
Load Google embedding model.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
model_name (str): Model identifier (default: "gemini-1.5-flash")
|
|
176
|
+
**kwargs: Additional arguments for model initialization
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
GoogleGenerativeAIEmbeddings: Initialized Google embeddings model
|
|
180
|
+
"""
|
|
181
|
+
if not ModelProvider.check_package("google.generativeai"):
|
|
182
|
+
raise ImportError("Google Generative AI package not found. Please install: pip install langchain-google-genai")
|
|
183
|
+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
184
|
+
kwargs["model"] = model_name
|
|
185
|
+
return GoogleGenerativeAIEmbeddings(**kwargs)
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def get_rag_qwen(model_name: str = "Qwen/Qwen3-Embedding-0.6B", **kwargs):
|
|
189
|
+
"""
|
|
190
|
+
Load Qwen embedding model.
|
|
191
|
+
Uses Transformers for embedding generation.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
model_name (str): Model identifier (default: "Qwen/Qwen3-Embedding-0.6B")
|
|
195
|
+
**kwargs: Additional arguments for model initialization
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
QwenEmbeddings: Initialized Qwen embeddings model
|
|
199
|
+
"""
|
|
200
|
+
from langchain.embeddings import HuggingFaceEmbeddings
|
|
201
|
+
|
|
202
|
+
return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
|
|
203
|
+
|
|
204
|
+
def load_embedding_model(model_name: str = 'openai', model_type: str = "text-embedding-ada-002", **kwargs):
|
|
205
|
+
"""
|
|
206
|
+
Load a RAG model based on provider and type.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model_name (str): Name of the model provider (default: 'openai')
|
|
210
|
+
model_type (str): Type/identifier of the model (default: "text-embedding-ada-002")
|
|
211
|
+
**kwargs: Additional arguments for model initialization
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Any: Initialized model instance
|
|
215
|
+
|
|
216
|
+
Example:
|
|
217
|
+
```python
|
|
218
|
+
model = load_embedding_model('openai', 'text-embedding-3-small')
|
|
219
|
+
```
|
|
220
|
+
"""
|
|
221
|
+
try:
|
|
222
|
+
if model_name == 'openai':
|
|
223
|
+
return ModelProvider.get_rag_openai(model_type, **kwargs)
|
|
224
|
+
elif model_name == 'ollama':
|
|
225
|
+
return ModelProvider.get_rag_ollama(model_type, **kwargs)
|
|
226
|
+
elif model_name == 'google':
|
|
227
|
+
return ModelProvider.get_rag_google(model_type, **kwargs)
|
|
228
|
+
elif model_name == 'anthropic':
|
|
229
|
+
return ModelProvider.get_rag_anthropic(model_type, **kwargs)
|
|
230
|
+
elif model_name == 'qwen':
|
|
231
|
+
return ModelProvider.get_rag_qwen(model_type, **kwargs)
|
|
232
|
+
else:
|
|
233
|
+
raise ValueError(f"Invalid model name: {model_name}")
|
|
234
|
+
except ImportError as e:
|
|
235
|
+
print(f"Error loading model: {str(e)}")
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
class TextProcessor:
|
|
239
|
+
"""
|
|
240
|
+
Handles text processing operations including file checking and tokenization.
|
|
241
|
+
|
|
242
|
+
This class provides methods for loading text files, processing them into chunks,
|
|
243
|
+
and preparing them for embedding generation.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
logger: Optional logger instance for logging operations
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
```python
|
|
250
|
+
processor = TextProcessor()
|
|
251
|
+
docs = processor.tokenize(
|
|
252
|
+
['./data.txt'],
|
|
253
|
+
'recursive_character',
|
|
254
|
+
chunk_size=1000,
|
|
255
|
+
chunk_overlap=5
|
|
256
|
+
)
|
|
257
|
+
```
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
def __init__(self, logger=None):
|
|
261
|
+
self.logger = logger
|
|
262
|
+
|
|
263
|
+
def check_file(self, file_path: str) -> bool:
|
|
264
|
+
"""Check if file exists."""
|
|
265
|
+
return os.path.exists(file_path)
|
|
266
|
+
|
|
267
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
268
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
269
|
+
"""
|
|
270
|
+
Process and tokenize text data from files.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
text_data_path (List[str]): List of paths to text files
|
|
274
|
+
text_splitter_type (str): Type of text splitter to use
|
|
275
|
+
chunk_size (int): Size of text chunks
|
|
276
|
+
chunk_overlap (int): Overlap between chunks
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
List: List of processed document chunks
|
|
280
|
+
|
|
281
|
+
"""
|
|
282
|
+
doc_data = []
|
|
283
|
+
for path in text_data_path:
|
|
284
|
+
if self.check_file(path):
|
|
285
|
+
text_loader = TextLoader(path)
|
|
286
|
+
get_text = text_loader.load()
|
|
287
|
+
file_name = path.split('/')[-1]
|
|
288
|
+
metadata = {'source': file_name}
|
|
289
|
+
if metadata is not None:
|
|
290
|
+
for doc in get_text:
|
|
291
|
+
doc.metadata = metadata
|
|
292
|
+
doc_data.append(doc)
|
|
293
|
+
if self.logger:
|
|
294
|
+
self.logger.info(f"Text data loaded from {file_name}")
|
|
295
|
+
else:
|
|
296
|
+
return f"File {path} not found"
|
|
297
|
+
|
|
298
|
+
splitters = {
|
|
299
|
+
'character': CharacterTextSplitter(
|
|
300
|
+
chunk_size=chunk_size,
|
|
301
|
+
chunk_overlap=chunk_overlap,
|
|
302
|
+
separator=["\n", "\n\n", "\n\n\n", " "]
|
|
303
|
+
),
|
|
304
|
+
'recursive_character': RecursiveCharacterTextSplitter(
|
|
305
|
+
chunk_size=chunk_size,
|
|
306
|
+
chunk_overlap=chunk_overlap,
|
|
307
|
+
separators=["\n", "\n\n", "\n\n\n", " "]
|
|
308
|
+
),
|
|
309
|
+
'sentence_transformers_token': SentenceTransformersTokenTextSplitter(
|
|
310
|
+
chunk_size=chunk_size
|
|
311
|
+
),
|
|
312
|
+
'token': TokenTextSplitter(
|
|
313
|
+
chunk_size=chunk_size,
|
|
314
|
+
chunk_overlap=chunk_overlap
|
|
315
|
+
),
|
|
316
|
+
'markdown_header': MarkdownHeaderTextSplitter(
|
|
317
|
+
chunk_size=chunk_size,
|
|
318
|
+
chunk_overlap=chunk_overlap
|
|
319
|
+
),
|
|
320
|
+
'semantic_chunker': SemanticChunker(
|
|
321
|
+
chunk_size=chunk_size,
|
|
322
|
+
chunk_overlap=chunk_overlap
|
|
323
|
+
)
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
if text_splitter_type not in splitters:
|
|
327
|
+
raise ValueError(f"Invalid text splitter type: {text_splitter_type}")
|
|
328
|
+
|
|
329
|
+
text_splitter = splitters[text_splitter_type]
|
|
330
|
+
docs = text_splitter.split_documents(doc_data)
|
|
331
|
+
|
|
332
|
+
if self.logger:
|
|
333
|
+
self.logger.info(f"Text data splitted into {len(docs)} chunks")
|
|
334
|
+
else:
|
|
335
|
+
print(f"Text data splitted into {len(docs)} chunks")
|
|
336
|
+
return docs
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class embedding_generator:
|
|
340
|
+
"""
|
|
341
|
+
Main class for generating embeddings and managing RAG operations.
|
|
342
|
+
|
|
343
|
+
This class provides comprehensive functionality for generating embeddings,
|
|
344
|
+
managing vector stores, handling retrievers, and managing conversations.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
model (str): Model provider name (default: 'openai')
|
|
348
|
+
model_type (str): Model type/identifier (default: 'text-embedding-3-small')
|
|
349
|
+
vector_store_type (str): Type of vector store (default: 'chroma')
|
|
350
|
+
collection_name (str): Name of the collection (default: 'test')
|
|
351
|
+
logger: Optional logger instance
|
|
352
|
+
model_kwargs (dict): Additional arguments for model initialization
|
|
353
|
+
vector_store_kwargs (dict): Additional arguments for vector store initialization
|
|
354
|
+
|
|
355
|
+
Example:
|
|
356
|
+
```python
|
|
357
|
+
# Initialize generator
|
|
358
|
+
gen = embedding_generator(
|
|
359
|
+
model="openai",
|
|
360
|
+
model_type="text-embedding-3-small",
|
|
361
|
+
collection_name='test'
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Generate embeddings
|
|
365
|
+
gen.generate_text_embeddings(
|
|
366
|
+
text_data_path=['./data.txt'],
|
|
367
|
+
folder_save_path='./embeddings'
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Load retriever
|
|
371
|
+
retriever = gen.load_retriever('./embeddings', collection_name='test')
|
|
372
|
+
|
|
373
|
+
# Query embeddings
|
|
374
|
+
results = gen.query_embeddings("What is this about?")
|
|
375
|
+
```
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
def __init__(self, model: str = 'openai', model_type: str = 'text-embedding-3-small',
|
|
379
|
+
vector_store_type: str = 'chroma', collection_name: str = 'test',
|
|
380
|
+
logger=None, model_kwargs: dict = None, vector_store_kwargs: dict = None) -> None:
|
|
381
|
+
"""Initialize the embedding generator with specified configuration."""
|
|
382
|
+
self.logger = logger
|
|
383
|
+
self.model = load_embedding_model(model_name=model, model_type=model_type, **(model_kwargs or {}))
|
|
384
|
+
if self.model is None:
|
|
385
|
+
raise ValueError(f"Failed to initialize model {model}. Please ensure required packages are installed.")
|
|
386
|
+
self.vector_store_type = vector_store_type
|
|
387
|
+
self.vector_store = self.load_vectorstore(**(vector_store_kwargs or {}))
|
|
388
|
+
self.collection_name = collection_name
|
|
389
|
+
self.text_processor = TextProcessor(logger)
|
|
390
|
+
self.compression_retriever = None
|
|
391
|
+
|
|
392
|
+
def check_file(self, file_path: str) -> bool:
|
|
393
|
+
"""Check if file exists."""
|
|
394
|
+
return self.text_processor.check_file(file_path)
|
|
395
|
+
|
|
396
|
+
def tokenize(self, text_data_path: List[str], text_splitter_type: str,
|
|
397
|
+
chunk_size: int, chunk_overlap: int) -> List:
|
|
398
|
+
"""Process and tokenize text data."""
|
|
399
|
+
return self.text_processor.tokenize(text_data_path, text_splitter_type,
|
|
400
|
+
chunk_size, chunk_overlap)
|
|
401
|
+
|
|
402
|
+
def generate_text_embeddings(self, text_data_path: List[str] = None,
|
|
403
|
+
text_splitter_type: str = 'recursive_character',
|
|
404
|
+
chunk_size: int = 1000, chunk_overlap: int = 5,
|
|
405
|
+
folder_save_path: str = './text_embeddings',
|
|
406
|
+
replace_existing: bool = False) -> str:
|
|
407
|
+
"""
|
|
408
|
+
Generate text embeddings from input files.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
text_data_path (List[str]): List of paths to text files
|
|
412
|
+
text_splitter_type (str): Type of text splitter
|
|
413
|
+
chunk_size (int): Size of text chunks
|
|
414
|
+
chunk_overlap (int): Overlap between chunks
|
|
415
|
+
folder_save_path (str): Path to save embeddings
|
|
416
|
+
replace_existing (bool): Whether to replace existing embeddings
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
str: Status message
|
|
420
|
+
|
|
421
|
+
Example:
|
|
422
|
+
```python
|
|
423
|
+
gen.generate_text_embeddings(
|
|
424
|
+
text_data_path=['./data.txt'],
|
|
425
|
+
folder_save_path='./embeddings'
|
|
426
|
+
)
|
|
427
|
+
```
|
|
428
|
+
"""
|
|
429
|
+
if self.logger:
|
|
430
|
+
self.logger.info("Performing basic checks")
|
|
431
|
+
|
|
432
|
+
if self.check_file(folder_save_path) and not replace_existing:
|
|
433
|
+
return "File already exists"
|
|
434
|
+
elif self.check_file(folder_save_path) and replace_existing:
|
|
435
|
+
shutil.rmtree(folder_save_path)
|
|
436
|
+
|
|
437
|
+
if text_data_path is None:
|
|
438
|
+
return "Please provide text data path"
|
|
439
|
+
|
|
440
|
+
if not isinstance(text_data_path, list):
|
|
441
|
+
raise ValueError("text_data_path should be a list")
|
|
442
|
+
|
|
443
|
+
if self.logger:
|
|
444
|
+
self.logger.info(f"Loading text data from {text_data_path}")
|
|
445
|
+
|
|
446
|
+
docs = self.tokenize(text_data_path, text_splitter_type, chunk_size, chunk_overlap)
|
|
447
|
+
|
|
448
|
+
if self.logger:
|
|
449
|
+
self.logger.info(f"Generating embeddings for {len(docs)} documents")
|
|
450
|
+
|
|
451
|
+
self.vector_store.from_documents(docs, self.model, collection_name=self.collection_name,
|
|
452
|
+
persist_directory=folder_save_path)
|
|
453
|
+
|
|
454
|
+
if self.logger:
|
|
455
|
+
self.logger.info(f"Embeddings generated and saved at {folder_save_path}")
|
|
456
|
+
|
|
457
|
+
def load_vectorstore(self, **kwargs):
|
|
458
|
+
"""Load vector store."""
|
|
459
|
+
if self.vector_store_type == 'chroma':
|
|
460
|
+
vector_store = Chroma()
|
|
461
|
+
if self.logger:
|
|
462
|
+
self.logger.info(f"Loaded vector store {self.vector_store_type}")
|
|
463
|
+
return vector_store
|
|
464
|
+
else:
|
|
465
|
+
return "Vector store not found"
|
|
466
|
+
|
|
467
|
+
def load_embeddings(self, embeddings_folder_path: str,collection_name: str = 'test'):
|
|
468
|
+
"""
|
|
469
|
+
Load embeddings from folder.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
473
|
+
collection_name (str): Name of the collection. Default: 'test'
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
Optional[Chroma]: Loaded vector store or None if not found
|
|
477
|
+
"""
|
|
478
|
+
if self.check_file(embeddings_folder_path):
|
|
479
|
+
if self.vector_store_type == 'chroma':
|
|
480
|
+
return Chroma(persist_directory=embeddings_folder_path,
|
|
481
|
+
embedding_function=self.model,
|
|
482
|
+
collection_name=collection_name)
|
|
483
|
+
else:
|
|
484
|
+
if self.logger:
|
|
485
|
+
self.logger.info("Embeddings file not found")
|
|
486
|
+
return None
|
|
487
|
+
|
|
488
|
+
def load_retriever(self, embeddings_folder_path: str,
|
|
489
|
+
search_type: List[str] = ["similarity_score_threshold"],
|
|
490
|
+
search_params: List[Dict] = [{"k": 3, "score_threshold": 0.9}],
|
|
491
|
+
collection_name: str = 'test'):
|
|
492
|
+
"""
|
|
493
|
+
Load retriever with search configuration.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
497
|
+
search_type (List[str]): List of search types
|
|
498
|
+
search_params (List[Dict]): List of search parameters
|
|
499
|
+
collection_name (str): Name of the collection. Default: 'test'
|
|
500
|
+
|
|
501
|
+
Returns:
|
|
502
|
+
Union[Any, List[Any]]: Single retriever or list of retrievers
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
```python
|
|
506
|
+
retriever = gen.load_retriever(
|
|
507
|
+
'./embeddings',
|
|
508
|
+
search_type=["similarity_score_threshold"],
|
|
509
|
+
search_params=[{"k": 3, "score_threshold": 0.9}]
|
|
510
|
+
)
|
|
511
|
+
```
|
|
512
|
+
"""
|
|
513
|
+
db = self.load_embeddings(embeddings_folder_path, collection_name)
|
|
514
|
+
if db is not None:
|
|
515
|
+
if self.vector_store_type == 'chroma':
|
|
516
|
+
if len(search_type) != len(search_params):
|
|
517
|
+
raise ValueError("Length of search_type and search_params should be equal")
|
|
518
|
+
if len(search_type) == 1:
|
|
519
|
+
self.retriever = db.as_retriever(search_type=search_type[0],
|
|
520
|
+
search_kwargs=search_params[0])
|
|
521
|
+
if self.logger:
|
|
522
|
+
self.logger.info("Retriever loaded")
|
|
523
|
+
return self.retriever
|
|
524
|
+
else:
|
|
525
|
+
retriever_list = []
|
|
526
|
+
for i in range(len(search_type)):
|
|
527
|
+
retriever_list.append(db.as_retriever(search_type=search_type[i],
|
|
528
|
+
search_kwargs=search_params[i]))
|
|
529
|
+
if self.logger:
|
|
530
|
+
self.logger.info("List of Retriever loaded")
|
|
531
|
+
return retriever_list
|
|
532
|
+
else:
|
|
533
|
+
return "Embeddings file not found"
|
|
534
|
+
|
|
535
|
+
def add_data(self, embeddings_folder_path: str, data: List[str],
|
|
536
|
+
text_splitter_type: str = 'recursive_character',
|
|
537
|
+
chunk_size: int = 1000, chunk_overlap: int = 5, collection_name: str = 'test'):
|
|
538
|
+
"""
|
|
539
|
+
Add data to existing embeddings.
|
|
540
|
+
|
|
541
|
+
Args:
|
|
542
|
+
embeddings_folder_path (str): Path to embeddings folder
|
|
543
|
+
data (List[str]): List of text data to add
|
|
544
|
+
text_splitter_type (str): Type of text splitter
|
|
545
|
+
chunk_size (int): Size of text chunks
|
|
546
|
+
chunk_overlap (int): Overlap between chunks
|
|
547
|
+
collection_name (str): Name of the collection. Default: 'test'
|
|
548
|
+
"""
|
|
549
|
+
if self.vector_store_type == 'chroma':
|
|
550
|
+
db = self.load_embeddings(embeddings_folder_path, collection_name)
|
|
551
|
+
if db is not None:
|
|
552
|
+
docs = self.tokenize(data, text_splitter_type, chunk_size, chunk_overlap)
|
|
553
|
+
db.add_documents(docs)
|
|
554
|
+
if self.logger:
|
|
555
|
+
self.logger.info("Data added to the existing db/embeddings")
|
|
556
|
+
|
|
557
|
+
def query_embeddings(self, query: str, retriever=None):
|
|
558
|
+
"""
|
|
559
|
+
Query embeddings.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
query (str): Query string
|
|
563
|
+
retriever: Optional retriever instance
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Any: Query results
|
|
567
|
+
"""
|
|
568
|
+
if retriever is None:
|
|
569
|
+
retriever = self.retriever
|
|
570
|
+
return retriever.invoke(query)
|
|
571
|
+
|
|
572
|
+
def get_relevant_documents(self, query: str, retriever=None):
|
|
573
|
+
"""
|
|
574
|
+
Get relevant documents for query.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
query (str): Query string
|
|
578
|
+
retriever: Optional retriever instance
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
List: List of relevant documents
|
|
582
|
+
"""
|
|
583
|
+
if retriever is None:
|
|
584
|
+
retriever = self.retriever
|
|
585
|
+
return retriever.get_relevant_documents(query)
|
|
586
|
+
|
|
587
|
+
def load_flashrank_compression_retriever(self, base_retriever=None, model_name: str = "flashrank/flashrank-base", top_n: int = 5):
|
|
588
|
+
"""
|
|
589
|
+
Load a ContextualCompressionRetriever using FlashrankRerank.
|
|
590
|
+
|
|
591
|
+
Args:
|
|
592
|
+
base_retriever: Existing retriever (if None, uses self.retriever)
|
|
593
|
+
model_name (str): Flashrank model identifier (default: "flashrank/flashrank-base")
|
|
594
|
+
top_n (int): Number of top documents to return after reranking
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
ContextualCompressionRetriever: A compression-based retriever using Flashrank
|
|
598
|
+
"""
|
|
599
|
+
if base_retriever is None:
|
|
600
|
+
base_retriever = self.retriever
|
|
601
|
+
if base_retriever is None:
|
|
602
|
+
raise ValueError("Base retriever is required.")
|
|
603
|
+
|
|
604
|
+
compressor = FlashrankRerank(model=model_name, top_n=top_n)
|
|
605
|
+
self.compression_retriever = ContextualCompressionRetriever(
|
|
606
|
+
base_compressor=compressor,
|
|
607
|
+
base_retriever=base_retriever
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
if self.logger:
|
|
611
|
+
self.logger.info("Loaded Flashrank compression retriever.")
|
|
612
|
+
return self.compression_retriever
|
|
613
|
+
|
|
614
|
+
def compression_invoke(self, query: str):
|
|
615
|
+
"""
|
|
616
|
+
Invoke compression retriever. Only one compression retriever (Reranker) added right now.
|
|
617
|
+
|
|
618
|
+
Args:
|
|
619
|
+
query (str): Query string
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
Any: Query results
|
|
623
|
+
"""
|
|
624
|
+
|
|
625
|
+
if self.compression_retriever is None:
|
|
626
|
+
self.compression_retriever = self.load_flashrank_compression_retriever(base_retriever=self.retriever)
|
|
627
|
+
print("Compression retriever loaded.")
|
|
628
|
+
return self.compression_retriever.invoke(query)
|
|
629
|
+
|
|
630
|
+
def generate_rag_chain(self, context_prompt: str = None, retriever=None, llm=None):
|
|
631
|
+
"""
|
|
632
|
+
Generate RAG chain for conversation.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
context_prompt (str): Optional context prompt
|
|
636
|
+
retriever: Optional retriever instance
|
|
637
|
+
llm: Optional language model instance
|
|
638
|
+
|
|
639
|
+
Returns:
|
|
640
|
+
Any: Generated RAG chain
|
|
641
|
+
|
|
642
|
+
Example:
|
|
643
|
+
```python
|
|
644
|
+
rag_chain = gen.generate_rag_chain(retriever=retriever)
|
|
645
|
+
```
|
|
646
|
+
"""
|
|
647
|
+
if context_prompt is None:
|
|
648
|
+
context_prompt = ("You are an assistant for question-answering tasks. "
|
|
649
|
+
"Use the following pieces of retrieved context to answer the question. "
|
|
650
|
+
"If you don't know the answer, just say that you don't know. "
|
|
651
|
+
"Use three sentences maximum and keep the answer concise.\n\n{context}")
|
|
652
|
+
|
|
653
|
+
contextualize_q_system_prompt = ("Given a chat history and the latest user question "
|
|
654
|
+
"which might reference context in the chat history, "
|
|
655
|
+
"formulate a standalone question which can be understood, "
|
|
656
|
+
"just reformulate it if needed and otherwise return it as is.")
|
|
657
|
+
|
|
658
|
+
contextualize_q_prompt = ChatPromptTemplate.from_messages([
|
|
659
|
+
("system", contextualize_q_system_prompt),
|
|
660
|
+
MessagesPlaceholder("chat_history"),
|
|
661
|
+
("human", "{input}"),
|
|
662
|
+
])
|
|
663
|
+
|
|
664
|
+
if retriever is None:
|
|
665
|
+
retriever = self.retriever
|
|
666
|
+
if llm is None:
|
|
667
|
+
if not ModelProvider.check_package("langchain_openai"):
|
|
668
|
+
raise ImportError("OpenAI package not found. Please install: pip install langchain-openai")
|
|
669
|
+
from langchain_openai import ChatOpenAI
|
|
670
|
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.8)
|
|
671
|
+
|
|
672
|
+
history_aware_retriever = create_history_aware_retriever(llm, retriever,
|
|
673
|
+
contextualize_q_prompt)
|
|
674
|
+
qa_prompt = ChatPromptTemplate.from_messages([
|
|
675
|
+
("system", context_prompt),
|
|
676
|
+
MessagesPlaceholder("chat_history"),
|
|
677
|
+
("human", "{input}"),
|
|
678
|
+
])
|
|
679
|
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
|
680
|
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
|
681
|
+
return rag_chain
|
|
682
|
+
|
|
683
|
+
def conversation_chain(self, query: str, rag_chain, file: str = None):
|
|
684
|
+
"""
|
|
685
|
+
Create conversation chain.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
query (str): User query
|
|
689
|
+
rag_chain: RAG chain instance
|
|
690
|
+
file (str): Optional file to save conversation
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
List: Conversation history
|
|
694
|
+
|
|
695
|
+
Example:
|
|
696
|
+
```python
|
|
697
|
+
history = gen.conversation_chain(
|
|
698
|
+
"Tell me about...",
|
|
699
|
+
rag_chain,
|
|
700
|
+
file='conversation.txt'
|
|
701
|
+
)
|
|
702
|
+
```
|
|
703
|
+
"""
|
|
704
|
+
if file is not None:
|
|
705
|
+
try:
|
|
706
|
+
chat_history = self.load_conversation(file, list_type=True)
|
|
707
|
+
if len(chat_history) == 0:
|
|
708
|
+
chat_history = []
|
|
709
|
+
except:
|
|
710
|
+
chat_history = []
|
|
711
|
+
else:
|
|
712
|
+
chat_history = []
|
|
713
|
+
|
|
714
|
+
query = "You : " + query
|
|
715
|
+
res = rag_chain.invoke({"input": query, "chat_history": chat_history})
|
|
716
|
+
print(f"Response: {res['answer']}")
|
|
717
|
+
chat_history.append(HumanMessage(content=query))
|
|
718
|
+
chat_history.append(SystemMessage(content=res['answer']))
|
|
719
|
+
if file is not None:
|
|
720
|
+
self.save_conversation(chat_history, file)
|
|
721
|
+
return chat_history
|
|
722
|
+
|
|
723
|
+
def load_conversation(self, file: str, list_type: bool = False):
|
|
724
|
+
"""
|
|
725
|
+
Load conversation history.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
file (str): Path to conversation file
|
|
729
|
+
list_type (bool): Whether to return as list
|
|
730
|
+
|
|
731
|
+
Returns:
|
|
732
|
+
Union[str, List]: Conversation history
|
|
733
|
+
"""
|
|
734
|
+
if list_type:
|
|
735
|
+
chat_history = []
|
|
736
|
+
with open(file, 'r') as f:
|
|
737
|
+
for line in f:
|
|
738
|
+
chat_history.append(line.strip())
|
|
739
|
+
else:
|
|
740
|
+
with open(file, "r") as f:
|
|
741
|
+
chat_history = f.read()
|
|
742
|
+
return chat_history
|
|
743
|
+
|
|
744
|
+
def save_conversation(self, chat: Union[str, List], file: str):
|
|
745
|
+
"""
|
|
746
|
+
Save conversation history.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
chat (Union[str, List]): Conversation to save
|
|
750
|
+
file (str): Path to save file
|
|
751
|
+
"""
|
|
752
|
+
if isinstance(chat, str):
|
|
753
|
+
with open(file, "a") as f:
|
|
754
|
+
f.write(chat)
|
|
755
|
+
elif isinstance(chat, list):
|
|
756
|
+
with open(file, "a") as f:
|
|
757
|
+
for i in chat[-2:]:
|
|
758
|
+
f.write("%s\n" % i)
|
|
759
|
+
print(f"Saved file : {file}")
|
|
760
|
+
|
|
761
|
+
def firecrawl_web(self, website: str, api_key: str = None, mode: str = "scrape",
|
|
762
|
+
file_to_save: str = './firecrawl_embeddings', **kwargs):
|
|
763
|
+
"""
|
|
764
|
+
Get data from website using FireCrawl.
|
|
765
|
+
|
|
766
|
+
Args:
|
|
767
|
+
website (str): Website URL to crawl
|
|
768
|
+
api_key (str): Optional FireCrawl API key
|
|
769
|
+
mode (str): Crawl mode (default: "scrape")
|
|
770
|
+
file_to_save (str): Path to save embeddings
|
|
771
|
+
**kwargs: Additional arguments for FireCrawl
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
Chroma: Vector store with crawled data
|
|
775
|
+
|
|
776
|
+
Example:
|
|
777
|
+
```python
|
|
778
|
+
db = gen.firecrawl_web(
|
|
779
|
+
"https://example.com",
|
|
780
|
+
mode="scrape",
|
|
781
|
+
file_to_save='./crawl_embeddings'
|
|
782
|
+
)
|
|
783
|
+
```
|
|
784
|
+
"""
|
|
785
|
+
if not ModelProvider.check_package("firecrawl"):
|
|
786
|
+
raise ImportError("Firecrawl package not found. Please install: pip install firecrawl")
|
|
787
|
+
|
|
788
|
+
if api_key is None:
|
|
789
|
+
api_key = os.getenv("FIRECRAWL_API_KEY")
|
|
790
|
+
|
|
791
|
+
loader = FireCrawlLoader(api_key=api_key, url=website, mode=mode)
|
|
792
|
+
docs = loader.load()
|
|
793
|
+
|
|
794
|
+
for doc in docs:
|
|
795
|
+
for key, value in doc.metadata.items():
|
|
796
|
+
if isinstance(value, list):
|
|
797
|
+
doc.metadata[key] = ", ".join(map(str, value))
|
|
798
|
+
|
|
799
|
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
|
800
|
+
split_docs = text_splitter.split_documents(docs)
|
|
801
|
+
|
|
802
|
+
print("\n--- Document Chunks Information ---")
|
|
803
|
+
print(f"Number of document chunks: {len(split_docs)}")
|
|
804
|
+
print(f"Sample chunk:\n{split_docs[0].page_content}\n")
|
|
805
|
+
|
|
806
|
+
embeddings = self.model
|
|
807
|
+
db = Chroma.from_documents(split_docs, embeddings,
|
|
808
|
+
persist_directory=file_to_save)
|
|
809
|
+
print(f"Retriever saved at {file_to_save}")
|
|
810
|
+
return db
|