kolzchut-ragbot 1.3.0__py3-none-any.whl → 1.7.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kolzchut_ragbot/Document.py +101 -101
- kolzchut_ragbot/IntegrateService.py +4 -4
- kolzchut_ragbot/__init__.py +0 -2
- kolzchut_ragbot/config.py +4 -5
- kolzchut_ragbot/engine.py +390 -246
- kolzchut_ragbot/get_full_documents_utilities.py +45 -0
- kolzchut_ragbot/llm_client.py +11 -11
- kolzchut_ragbot/model.py +182 -166
- kolzchut_ragbot-1.7.13.dist-info/METADATA +76 -0
- kolzchut_ragbot-1.7.13.dist-info/RECORD +12 -0
- {kolzchut_ragbot-1.3.0.dist-info → kolzchut_ragbot-1.7.13.dist-info}/WHEEL +1 -1
- kolzchut_ragbot-1.3.0.dist-info/METADATA +0 -67
- kolzchut_ragbot-1.3.0.dist-info/RECORD +0 -11
- {kolzchut_ragbot-1.3.0.dist-info → kolzchut_ragbot-1.7.13.dist-info}/top_level.txt +0 -0
kolzchut_ragbot/engine.py
CHANGED
|
@@ -1,246 +1,390 @@
|
|
|
1
|
-
import time
|
|
2
|
-
from collections import defaultdict
|
|
3
|
-
from datetime import datetime
|
|
4
|
-
from .llm_client import LLMClient
|
|
5
|
-
from . import config
|
|
6
|
-
from .model import es_client_factory
|
|
7
|
-
from .Document import factory
|
|
8
|
-
from sentence_transformers import SentenceTransformer
|
|
9
|
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
10
|
-
|
|
11
|
-
import torch
|
|
12
|
-
import os
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
"""
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
"""
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
1
|
+
import time
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from .llm_client import LLMClient
|
|
5
|
+
from . import config
|
|
6
|
+
from .model import es_client_factory
|
|
7
|
+
from .Document import factory
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
10
|
+
from .get_full_documents_utilities import find_page_id_in_all_indices, unite_docs_to_single_instance
|
|
11
|
+
import torch
|
|
12
|
+
import os
|
|
13
|
+
import asyncio
|
|
14
|
+
|
|
15
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
16
|
+
definitions = factory()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Engine:
|
|
20
|
+
"""
|
|
21
|
+
Engine class for handling document search and retrieval using Elasticsearch and LLMs.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
llms_client (LLMClient): The LLM client instance.
|
|
25
|
+
elastic_model (Model): The Elasticsearch model instance.
|
|
26
|
+
models (dict): A dictionary of SentenceTransformer models.
|
|
27
|
+
reranker_tokenizer (AutoTokenizer): The tokenizer for the reranker model.
|
|
28
|
+
reranker_model (AutoModelForSequenceClassification): The reranker model.
|
|
29
|
+
identifier_field (str): The identifier field for documents.
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
rerank_with_me5(query, documents, k=5):
|
|
33
|
+
Reranks documents based on the query using the reranker model.
|
|
34
|
+
|
|
35
|
+
update_docs(list_of_docs, embed_only_fields=None, delete_existing=False):
|
|
36
|
+
Updates or creates documents in the Elasticsearch index.
|
|
37
|
+
|
|
38
|
+
reciprocal_rank_fusion(ranking_lists, k=60, weights=None):
|
|
39
|
+
Performs Reciprocal Rank Fusion on a list of ranking lists.
|
|
40
|
+
|
|
41
|
+
search_documents(query, top_k):
|
|
42
|
+
Searches for documents based on the query and returns the top_k results.
|
|
43
|
+
|
|
44
|
+
answer_query(query, top_k, model):
|
|
45
|
+
Answers a query using the top_k documents and the specified model.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, llms_client: LLMClient, elastic_model=None, models=None, reranker_tokenizer=None,
|
|
49
|
+
reranker_model=None, es_client=None):
|
|
50
|
+
"""
|
|
51
|
+
Initializes the Engine instance.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
llms_client (LLMClient): The LLM client instance.
|
|
55
|
+
elastic_model (Model, optional): The Elasticsearch model instance. Default is None.
|
|
56
|
+
models (dict, optional): A dictionary of SentenceTransformer models. Default is None.
|
|
57
|
+
reranker_tokenizer (AutoTokenizer, optional): The tokenizer for the reranker model. Default is None.
|
|
58
|
+
reranker_model (AutoModelForSequenceClassification, optional): The reranker model. Default is None.
|
|
59
|
+
es_client (optional): The Elasticsearch client instance. Default is None.
|
|
60
|
+
"""
|
|
61
|
+
if elastic_model is None:
|
|
62
|
+
self.elastic_model = es_client_factory(es_client)
|
|
63
|
+
else:
|
|
64
|
+
self.elastic_model = elastic_model
|
|
65
|
+
|
|
66
|
+
self.llms_client = llms_client
|
|
67
|
+
|
|
68
|
+
self.identifier_field = factory().identifier
|
|
69
|
+
|
|
70
|
+
if models is None:
|
|
71
|
+
self.models = {f"{model_name}": SentenceTransformer(config.MODELS_LOCATION + "/" + model_name).to(device)
|
|
72
|
+
for model_name in definitions.models.keys()}
|
|
73
|
+
else:
|
|
74
|
+
self.models = models
|
|
75
|
+
for model in self.models.values():
|
|
76
|
+
model.eval()
|
|
77
|
+
|
|
78
|
+
if reranker_tokenizer is None:
|
|
79
|
+
self.reranker_tokenizer = AutoTokenizer.from_pretrained(os.getenv("TOKENIZER_LOCATION"))
|
|
80
|
+
else:
|
|
81
|
+
self.reranker_tokenizer = reranker_tokenizer
|
|
82
|
+
|
|
83
|
+
if reranker_model is None:
|
|
84
|
+
self.reranker_model = AutoModelForSequenceClassification.from_pretrained(os.getenv("TOKENIZER_LOCATION"))
|
|
85
|
+
else:
|
|
86
|
+
self.reranker_model = reranker_model
|
|
87
|
+
self.reranker_model.eval()
|
|
88
|
+
|
|
89
|
+
def change_llm(self, llms_client: LLMClient):
|
|
90
|
+
"""
|
|
91
|
+
Changes the LLM client for the Engine instance.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
llms_client (LLMClient): The new LLM client instance.
|
|
95
|
+
"""
|
|
96
|
+
self.llms_client = llms_client
|
|
97
|
+
|
|
98
|
+
def rerank_with_me5(self, query, documents, k=5):
|
|
99
|
+
"""
|
|
100
|
+
Reranks documents based on the query using the reranker model.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
query (str): The query string.
|
|
104
|
+
documents (list): A list of documents to be reranked.
|
|
105
|
+
k (int, optional): The number of top documents to return. Default is 5.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
list: A list of top k reranked documents.
|
|
109
|
+
"""
|
|
110
|
+
pairs = [(query, doc) for doc in set(documents)]
|
|
111
|
+
inputs = self.reranker_tokenizer(pairs, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
|
112
|
+
|
|
113
|
+
# Make predictions
|
|
114
|
+
with torch.no_grad():
|
|
115
|
+
outputs = self.reranker_model(**inputs)
|
|
116
|
+
|
|
117
|
+
scores = outputs.logits.squeeze()
|
|
118
|
+
|
|
119
|
+
if scores.ndim > 1:
|
|
120
|
+
scores = scores[:, 1] # Assuming binary classification and index 1 is the relevance score
|
|
121
|
+
|
|
122
|
+
sorted_indices = torch.argsort(scores, descending=True)
|
|
123
|
+
# If there is only one document, return it to avoid torch error
|
|
124
|
+
if len(sorted_indices) == 1:
|
|
125
|
+
return [pairs[0][1]]
|
|
126
|
+
# Sort documents by their highest score
|
|
127
|
+
sorted_docs = [pairs[i][1] for i in sorted_indices]
|
|
128
|
+
return sorted_docs[:k]
|
|
129
|
+
|
|
130
|
+
def update_docs(self, list_of_docs: list[dict], embed_only_fields=None, delete_existing=False):
|
|
131
|
+
"""
|
|
132
|
+
Updates or creates documents in the Elasticsearch index.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
list_of_docs (list[dict]): A list of dictionaries representing the documents to be indexed.
|
|
136
|
+
embed_only_fields (list, optional): A list of fields to be embedded. Default is None.
|
|
137
|
+
delete_existing (bool, optional): Whether to delete existing documents. Default is False.
|
|
138
|
+
"""
|
|
139
|
+
embed_only_fields = embed_only_fields or definitions.models.values()
|
|
140
|
+
for doc in list_of_docs:
|
|
141
|
+
for semantic_model, field in definitions.models.items():
|
|
142
|
+
if field in doc.keys() and field in embed_only_fields:
|
|
143
|
+
content_vectors = self.models[semantic_model].encode(doc[field])
|
|
144
|
+
doc[f'{field}_{semantic_model}_vectors'] = content_vectors
|
|
145
|
+
|
|
146
|
+
doc['last_update'] = datetime.now()
|
|
147
|
+
self.elastic_model.create_or_update_documents(list_of_docs, delete_existing)
|
|
148
|
+
|
|
149
|
+
def reciprocal_rank_fusion(self, ranking_lists, k=60, weights=None):
|
|
150
|
+
"""
|
|
151
|
+
Performs Reciprocal Rank Fusion on a list of ranking lists.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
:param ranking_lists: List of ranking lists, where each ranking list is a list of documents returned by a model.
|
|
155
|
+
:param k: The parameter for the reciprocal rank calculation (default is 60).
|
|
156
|
+
:param: weights: Optional. Weights for each ranking list.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
list: A fused ranking list of documents.
|
|
160
|
+
"""
|
|
161
|
+
scores = defaultdict(float)
|
|
162
|
+
|
|
163
|
+
for list_index, rank_list in enumerate(ranking_lists):
|
|
164
|
+
for rank, identifier in enumerate(rank_list):
|
|
165
|
+
# Reciprocal rank score
|
|
166
|
+
w = weights[list_index] if weights else 1
|
|
167
|
+
scores[identifier] += w / (k + rank + 1)
|
|
168
|
+
|
|
169
|
+
# Sort the documents by their cumulative scores in descending order
|
|
170
|
+
fused_list = sorted(scores, key=scores.get, reverse=True)
|
|
171
|
+
|
|
172
|
+
return fused_list
|
|
173
|
+
|
|
174
|
+
def search_documents(self, query: str, top_k: int, retrieval_size: int, max_documents_from_same_page: int):
|
|
175
|
+
"""
|
|
176
|
+
Searches for documents based on the query and returns the top_k results.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
query (str): The query string.
|
|
180
|
+
top_k (int): The number of top documents to return.
|
|
181
|
+
retrieval_size (int, optional): The number of documents to fetch from each model.
|
|
182
|
+
max_documents_from_same_page (int, optional): The maximum number of documents (paragraphs acutually) to return from the same page.
|
|
183
|
+
Returns:
|
|
184
|
+
list: A list of top k documents.
|
|
185
|
+
"""
|
|
186
|
+
query_embeddings = {f"{semantic_model}": self.models[semantic_model].encode(query) for semantic_model in
|
|
187
|
+
definitions.models.keys()}
|
|
188
|
+
all_docs_by_model = self.elastic_model.search(embedded_search=query_embeddings, size=retrieval_size)
|
|
189
|
+
all_docs = []
|
|
190
|
+
ids_for_fusion = []
|
|
191
|
+
all_docs_and_scores = {}
|
|
192
|
+
|
|
193
|
+
for key, values in all_docs_by_model.items():
|
|
194
|
+
print(f"\nFound {len(values)} documents for model\n")
|
|
195
|
+
model_ids = []
|
|
196
|
+
scores_for_model = []
|
|
197
|
+
|
|
198
|
+
for doc in values:
|
|
199
|
+
model_ids.append(doc["_source"]["page_id"])
|
|
200
|
+
all_docs.append(doc)
|
|
201
|
+
scores_for_model.append({"doc": doc["_source"]["title"], "score": doc["_score"]})
|
|
202
|
+
ids_for_fusion.append(model_ids)
|
|
203
|
+
all_docs_and_scores[f'{key}'] = scores_for_model
|
|
204
|
+
print(f"\nFusing {len(ids_for_fusion)} results\n")
|
|
205
|
+
fused_ids = self.reciprocal_rank_fusion(ids_for_fusion, k=top_k)
|
|
206
|
+
top_k_documents = []
|
|
207
|
+
count_per_id = {}
|
|
208
|
+
|
|
209
|
+
for fused_id in fused_ids[:top_k]:
|
|
210
|
+
for doc in all_docs:
|
|
211
|
+
if doc["_source"]["page_id"] == fused_id:
|
|
212
|
+
count = count_per_id.get(fused_id, 0)
|
|
213
|
+
if count >= max_documents_from_same_page:
|
|
214
|
+
break;
|
|
215
|
+
top_k_documents.append(doc["_source"])
|
|
216
|
+
count_per_id[fused_id] = count + 1
|
|
217
|
+
|
|
218
|
+
return top_k_documents, all_docs_and_scores
|
|
219
|
+
|
|
220
|
+
def get_page_content_by_page_id(self, page_id: int) -> tuple:
|
|
221
|
+
"""
|
|
222
|
+
Fetches the full content of a page and measures how long it takes.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
page_id (int): The ID of the page to retrieve.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
tuple: (page_content, elapsed_time) where `page_content` is the retrieved content
|
|
229
|
+
and `elapsed_time` is the time in seconds.
|
|
230
|
+
"""
|
|
231
|
+
before_getting_additional_page = time.perf_counter()
|
|
232
|
+
additional_page_content = self.get_full_document_by_page_id(page_id)
|
|
233
|
+
after_getting_additional_page = time.perf_counter()
|
|
234
|
+
elapsed_time = after_getting_additional_page - before_getting_additional_page
|
|
235
|
+
return additional_page_content, elapsed_time
|
|
236
|
+
|
|
237
|
+
def retrieve_documents(self, query: str, top_k: int, retrieval_size: int,
|
|
238
|
+
max_documents_from_same_page: int, send_complete_pages_to_llm: bool) -> tuple:
|
|
239
|
+
"""
|
|
240
|
+
Retrieves documents matching a query and optionally converts them to full pages.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
query (str): Search query.
|
|
244
|
+
top_k (int): Number of top documents to return.
|
|
245
|
+
retrieval_size (int): Number of documents to fetch from the source.
|
|
246
|
+
max_documents_from_same_page (int): Max documents from a single page.
|
|
247
|
+
send_complete_pages_to_llm (bool): If True, returns full page content.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
tuple: (top_k_documents, all_docs_and_scores, retrieval_time)
|
|
251
|
+
"""
|
|
252
|
+
before_retrieval = time.perf_counter()
|
|
253
|
+
top_k_documents, all_docs_and_scores = self.search_documents(
|
|
254
|
+
query=query,
|
|
255
|
+
top_k=top_k,
|
|
256
|
+
retrieval_size=retrieval_size,
|
|
257
|
+
max_documents_from_same_page=max_documents_from_same_page
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if send_complete_pages_to_llm:
|
|
261
|
+
top_k_documents = [self.transform_document_to_full_page(doc) for doc in top_k_documents]
|
|
262
|
+
|
|
263
|
+
retrieval_time = round(time.perf_counter() - before_retrieval, 4)
|
|
264
|
+
print(f"retrieval time: {retrieval_time}")
|
|
265
|
+
|
|
266
|
+
return top_k_documents, all_docs_and_scores, retrieval_time
|
|
267
|
+
|
|
268
|
+
async def answer_query(self, query: str, top_k: int, model, page_id: int | None = None,
|
|
269
|
+
send_complete_pages_to_llm: bool = False, retrieval_size: int = 50,
|
|
270
|
+
max_documents_from_same_page: int = 3) -> tuple:
|
|
271
|
+
"""
|
|
272
|
+
Answers a query using top documents and an LLM model, optionally including a full page.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
query (str): Query string.
|
|
276
|
+
top_k (int): Number of top documents to use.
|
|
277
|
+
model: LLM model to generate the answer.
|
|
278
|
+
page_id (int | None): Optional page to include.
|
|
279
|
+
send_complete_pages_to_llm (bool): If True, sends full pages to the LLM.
|
|
280
|
+
retrieval_size (int): Number of documents to fetch (default 50).
|
|
281
|
+
max_documents_from_same_page (int): Max documents from one page (default 3).
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
tuple: (top_k_documents, gpt_answer, stats, all_docs_and_scores, request_params)
|
|
285
|
+
"""
|
|
286
|
+
before_answer = time.perf_counter()
|
|
287
|
+
|
|
288
|
+
tasks = [
|
|
289
|
+
asyncio.to_thread(
|
|
290
|
+
self.retrieve_documents,
|
|
291
|
+
query, top_k, retrieval_size, max_documents_from_same_page, send_complete_pages_to_llm
|
|
292
|
+
)
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
if page_id:
|
|
296
|
+
tasks.append(asyncio.to_thread(self.get_page_content_by_page_id, page_id))
|
|
297
|
+
|
|
298
|
+
results = await asyncio.gather(*tasks)
|
|
299
|
+
|
|
300
|
+
# Unpack results
|
|
301
|
+
top_k_documents, all_docs_and_scores, retrieval_time = results[0]
|
|
302
|
+
additional_document = None
|
|
303
|
+
additional_page_time = None
|
|
304
|
+
if page_id:
|
|
305
|
+
additional_document, additional_page_time = results[1]
|
|
306
|
+
print(f'retrived document {page_id} in {additional_page_time} ms: \n {additional_document}')
|
|
307
|
+
|
|
308
|
+
# Combine documents
|
|
309
|
+
top_k_documents_and_additional_document = top_k_documents.copy()
|
|
310
|
+
# Remove documents with the same page_id as the additional_document before appending
|
|
311
|
+
if additional_document:
|
|
312
|
+
additional_page_id = additional_document.get("page_id")
|
|
313
|
+
top_k_documents_and_additional_document = [
|
|
314
|
+
doc for doc in top_k_documents_and_additional_document
|
|
315
|
+
if doc.get("page_id") != additional_page_id
|
|
316
|
+
]
|
|
317
|
+
additional_document['is_additional_page'] = True
|
|
318
|
+
top_k_documents_and_additional_document.append(additional_document)
|
|
319
|
+
|
|
320
|
+
# Query LLM
|
|
321
|
+
gpt_answer, gpt_elapsed, tokens, request_params, full_user_prompt = await asyncio.to_thread(
|
|
322
|
+
self.llms_client.answer,
|
|
323
|
+
query, top_k_documents_and_additional_document
|
|
324
|
+
)
|
|
325
|
+
after_answer = time.perf_counter()
|
|
326
|
+
answer_time = after_answer - before_answer
|
|
327
|
+
|
|
328
|
+
stats = {
|
|
329
|
+
"retrieval_time": retrieval_time,
|
|
330
|
+
"gpt_model": model,
|
|
331
|
+
"gpt_time": gpt_elapsed,
|
|
332
|
+
"tokens": tokens,
|
|
333
|
+
"answer_time": answer_time
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
request_params['timers_ms'] = {
|
|
337
|
+
"answer_time": int(answer_time * 1000),
|
|
338
|
+
"retrieval_time": int(retrieval_time * 1000),
|
|
339
|
+
"llm_time": int(gpt_elapsed * 1000),
|
|
340
|
+
"additional_page_time": int((additional_page_time or 0) * 1000)
|
|
341
|
+
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
return top_k_documents, gpt_answer, stats, all_docs_and_scores, request_params, full_user_prompt
|
|
345
|
+
|
|
346
|
+
def transform_document_to_full_page(self, document: dict) -> dict:
|
|
347
|
+
"""
|
|
348
|
+
Adds the full page content to the document by retrieving it from Elasticsearch.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
document (dict): The document to which the full page content will be added.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
dict: The updated document with the full page content added.
|
|
355
|
+
"""
|
|
356
|
+
if not document.get("page_id"):
|
|
357
|
+
return document
|
|
358
|
+
full_document = self.get_full_document_by_page_id(document["page_id"])
|
|
359
|
+
if full_document and full_document.get("content"):
|
|
360
|
+
document["content"] = full_document["content"]
|
|
361
|
+
return document
|
|
362
|
+
|
|
363
|
+
def get_full_document_by_page_id(self, page_id: int) -> dict | None:
|
|
364
|
+
"""
|
|
365
|
+
Retrieves a unified document instance for a given page_id by searching all indices in Elasticsearch.
|
|
366
|
+
|
|
367
|
+
Args:
|
|
368
|
+
page_id (int): The page ID to search for.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
dict | None: A single dict representing the united document (with metadata and concatenated content),
|
|
372
|
+
or None if no documents are found.
|
|
373
|
+
"""
|
|
374
|
+
es_client = self.elastic_model.es_client
|
|
375
|
+
indices = es_client.indices.get_alias(index="*").keys()
|
|
376
|
+
parts_of_documents = find_page_id_in_all_indices(page_id=page_id, es_client=es_client, indices=indices)
|
|
377
|
+
if not parts_of_documents:
|
|
378
|
+
return None
|
|
379
|
+
full_document = unite_docs_to_single_instance(parts_of_documents)
|
|
380
|
+
return full_document
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
engine = None
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def engine_factory(llms_client: LLMClient, es_client=None):
|
|
387
|
+
global engine
|
|
388
|
+
if engine is None:
|
|
389
|
+
engine = Engine(llms_client=llms_client, es_client=es_client)
|
|
390
|
+
return engine
|