wikontic 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wikontic/__init__.py +16 -0
- wikontic/create_ontological_triplets_db.py +193 -0
- wikontic/create_triplets_db.py +259 -0
- wikontic/create_wikidata_ontology_db.py +555 -0
- wikontic/utils/__init__.py +7 -0
- wikontic/utils/base_inference_with_db.py +329 -0
- wikontic/utils/dynamic_aligner.py +281 -0
- wikontic/utils/inference_with_db.py +224 -0
- wikontic/utils/ontology_mappings/entity_hierarchy.json +1 -0
- wikontic/utils/ontology_mappings/entity_names.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2aliases.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2hierarchy.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2label.json +1 -0
- wikontic/utils/ontology_mappings/enum_entity_ids.json +1 -0
- wikontic/utils/ontology_mappings/enum_prop_ids.json +1 -0
- wikontic/utils/ontology_mappings/label2entity.json +1 -0
- wikontic/utils/ontology_mappings/obj_constraint2prop.json +1 -0
- wikontic/utils/ontology_mappings/prop2aliases.json +1 -0
- wikontic/utils/ontology_mappings/prop2constraints.json +1 -0
- wikontic/utils/ontology_mappings/prop2data_type.json +1 -0
- wikontic/utils/ontology_mappings/prop2label.json +1 -0
- wikontic/utils/ontology_mappings/propid2enum.json +1 -0
- wikontic/utils/ontology_mappings/subj_constraint2prop.json +1 -0
- wikontic/utils/ontology_mappings/subject_object_constraints.json +1 -0
- wikontic/utils/openai_utils.py +517 -0
- wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types.txt +17 -0
- wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types_dialog_bench.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_object_names.txt +17 -0
- wikontic/utils/prompts/name_refinement/rank_object_names_dialog_bench.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_object_qualifiers.txt +20 -0
- wikontic/utils/prompts/name_refinement/rank_subject_names.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_subject_names_dialog_bench.txt +20 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_entity_types.txt +26 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_relation.txt +24 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_relation_and_types.txt +28 -0
- wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question.txt +17 -0
- wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question_wo_types.txt +16 -0
- wikontic/utils/prompts/qa/prompt_entity_extraction_from_question.txt +3 -0
- wikontic/utils/prompts/qa/prompt_is_answered.txt +43 -0
- wikontic/utils/prompts/qa/qa_collapsing_prompt.txt +22 -0
- wikontic/utils/prompts/qa/qa_prompt.txt +5 -0
- wikontic/utils/prompts/qa/qa_prompt_hotpot.txt +6 -0
- wikontic/utils/prompts/qa/question_decomposition_1.txt +7 -0
- wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench.txt +75 -0
- wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench_in_russian.txt +78 -0
- wikontic/utils/prompts/triplet_extraction/propmt_1_types_qualifiers.txt +91 -0
- wikontic/utils/structured_aligner.py +606 -0
- wikontic/utils/structured_inference_with_db.py +561 -0
- wikontic-0.0.3.dist-info/METADATA +111 -0
- wikontic-0.0.3.dist-info/RECORD +53 -0
- wikontic-0.0.3.dist-info/WHEEL +5 -0
- wikontic-0.0.3.dist-info/licenses/LICENSE +19 -0
- wikontic-0.0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
from typing import Dict, List, Optional
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("BaseInferenceWithDB")
|
|
5
|
+
logger.setLevel(logging.ERROR)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseInferenceWithDB:
|
|
9
|
+
"""
|
|
10
|
+
Base class for inference with database functionality.
|
|
11
|
+
Contains common methods shared by InferenceWithDB and StructuredInferenceWithDB.
|
|
12
|
+
|
|
13
|
+
Note: This is an abstract base class. Child classes must define the following
|
|
14
|
+
attributes in their __init__ methods:
|
|
15
|
+
- self.extractor: The extractor instance
|
|
16
|
+
- self.aligner: The aligner instance
|
|
17
|
+
- self.triplets_db: The triplets database instance
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def retrieve_similar_entity_names(
|
|
21
|
+
self, entity_name: str, k: int, sample_id: Optional[str] = None
|
|
22
|
+
) -> List[Dict[str, str]]:
|
|
23
|
+
"""
|
|
24
|
+
Retrieve similar entity names from the knowledge graph using vector search.
|
|
25
|
+
Useful to link entities from the question to the knowledge graph.
|
|
26
|
+
Args:
|
|
27
|
+
entity_name: The entity name to retrieve similar entity names from.
|
|
28
|
+
k: The number of similar entity names to retrieve.
|
|
29
|
+
sample_id: The sample ID of the subgraph to retrieve similar entity names from. If None, perform the search across all samples.
|
|
30
|
+
Returns:
|
|
31
|
+
A list of dictionaries with the entity name and entity type.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
similar_entity_names = self.aligner.retrieve_similar_entity_names(
|
|
35
|
+
entity_name=entity_name, k=k, sample_id=sample_id
|
|
36
|
+
)
|
|
37
|
+
if isinstance(similar_entity_names, dict):
|
|
38
|
+
similar_entity_names = [e["entity"] for e in similar_entity_names]
|
|
39
|
+
return similar_entity_names
|
|
40
|
+
|
|
41
|
+
def identify_relevant_entities_from_question_with_llm(
|
|
42
|
+
self, question, sample_id=None, use_entity_types=True
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Identify relevant entities from question using LLM.
|
|
46
|
+
Args:
|
|
47
|
+
question: The question to identify relevant entities from.
|
|
48
|
+
sample_id: The sample ID of the subgraph to identify relevant entities from. If None, perform the search across all samples.
|
|
49
|
+
Returns:
|
|
50
|
+
The relevant entities.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
entities = self.extractor.extract_entities_from_question(question)
|
|
54
|
+
identified_entities = []
|
|
55
|
+
linked_entities = []
|
|
56
|
+
|
|
57
|
+
if isinstance(entities, dict):
|
|
58
|
+
entities = [entities]
|
|
59
|
+
|
|
60
|
+
for ent in entities:
|
|
61
|
+
similar_entities = self.retrieve_similar_entity_names(
|
|
62
|
+
entity_name=ent, k=10, sample_id=sample_id
|
|
63
|
+
)
|
|
64
|
+
logger.log(logging.DEBUG, "Similar entities: %s" % (str(similar_entities)))
|
|
65
|
+
|
|
66
|
+
exact_entity_match = [e for e in similar_entities if e == ent]
|
|
67
|
+
if len(exact_entity_match) > 0:
|
|
68
|
+
linked_entities.extend(exact_entity_match)
|
|
69
|
+
else:
|
|
70
|
+
identified_entities.extend(similar_entities)
|
|
71
|
+
|
|
72
|
+
logger.log(
|
|
73
|
+
logging.DEBUG,
|
|
74
|
+
"Identified entities from question: %s" % (str(identified_entities)),
|
|
75
|
+
)
|
|
76
|
+
logger.log(
|
|
77
|
+
logging.DEBUG, "Linked entities from question: %s" % (str(linked_entities))
|
|
78
|
+
)
|
|
79
|
+
if use_entity_types:
|
|
80
|
+
linked_identified_entities = self.extractor.identify_relevant_entities(
|
|
81
|
+
question=question, entity_list=identified_entities
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
linked_identified_entities = (
|
|
85
|
+
self.extractor.identify_relevant_entities_wo_types(
|
|
86
|
+
question=question, entity_list=identified_entities
|
|
87
|
+
)
|
|
88
|
+
)
|
|
89
|
+
linked_entities.extend([e["entity"] for e in linked_identified_entities])
|
|
90
|
+
|
|
91
|
+
logger.log(
|
|
92
|
+
logging.DEBUG,
|
|
93
|
+
"Linked entities after refinement: %s" % (str(linked_entities)),
|
|
94
|
+
)
|
|
95
|
+
return linked_entities
|
|
96
|
+
|
|
97
|
+
def get_1_hop_supporting_triplets(
|
|
98
|
+
self,
|
|
99
|
+
entities4search: List[str],
|
|
100
|
+
sample_id=None,
|
|
101
|
+
use_qualifiers=False,
|
|
102
|
+
use_filtered_triplets=False,
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Get the 1-hop supporting triplets for the given entities.
|
|
106
|
+
Useful to answer the question with the given entities.
|
|
107
|
+
Can be invoked multiple times for more than 1-hop support.
|
|
108
|
+
Args:
|
|
109
|
+
entities4search: The entities to get the 1-hop supporting triplets for.
|
|
110
|
+
sample_id: The sample ID of the subgraph to get the 1-hop supporting triplets from. If None, perform the search across all samples.
|
|
111
|
+
use_qualifiers: Whether to use qualifiers.
|
|
112
|
+
use_filtered_triplets: Whether to use the triplets that violate the ontology constraints along with the valid triplets.
|
|
113
|
+
Returns:
|
|
114
|
+
A list of dictionaries with the subject, relation, object, and qualifiers that correspond to the 1-hop supporting triplets for the given entities.
|
|
115
|
+
"""
|
|
116
|
+
or_conditions = []
|
|
117
|
+
for ent in entities4search:
|
|
118
|
+
or_conditions.append({"$and": [{"subject": ent}]})
|
|
119
|
+
or_conditions.append({"$and": [{"object": ent}]})
|
|
120
|
+
if sample_id is None:
|
|
121
|
+
pipeline = [{"$match": {"$or": or_conditions}}]
|
|
122
|
+
else:
|
|
123
|
+
pipeline = [{"$match": {"sample_id": sample_id, "$or": or_conditions}}]
|
|
124
|
+
results = list(
|
|
125
|
+
self.triplets_db.get_collection(
|
|
126
|
+
self.aligner.triplets_collection_name
|
|
127
|
+
).aggregate(pipeline)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if use_filtered_triplets:
|
|
131
|
+
filtered_results = list(
|
|
132
|
+
self.triplets_db.get_collection(
|
|
133
|
+
self.aligner.ontology_filtered_triplets_collection_name
|
|
134
|
+
).aggregate(pipeline)
|
|
135
|
+
)
|
|
136
|
+
results.extend(filtered_results)
|
|
137
|
+
|
|
138
|
+
if use_qualifiers:
|
|
139
|
+
supporting_triplets = [
|
|
140
|
+
{
|
|
141
|
+
"subject": item["subject"],
|
|
142
|
+
"relation": item["relation"],
|
|
143
|
+
"object": item["object"],
|
|
144
|
+
"qualifiers": item["qualifiers"],
|
|
145
|
+
}
|
|
146
|
+
for item in results
|
|
147
|
+
]
|
|
148
|
+
else:
|
|
149
|
+
supporting_triplets = [
|
|
150
|
+
{
|
|
151
|
+
"subject": item["subject"],
|
|
152
|
+
"relation": item["relation"],
|
|
153
|
+
"object": item["object"],
|
|
154
|
+
}
|
|
155
|
+
for item in results
|
|
156
|
+
]
|
|
157
|
+
logger.log(
|
|
158
|
+
logging.DEBUG,
|
|
159
|
+
"Supporting triplets: %s\n%s" % (str(supporting_triplets), "-" * 100),
|
|
160
|
+
)
|
|
161
|
+
return supporting_triplets
|
|
162
|
+
|
|
163
|
+
def answer_question_with_llm(
|
|
164
|
+
self,
|
|
165
|
+
question,
|
|
166
|
+
linked_entities,
|
|
167
|
+
sample_id=None,
|
|
168
|
+
hop_depth=5,
|
|
169
|
+
use_filtered_triplets=False,
|
|
170
|
+
use_qualifiers=False,
|
|
171
|
+
):
|
|
172
|
+
"""
|
|
173
|
+
"Answer a question with relevant entities."
|
|
174
|
+
Args:
|
|
175
|
+
question: The question to answer.
|
|
176
|
+
linked_entities: The linked entities to answer the question.
|
|
177
|
+
sample_id: The sample ID of the subgraph to answer the question from. If None, perform the search across all samples.
|
|
178
|
+
use_filtered_triplets: Whether to use filtered triplets.
|
|
179
|
+
use_qualifiers: Whether to use qualifiers.
|
|
180
|
+
Returns:
|
|
181
|
+
The answer to the question.
|
|
182
|
+
"""
|
|
183
|
+
logger.log(logging.DEBUG, "Linked entities: %s" % (str(linked_entities)))
|
|
184
|
+
|
|
185
|
+
entity_set = {e for e in linked_entities}
|
|
186
|
+
entities4search = list(entity_set)
|
|
187
|
+
supporting_triplets = []
|
|
188
|
+
|
|
189
|
+
for _ in range(hop_depth):
|
|
190
|
+
new_entities4search = []
|
|
191
|
+
new_supporting_triplets = self.get_1_hop_supporting_triplets(
|
|
192
|
+
entities4search, sample_id, use_qualifiers, use_filtered_triplets
|
|
193
|
+
)
|
|
194
|
+
supporting_triplets.extend(new_supporting_triplets)
|
|
195
|
+
|
|
196
|
+
for doc in supporting_triplets:
|
|
197
|
+
if doc["subject"] not in entities4search:
|
|
198
|
+
new_entities4search.append(doc["subject"])
|
|
199
|
+
if doc["object"] not in entities4search:
|
|
200
|
+
new_entities4search.append(doc["object"])
|
|
201
|
+
if use_qualifiers:
|
|
202
|
+
for q in doc["qualifiers"]:
|
|
203
|
+
if q["object"] not in entities4search:
|
|
204
|
+
new_entities4search.append(q["object"])
|
|
205
|
+
|
|
206
|
+
entities4search = list(set(new_entities4search))
|
|
207
|
+
|
|
208
|
+
if use_qualifiers:
|
|
209
|
+
supporting_triplets = [
|
|
210
|
+
{
|
|
211
|
+
"subject": item["subject"],
|
|
212
|
+
"relation": item["relation"],
|
|
213
|
+
"object": item["object"],
|
|
214
|
+
"qualifiers": item["qualifiers"],
|
|
215
|
+
}
|
|
216
|
+
for item in supporting_triplets
|
|
217
|
+
]
|
|
218
|
+
else:
|
|
219
|
+
supporting_triplets = [
|
|
220
|
+
{
|
|
221
|
+
"subject": item["subject"],
|
|
222
|
+
"relation": item["relation"],
|
|
223
|
+
"object": item["object"],
|
|
224
|
+
}
|
|
225
|
+
for item in supporting_triplets
|
|
226
|
+
]
|
|
227
|
+
logger.log(
|
|
228
|
+
logging.DEBUG,
|
|
229
|
+
"Supporting triplets: %s\n%s" % (str(supporting_triplets), "-" * 100),
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
ans = self.extractor.answer_question(
|
|
233
|
+
question=question, triplets=supporting_triplets
|
|
234
|
+
)
|
|
235
|
+
return supporting_triplets, ans
|
|
236
|
+
|
|
237
|
+
def answer_with_qa_collapsing(
|
|
238
|
+
self,
|
|
239
|
+
question,
|
|
240
|
+
sample_id=None,
|
|
241
|
+
max_attempts=5,
|
|
242
|
+
use_qualifiers=False,
|
|
243
|
+
use_filtered_triplets=False,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
"Answer a question with QA collapsing."
|
|
247
|
+
Args:
|
|
248
|
+
question: The question to answer.
|
|
249
|
+
sample_id: The sample ID of the subgraph to answer the question from. If None, perform the search across all samples.
|
|
250
|
+
max_attempts: The maximum number of attempts to answer the question. Useful to handle complex questions that require multiple hops to answer.
|
|
251
|
+
use_qualifiers: Whether to use qualifiers.
|
|
252
|
+
use_filtered_triplets: Whether to use filtered triplets.
|
|
253
|
+
Returns:
|
|
254
|
+
The answer to the question.
|
|
255
|
+
"""
|
|
256
|
+
collapsed_question_answer = ""
|
|
257
|
+
collapsed_question_sequence = []
|
|
258
|
+
collapsed_answer_sequence = []
|
|
259
|
+
|
|
260
|
+
logger.log(logging.DEBUG, "Question: %s" % (str(question)))
|
|
261
|
+
collapsed_question = self.extractor.decompose_question(question)
|
|
262
|
+
|
|
263
|
+
for i in range(max_attempts):
|
|
264
|
+
extracted_entities = self.extractor.extract_entities_from_question(
|
|
265
|
+
collapsed_question
|
|
266
|
+
)
|
|
267
|
+
logger.log(
|
|
268
|
+
logging.DEBUG, "Collapsed question: %s" % (str(collapsed_question))
|
|
269
|
+
)
|
|
270
|
+
logger.log(
|
|
271
|
+
logging.DEBUG, "Extracted entities: %s" % (str(extracted_entities))
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if len(collapsed_question_answer) > 0:
|
|
275
|
+
extracted_entities.append(collapsed_question_answer)
|
|
276
|
+
|
|
277
|
+
entities4search = []
|
|
278
|
+
for ent in extracted_entities:
|
|
279
|
+
similar_entities = self.retrieve_similar_entity_names(
|
|
280
|
+
entity_name=ent, k=10, sample_id=sample_id
|
|
281
|
+
)
|
|
282
|
+
entities4search.extend([e for e in similar_entities])
|
|
283
|
+
|
|
284
|
+
entities4search = list(set(entities4search))
|
|
285
|
+
logger.log(logging.DEBUG, "Similar entities: %s" % (str(entities4search)))
|
|
286
|
+
|
|
287
|
+
supporting_triplets = self.get_1_hop_supporting_triplets(
|
|
288
|
+
entities4search, sample_id, use_qualifiers, use_filtered_triplets
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
logger.log(
|
|
292
|
+
logging.DEBUG,
|
|
293
|
+
"Supporting triplets length: %s" % (str(len(supporting_triplets))),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
collapsed_question_answer = self.extractor.answer_question(
|
|
297
|
+
collapsed_question, supporting_triplets
|
|
298
|
+
)
|
|
299
|
+
collapsed_question_sequence.append(collapsed_question)
|
|
300
|
+
collapsed_answer_sequence.append(collapsed_question_answer)
|
|
301
|
+
|
|
302
|
+
logger.log(
|
|
303
|
+
logging.DEBUG, "Collapsed question: %s" % (str(collapsed_question))
|
|
304
|
+
)
|
|
305
|
+
logger.log(
|
|
306
|
+
logging.DEBUG,
|
|
307
|
+
"Collapsed question answer: %s" % (str(collapsed_question_answer)),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
is_answered = self.extractor.check_if_question_is_answered(
|
|
311
|
+
question, collapsed_question_sequence, collapsed_answer_sequence
|
|
312
|
+
)
|
|
313
|
+
question_answer_sequence = list(
|
|
314
|
+
zip(collapsed_question_sequence, collapsed_answer_sequence)
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if is_answered == "NOT FINAL":
|
|
318
|
+
collapsed_question = self.extractor.collapse_question(
|
|
319
|
+
original_question=question,
|
|
320
|
+
question=collapsed_question,
|
|
321
|
+
answer=collapsed_question_answer,
|
|
322
|
+
)
|
|
323
|
+
continue
|
|
324
|
+
else:
|
|
325
|
+
logger.log(logging.DEBUG, "Final answer: %s" % (str(is_answered)))
|
|
326
|
+
return is_answered
|
|
327
|
+
|
|
328
|
+
logger.log(logging.DEBUG, "Final answer: %s" % (str(collapsed_question_answer)))
|
|
329
|
+
return collapsed_question_answer
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from typing import List, Tuple, Set, Dict, Optional
|
|
2
|
+
from transformers import AutoTokenizer, AutoModel
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pydantic import BaseModel, ValidationError
|
|
5
|
+
from pymongo import MongoClient, UpdateOne
|
|
6
|
+
from dotenv import load_dotenv, find_dotenv
|
|
7
|
+
import os
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
11
|
+
_ = load_dotenv(find_dotenv())
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EntityAlias(BaseModel):
|
|
15
|
+
_id: int
|
|
16
|
+
label: str
|
|
17
|
+
alias: str
|
|
18
|
+
sample_id: str
|
|
19
|
+
alias_text_embedding: List[float]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PropertyAlias(BaseModel):
|
|
23
|
+
_id: int
|
|
24
|
+
label: str
|
|
25
|
+
alias: str
|
|
26
|
+
sample_id: str
|
|
27
|
+
alias_text_embedding: List[float]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Aligner:
|
|
31
|
+
def __init__(self, triplets_db, device="cuda:0"):
|
|
32
|
+
self.db = triplets_db
|
|
33
|
+
|
|
34
|
+
self.entity_aliases_collection_name = "entity_aliases"
|
|
35
|
+
self.property_aliases_collection_name = "property_aliases"
|
|
36
|
+
|
|
37
|
+
self.property_vector_index_name = "property_aliases"
|
|
38
|
+
self.entities_vector_index_name = "entity_aliases"
|
|
39
|
+
|
|
40
|
+
self.initial_triplets_collection_name = "initial_triplets"
|
|
41
|
+
self.triplets_collection_name = "triplets"
|
|
42
|
+
self.filtered_triplets_collection_name = "filtered_triplets"
|
|
43
|
+
|
|
44
|
+
self.device = torch.device(device)
|
|
45
|
+
# self.tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY"))
|
|
46
|
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
|
|
47
|
+
# self.model = AutoModel.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY")).to(self.device)
|
|
48
|
+
self.model = AutoModel.from_pretrained(
|
|
49
|
+
"facebook/contriever", use_safetensors=True
|
|
50
|
+
).to(self.device)
|
|
51
|
+
|
|
52
|
+
def get_embedding(self, text):
|
|
53
|
+
|
|
54
|
+
def mean_pooling(token_embeddings, mask):
|
|
55
|
+
token_embeddings = token_embeddings.masked_fill(
|
|
56
|
+
~mask[..., None].bool(), 0.0
|
|
57
|
+
)
|
|
58
|
+
sentence_embeddings = (
|
|
59
|
+
token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
|
|
60
|
+
)
|
|
61
|
+
return sentence_embeddings
|
|
62
|
+
|
|
63
|
+
if not text or not isinstance(text, str):
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
inputs = self.tokenizer(
|
|
67
|
+
[text], padding=True, truncation=True, return_tensors="pt"
|
|
68
|
+
)
|
|
69
|
+
outputs = self.model(**inputs.to(self.device))
|
|
70
|
+
embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
|
|
71
|
+
return embeddings.detach().cpu().tolist()[0]
|
|
72
|
+
|
|
73
|
+
def retrieve_similar_properties(
|
|
74
|
+
self, target_relation: str, sample_id: str, k: int = 10
|
|
75
|
+
) -> List[str]: # List of property labels
|
|
76
|
+
"""
|
|
77
|
+
Retrieve and rank properties that match given relation.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
target_relation: The relation to search for
|
|
81
|
+
k: Number of results to return
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
List of property labels
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
collection = self.db.get_collection(self.property_aliases_collection_name)
|
|
88
|
+
query_embedding = self.get_embedding(target_relation)
|
|
89
|
+
if query_embedding is None:
|
|
90
|
+
return []
|
|
91
|
+
|
|
92
|
+
query_k = k * 2
|
|
93
|
+
max_attempts = 5 #
|
|
94
|
+
attempt = 0
|
|
95
|
+
unique_ranked_properties: List[str] = []
|
|
96
|
+
|
|
97
|
+
while len(unique_ranked_properties) < k and attempt < max_attempts:
|
|
98
|
+
|
|
99
|
+
pipeline = [
|
|
100
|
+
{
|
|
101
|
+
"$vectorSearch": {
|
|
102
|
+
"index": self.property_vector_index_name,
|
|
103
|
+
"queryVector": query_embedding,
|
|
104
|
+
"path": "alias_text_embedding",
|
|
105
|
+
"numCandidates": 150,
|
|
106
|
+
"limit": query_k if query_k < 150 else 150,
|
|
107
|
+
# "filter": {
|
|
108
|
+
# "sample_id": {"$eq": sample_id},
|
|
109
|
+
# },
|
|
110
|
+
}
|
|
111
|
+
},
|
|
112
|
+
{
|
|
113
|
+
"$project": {
|
|
114
|
+
"_id": 0,
|
|
115
|
+
"label": 1,
|
|
116
|
+
# "alias": 1
|
|
117
|
+
# "score": {"$meta": "vectorSearchScore"}
|
|
118
|
+
}
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
similar_properties = collection.aggregate(pipeline)
|
|
123
|
+
|
|
124
|
+
for prop in similar_properties:
|
|
125
|
+
if prop["label"] not in unique_ranked_properties:
|
|
126
|
+
unique_ranked_properties.append(prop["label"])
|
|
127
|
+
if len(unique_ranked_properties) == k:
|
|
128
|
+
break
|
|
129
|
+
|
|
130
|
+
query_k *= 2
|
|
131
|
+
attempt += 1
|
|
132
|
+
|
|
133
|
+
return unique_ranked_properties
|
|
134
|
+
|
|
135
|
+
def retrieve_similar_entity_names(
|
|
136
|
+
self, entity_name: str, sample_id: Optional[str] = None, k: int = 10
|
|
137
|
+
) -> List[str]: # List of entity labels
|
|
138
|
+
"""
|
|
139
|
+
Retrieve and rank entities that match given entity.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
entity_name: The entity to search for
|
|
143
|
+
k: Number of results to return
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List of entity labels
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
collection = self.db.get_collection(self.entity_aliases_collection_name)
|
|
150
|
+
query_embedding = self.get_embedding(entity_name)
|
|
151
|
+
if query_embedding is None:
|
|
152
|
+
return []
|
|
153
|
+
|
|
154
|
+
query_k = k * 2
|
|
155
|
+
max_attempts = 5 #
|
|
156
|
+
attempt = 0
|
|
157
|
+
unique_ranked_entities: List[str] = []
|
|
158
|
+
|
|
159
|
+
while len(unique_ranked_entities) < k and attempt < max_attempts:
|
|
160
|
+
|
|
161
|
+
if sample_id is not None:
|
|
162
|
+
filter = {
|
|
163
|
+
"sample_id": {"$eq": sample_id},
|
|
164
|
+
}
|
|
165
|
+
else:
|
|
166
|
+
filter = {}
|
|
167
|
+
|
|
168
|
+
pipeline = [
|
|
169
|
+
{
|
|
170
|
+
"$vectorSearch": {
|
|
171
|
+
"index": self.entities_vector_index_name,
|
|
172
|
+
"queryVector": query_embedding,
|
|
173
|
+
"path": "alias_text_embedding",
|
|
174
|
+
"numCandidates": 150,
|
|
175
|
+
"limit": query_k if query_k < 150 else 150,
|
|
176
|
+
"filter": filter,
|
|
177
|
+
}
|
|
178
|
+
},
|
|
179
|
+
{
|
|
180
|
+
"$project": {
|
|
181
|
+
"_id": 0,
|
|
182
|
+
"label": 1,
|
|
183
|
+
# "score": {"$meta": "vectorSearchScore"}
|
|
184
|
+
}
|
|
185
|
+
},
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
similar_entities = collection.aggregate(pipeline)
|
|
189
|
+
|
|
190
|
+
for entity in similar_entities:
|
|
191
|
+
if entity["label"] not in unique_ranked_entities:
|
|
192
|
+
unique_ranked_entities.append(entity["label"])
|
|
193
|
+
if len(unique_ranked_entities) == k:
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
query_k *= 2
|
|
197
|
+
attempt += 1
|
|
198
|
+
|
|
199
|
+
return unique_ranked_entities
|
|
200
|
+
|
|
201
|
+
def add_entity(self, entity_name, alias, sample_id):
|
|
202
|
+
collection = self.db.get_collection(self.entity_aliases_collection_name)
|
|
203
|
+
if not collection.find_one(
|
|
204
|
+
{"label": entity_name, "alias": alias, "sample_id": sample_id}
|
|
205
|
+
):
|
|
206
|
+
|
|
207
|
+
collection.insert_one(
|
|
208
|
+
{
|
|
209
|
+
"label": entity_name,
|
|
210
|
+
"alias": alias,
|
|
211
|
+
"sample_id": sample_id,
|
|
212
|
+
"alias_text_embedding": self.get_embedding(alias),
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def add_property(self, property_name, alias, sample_id):
|
|
217
|
+
collection = self.db.get_collection(self.property_aliases_collection_name)
|
|
218
|
+
if not collection.find_one({"label": property_name, "alias": alias}):
|
|
219
|
+
collection.insert_one(
|
|
220
|
+
{
|
|
221
|
+
"label": property_name,
|
|
222
|
+
"alias": alias,
|
|
223
|
+
# "sample_id": sample_id,
|
|
224
|
+
"alias_text_embedding": self.get_embedding(alias),
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def add_triplets(self, triplets_list, sample_id):
|
|
229
|
+
collection = self.db.get_collection(self.triplets_collection_name)
|
|
230
|
+
|
|
231
|
+
operations = []
|
|
232
|
+
for triple in triplets_list:
|
|
233
|
+
triple["sample_id"] = sample_id
|
|
234
|
+
filter_query = {
|
|
235
|
+
"subject": triple["subject"],
|
|
236
|
+
"relation": triple["relation"],
|
|
237
|
+
"object": triple["object"],
|
|
238
|
+
"sample_id": triple["sample_id"],
|
|
239
|
+
}
|
|
240
|
+
operations.append(
|
|
241
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if operations:
|
|
245
|
+
collection.bulk_write(operations)
|
|
246
|
+
|
|
247
|
+
def add_filtered_triplets(self, triplets_list, sample_id):
|
|
248
|
+
collection = self.db.get_collection(self.filtered_triplets_collection_name)
|
|
249
|
+
|
|
250
|
+
operations = []
|
|
251
|
+
for triple in triplets_list:
|
|
252
|
+
triple["sample_id"] = sample_id
|
|
253
|
+
filter_query = {
|
|
254
|
+
"subject": triple["subject"],
|
|
255
|
+
"relation": triple["relation"],
|
|
256
|
+
"object": triple["object"],
|
|
257
|
+
"sample_id": triple["sample_id"],
|
|
258
|
+
}
|
|
259
|
+
operations.append(
|
|
260
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if operations:
|
|
264
|
+
collection.bulk_write(operations)
|
|
265
|
+
|
|
266
|
+
def add_initial_triplets(self, triplets_list, sample_id):
|
|
267
|
+
collection = self.db.get_collection(self.initial_triplets_collection_name)
|
|
268
|
+
operations = []
|
|
269
|
+
for triple in triplets_list:
|
|
270
|
+
triple["sample_id"] = sample_id
|
|
271
|
+
filter_query = {
|
|
272
|
+
"subject": triple["subject"],
|
|
273
|
+
"relation": triple["relation"],
|
|
274
|
+
"object": triple["object"],
|
|
275
|
+
"sample_id": triple["sample_id"],
|
|
276
|
+
}
|
|
277
|
+
operations.append(
|
|
278
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
279
|
+
)
|
|
280
|
+
if operations:
|
|
281
|
+
collection.bulk_write(operations)
|