wikontic 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wikontic/__init__.py +16 -0
- wikontic/create_ontological_triplets_db.py +193 -0
- wikontic/create_triplets_db.py +259 -0
- wikontic/create_wikidata_ontology_db.py +555 -0
- wikontic/utils/__init__.py +7 -0
- wikontic/utils/base_inference_with_db.py +329 -0
- wikontic/utils/dynamic_aligner.py +281 -0
- wikontic/utils/inference_with_db.py +224 -0
- wikontic/utils/ontology_mappings/entity_hierarchy.json +1 -0
- wikontic/utils/ontology_mappings/entity_names.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2aliases.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2hierarchy.json +1 -0
- wikontic/utils/ontology_mappings/entity_type2label.json +1 -0
- wikontic/utils/ontology_mappings/enum_entity_ids.json +1 -0
- wikontic/utils/ontology_mappings/enum_prop_ids.json +1 -0
- wikontic/utils/ontology_mappings/label2entity.json +1 -0
- wikontic/utils/ontology_mappings/obj_constraint2prop.json +1 -0
- wikontic/utils/ontology_mappings/prop2aliases.json +1 -0
- wikontic/utils/ontology_mappings/prop2constraints.json +1 -0
- wikontic/utils/ontology_mappings/prop2data_type.json +1 -0
- wikontic/utils/ontology_mappings/prop2label.json +1 -0
- wikontic/utils/ontology_mappings/propid2enum.json +1 -0
- wikontic/utils/ontology_mappings/subj_constraint2prop.json +1 -0
- wikontic/utils/ontology_mappings/subject_object_constraints.json +1 -0
- wikontic/utils/openai_utils.py +517 -0
- wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types.txt +17 -0
- wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types_dialog_bench.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_object_names.txt +17 -0
- wikontic/utils/prompts/name_refinement/rank_object_names_dialog_bench.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_object_qualifiers.txt +20 -0
- wikontic/utils/prompts/name_refinement/rank_subject_names.txt +18 -0
- wikontic/utils/prompts/name_refinement/rank_subject_names_dialog_bench.txt +20 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_entity_types.txt +26 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_relation.txt +24 -0
- wikontic/utils/prompts/ontology_refinement/prompt_choose_relation_and_types.txt +28 -0
- wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question.txt +17 -0
- wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question_wo_types.txt +16 -0
- wikontic/utils/prompts/qa/prompt_entity_extraction_from_question.txt +3 -0
- wikontic/utils/prompts/qa/prompt_is_answered.txt +43 -0
- wikontic/utils/prompts/qa/qa_collapsing_prompt.txt +22 -0
- wikontic/utils/prompts/qa/qa_prompt.txt +5 -0
- wikontic/utils/prompts/qa/qa_prompt_hotpot.txt +6 -0
- wikontic/utils/prompts/qa/question_decomposition_1.txt +7 -0
- wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench.txt +75 -0
- wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench_in_russian.txt +78 -0
- wikontic/utils/prompts/triplet_extraction/propmt_1_types_qualifiers.txt +91 -0
- wikontic/utils/structured_aligner.py +606 -0
- wikontic/utils/structured_inference_with_db.py +561 -0
- wikontic-0.0.3.dist-info/METADATA +111 -0
- wikontic-0.0.3.dist-info/RECORD +53 -0
- wikontic-0.0.3.dist-info/WHEEL +5 -0
- wikontic-0.0.3.dist-info/licenses/LICENSE +19 -0
- wikontic-0.0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,606 @@
|
|
|
1
|
+
from typing import List, Tuple, Set, Dict
|
|
2
|
+
from transformers import AutoTokenizer, AutoModel
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pydantic import BaseModel, ValidationError
|
|
5
|
+
from pymongo import MongoClient, UpdateOne
|
|
6
|
+
import torch
|
|
7
|
+
from dotenv import load_dotenv, find_dotenv
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
11
|
+
_ = load_dotenv(find_dotenv())
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PropertyConstraints:
|
|
16
|
+
subject_properties: Set[str]
|
|
17
|
+
object_properties: Set[str]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EntityAlias(BaseModel):
|
|
21
|
+
_id: int
|
|
22
|
+
label: str
|
|
23
|
+
entity_type: str
|
|
24
|
+
alias: str
|
|
25
|
+
sample_id: str
|
|
26
|
+
alias_text_embedding: List[float]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Aligner:
|
|
30
|
+
def __init__(self, ontology_db, triplets_db, device="cuda:0"):
|
|
31
|
+
self.ontology_db = ontology_db
|
|
32
|
+
self.triplets_db = triplets_db
|
|
33
|
+
|
|
34
|
+
self.entity_type_collection_name = "entity_types"
|
|
35
|
+
self.entity_type_aliases_collection_name = "entity_type_aliases"
|
|
36
|
+
self.property_collection_name = "properties"
|
|
37
|
+
self.property_aliases_collection_name = "property_aliases"
|
|
38
|
+
|
|
39
|
+
self.entity_type_vector_index_name = "entity_type_aliases"
|
|
40
|
+
self.property_vector_index_name = "property_aliases"
|
|
41
|
+
|
|
42
|
+
self.entity_aliases_collection_name = "entity_aliases"
|
|
43
|
+
self.triplets_collection_name = "triplets"
|
|
44
|
+
self.filtered_triplets_collection_name = "filtered_triplets"
|
|
45
|
+
self.ontology_filtered_triplets_collection_name = "ontology_filtered_triplets"
|
|
46
|
+
self.initial_triplets_collection_name = "initial_triplets"
|
|
47
|
+
self.entities_vector_index_name = "entity_aliases"
|
|
48
|
+
|
|
49
|
+
self.device = torch.device(device)
|
|
50
|
+
# self.tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY"))
|
|
51
|
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
|
|
52
|
+
# self.model = AutoModel.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY")).to(self.device)
|
|
53
|
+
self.model = AutoModel.from_pretrained(
|
|
54
|
+
"facebook/contriever", use_safetensors=True
|
|
55
|
+
).to(self.device)
|
|
56
|
+
|
|
57
|
+
def get_embedding(self, text):
|
|
58
|
+
|
|
59
|
+
def mean_pooling(token_embeddings, mask):
|
|
60
|
+
token_embeddings = token_embeddings.masked_fill(
|
|
61
|
+
~mask[..., None].bool(), 0.0
|
|
62
|
+
)
|
|
63
|
+
sentence_embeddings = (
|
|
64
|
+
token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
|
|
65
|
+
)
|
|
66
|
+
return sentence_embeddings
|
|
67
|
+
|
|
68
|
+
if not text or not isinstance(text, str):
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
inputs = self.tokenizer(
|
|
72
|
+
[text], padding=True, truncation=True, return_tensors="pt"
|
|
73
|
+
)
|
|
74
|
+
outputs = self.model(**inputs.to(self.device))
|
|
75
|
+
embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
|
|
76
|
+
return embeddings.detach().cpu().tolist()[0]
|
|
77
|
+
|
|
78
|
+
def _get_unique_similar_entity_types(
|
|
79
|
+
self, target_entity_type: str, k: int = 5, max_attempts: int = 10
|
|
80
|
+
) -> List[str]:
|
|
81
|
+
# retrieve k most similar entity types to the given triplet
|
|
82
|
+
# using the entity type index
|
|
83
|
+
# return the wikidata ids of the most similar entity types
|
|
84
|
+
|
|
85
|
+
query_k = k * 2
|
|
86
|
+
attempt = 0
|
|
87
|
+
unique_ranked_entities: List[str] = []
|
|
88
|
+
query_embedding = self.get_embedding(target_entity_type)
|
|
89
|
+
collection = self.ontology_db.get_collection(
|
|
90
|
+
self.entity_type_aliases_collection_name
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# as we search among aliases, there can be duplicated original entitites
|
|
94
|
+
# and as we want K unique entities in result, we querying the index until we get exactly K unique entities
|
|
95
|
+
while len(unique_ranked_entities) < k and attempt < max_attempts:
|
|
96
|
+
search_pipeline = [
|
|
97
|
+
{
|
|
98
|
+
"$vectorSearch": {
|
|
99
|
+
"index": self.entity_type_vector_index_name, #
|
|
100
|
+
"queryVector": query_embedding,
|
|
101
|
+
"path": "alias_text_embedding",
|
|
102
|
+
"numCandidates": 150 if query_k < 150 else query_k,
|
|
103
|
+
"limit": query_k,
|
|
104
|
+
}
|
|
105
|
+
},
|
|
106
|
+
{"$project": {"_id": 0, "entity_type_id": 1}},
|
|
107
|
+
]
|
|
108
|
+
result = collection.aggregate(search_pipeline)
|
|
109
|
+
for res in result:
|
|
110
|
+
if res["entity_type_id"] not in unique_ranked_entities:
|
|
111
|
+
unique_ranked_entities.append(res["entity_type_id"])
|
|
112
|
+
if len(unique_ranked_entities) == k:
|
|
113
|
+
break
|
|
114
|
+
query_k *= 2
|
|
115
|
+
attempt += 1
|
|
116
|
+
|
|
117
|
+
return unique_ranked_entities
|
|
118
|
+
|
|
119
|
+
def retrieve_similar_entity_types(
|
|
120
|
+
self, triplet: Dict[str, str], k: int = 10
|
|
121
|
+
) -> Tuple[List[str], List[str]]:
|
|
122
|
+
|
|
123
|
+
similar_subject_types = self._get_unique_similar_entity_types(
|
|
124
|
+
target_entity_type=triplet["subject_type"], k=k
|
|
125
|
+
)
|
|
126
|
+
if "object_type" in triplet:
|
|
127
|
+
|
|
128
|
+
similar_object_types = self._get_unique_similar_entity_types(
|
|
129
|
+
target_entity_type=triplet["object_type"], k=k
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
similar_object_types = []
|
|
133
|
+
return similar_subject_types, similar_object_types
|
|
134
|
+
|
|
135
|
+
def _get_valid_property_ids_by_entity_type(
|
|
136
|
+
self, entity_type: str, is_object: bool = True
|
|
137
|
+
) -> Tuple[Set[str], Set[str]]:
|
|
138
|
+
"""
|
|
139
|
+
Get direct and inverse properties for an entity type.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
entity_type: The entity type to look up
|
|
143
|
+
is_object: Whether this is an object type in triplet (True) or a subject type (False)
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
collection = self.ontology_db.get_collection(self.entity_type_collection_name)
|
|
147
|
+
|
|
148
|
+
# Get extended types including supertypes
|
|
149
|
+
extended_types = [entity_type, "ANY"]
|
|
150
|
+
hirerarchy = collection.find_one(
|
|
151
|
+
{"entity_type_id": entity_type}, {"parent_type_ids": 1, "_id": 0}
|
|
152
|
+
)
|
|
153
|
+
extended_types.extend(hirerarchy["parent_type_ids"])
|
|
154
|
+
|
|
155
|
+
pipeline = [
|
|
156
|
+
{"$match": {"entity_type_id": {"$in": extended_types}}},
|
|
157
|
+
{
|
|
158
|
+
"$group": {
|
|
159
|
+
"_id": None,
|
|
160
|
+
"subject_ids": {
|
|
161
|
+
"$addToSet": {"$ifNull": ["$valid_subject_property_ids", []]}
|
|
162
|
+
},
|
|
163
|
+
"object_ids": {
|
|
164
|
+
"$addToSet": {"$ifNull": ["$valid_object_property_ids", []]}
|
|
165
|
+
},
|
|
166
|
+
}
|
|
167
|
+
},
|
|
168
|
+
{
|
|
169
|
+
"$project": {
|
|
170
|
+
"subject_ids": {
|
|
171
|
+
"$reduce": {
|
|
172
|
+
"input": "$subject_ids",
|
|
173
|
+
"initialValue": [],
|
|
174
|
+
"in": {"$setUnion": ["$$value", "$$this"]},
|
|
175
|
+
}
|
|
176
|
+
},
|
|
177
|
+
"object_ids": {
|
|
178
|
+
"$reduce": {
|
|
179
|
+
"input": "$object_ids",
|
|
180
|
+
"initialValue": [],
|
|
181
|
+
"in": {"$setUnion": ["$$value", "$$this"]},
|
|
182
|
+
}
|
|
183
|
+
},
|
|
184
|
+
}
|
|
185
|
+
},
|
|
186
|
+
]
|
|
187
|
+
result = collection.aggregate(pipeline)
|
|
188
|
+
|
|
189
|
+
result_data = next(result, {})
|
|
190
|
+
|
|
191
|
+
subject_props = result_data.get("subject_ids", [])
|
|
192
|
+
object_props = result_data.get("object_ids", [])
|
|
193
|
+
|
|
194
|
+
if is_object:
|
|
195
|
+
direct_props = set(object_props)
|
|
196
|
+
inverse_props = set(subject_props)
|
|
197
|
+
else:
|
|
198
|
+
direct_props = set(subject_props)
|
|
199
|
+
inverse_props = set(object_props)
|
|
200
|
+
|
|
201
|
+
return direct_props, inverse_props
|
|
202
|
+
|
|
203
|
+
def _get_ranked_properties(
|
|
204
|
+
self,
|
|
205
|
+
prop_2_direction: Dict[str, List[str]],
|
|
206
|
+
target_property: str,
|
|
207
|
+
k: int,
|
|
208
|
+
) -> List[Tuple[str, str]]:
|
|
209
|
+
"""
|
|
210
|
+
Rank properties based on similarity to target relation.
|
|
211
|
+
"""
|
|
212
|
+
collection = self.ontology_db.get_collection(
|
|
213
|
+
self.property_aliases_collection_name
|
|
214
|
+
)
|
|
215
|
+
query_embedding = self.get_embedding(target_property)
|
|
216
|
+
if query_embedding is None:
|
|
217
|
+
return []
|
|
218
|
+
props = list(prop_2_direction.keys())
|
|
219
|
+
|
|
220
|
+
query_k = k * 2
|
|
221
|
+
max_attempts = 5
|
|
222
|
+
attempt = 0
|
|
223
|
+
unique_ranked_properties: List[str] = []
|
|
224
|
+
|
|
225
|
+
while len(unique_ranked_properties) < k and attempt < max_attempts:
|
|
226
|
+
|
|
227
|
+
pipeline = [
|
|
228
|
+
{
|
|
229
|
+
"$vectorSearch": {
|
|
230
|
+
"index": self.property_vector_index_name,
|
|
231
|
+
"queryVector": query_embedding,
|
|
232
|
+
"path": "alias_text_embedding",
|
|
233
|
+
"numCandidates": 150 if query_k < 150 else query_k,
|
|
234
|
+
"limit": query_k,
|
|
235
|
+
"filter": {"relation_id": {"$in": props}},
|
|
236
|
+
}
|
|
237
|
+
},
|
|
238
|
+
{
|
|
239
|
+
"$project": {
|
|
240
|
+
"_id": 0,
|
|
241
|
+
"relation_id": 1,
|
|
242
|
+
}
|
|
243
|
+
},
|
|
244
|
+
]
|
|
245
|
+
|
|
246
|
+
similar_properties = collection.aggregate(pipeline)
|
|
247
|
+
|
|
248
|
+
for prop in similar_properties:
|
|
249
|
+
if prop["relation_id"] not in unique_ranked_properties:
|
|
250
|
+
unique_ranked_properties.append(prop["relation_id"])
|
|
251
|
+
if len(unique_ranked_properties) == k:
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
query_k *= 2
|
|
255
|
+
attempt += 1
|
|
256
|
+
|
|
257
|
+
# taking into account directions of properties
|
|
258
|
+
unique_ranked_properties_with_direction = []
|
|
259
|
+
for prop_id in unique_ranked_properties:
|
|
260
|
+
for direction in prop_2_direction[prop_id]:
|
|
261
|
+
unique_ranked_properties_with_direction.append((prop_id, direction))
|
|
262
|
+
return unique_ranked_properties_with_direction
|
|
263
|
+
|
|
264
|
+
def retrieve_properties_for_entity_type(
|
|
265
|
+
self,
|
|
266
|
+
target_relation: str, # relation from triplet
|
|
267
|
+
object_types: List[str],
|
|
268
|
+
subject_types: List[str],
|
|
269
|
+
k: int = 10,
|
|
270
|
+
) -> List[Tuple[str, str]]: # List of tuples (<property_id>, <property_direction>)
|
|
271
|
+
"""
|
|
272
|
+
Retrieve and rank properties that match given entity types and relation.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
target_relation: The relation to search for
|
|
276
|
+
object_types: List of valid object types
|
|
277
|
+
subject_types: List of valid subject types
|
|
278
|
+
k: Number of results to return
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
List of tuples (<property_id>, <property_direction>)
|
|
282
|
+
"""
|
|
283
|
+
# Initialize property constraints
|
|
284
|
+
direct_props = PropertyConstraints(set(), set())
|
|
285
|
+
inverse_props = PropertyConstraints(set(), set())
|
|
286
|
+
|
|
287
|
+
# Collect object type properties
|
|
288
|
+
for obj_type in object_types:
|
|
289
|
+
obj_direct, obj_inverse = self._get_valid_property_ids_by_entity_type(
|
|
290
|
+
obj_type, is_object=True
|
|
291
|
+
)
|
|
292
|
+
direct_props.object_properties.update(obj_direct)
|
|
293
|
+
inverse_props.subject_properties.update(obj_inverse)
|
|
294
|
+
|
|
295
|
+
# Collect subject type properties
|
|
296
|
+
for subj_type in subject_types:
|
|
297
|
+
subj_direct, subj_inverse = self._get_valid_property_ids_by_entity_type(
|
|
298
|
+
subj_type, is_object=False
|
|
299
|
+
)
|
|
300
|
+
direct_props.subject_properties.update(subj_direct)
|
|
301
|
+
inverse_props.object_properties.update(subj_inverse)
|
|
302
|
+
|
|
303
|
+
# Find valid properties that satisfy both subject and object constraints
|
|
304
|
+
valid_direct = direct_props.subject_properties & direct_props.object_properties
|
|
305
|
+
valid_inverse = (
|
|
306
|
+
inverse_props.subject_properties & inverse_props.object_properties
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
prop_id_2_direction = {prop_id: ["direct"] for prop_id in valid_direct}
|
|
310
|
+
for prop_id in valid_inverse:
|
|
311
|
+
if prop_id in prop_id_2_direction:
|
|
312
|
+
prop_id_2_direction[prop_id].append("inverse")
|
|
313
|
+
else:
|
|
314
|
+
prop_id_2_direction[prop_id] = ["inverse"]
|
|
315
|
+
|
|
316
|
+
return self._get_ranked_properties(prop_id_2_direction, target_relation, k)
|
|
317
|
+
|
|
318
|
+
def retrieve_properties_labels_and_constraints(
|
|
319
|
+
self, property_id_list: List[str]
|
|
320
|
+
) -> Dict[str, Dict[str, str]]:
|
|
321
|
+
collection = self.ontology_db.get_collection(self.property_collection_name)
|
|
322
|
+
|
|
323
|
+
pipeline = [
|
|
324
|
+
{"$match": {"property_id": {"$in": property_id_list}}},
|
|
325
|
+
{
|
|
326
|
+
"$project": {
|
|
327
|
+
"_id": 0,
|
|
328
|
+
"property_id": 1,
|
|
329
|
+
"label": 1,
|
|
330
|
+
"valid_subject_type_ids": 1,
|
|
331
|
+
"valid_object_type_ids": 1,
|
|
332
|
+
}
|
|
333
|
+
},
|
|
334
|
+
]
|
|
335
|
+
result = collection.aggregate(pipeline)
|
|
336
|
+
|
|
337
|
+
result_dict = {
|
|
338
|
+
item["property_id"]: {
|
|
339
|
+
"label": item["label"],
|
|
340
|
+
"valid_subject_type_ids": item["valid_subject_type_ids"],
|
|
341
|
+
"valid_object_type_ids": item["valid_object_type_ids"],
|
|
342
|
+
}
|
|
343
|
+
for item in result
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
return result_dict
|
|
347
|
+
|
|
348
|
+
def retrieve_entity_type_labels(self, entity_type_ids: List[str]):
|
|
349
|
+
collection = self.ontology_db.get_collection(self.entity_type_collection_name)
|
|
350
|
+
pipeline = [
|
|
351
|
+
{"$match": {"entity_type_id": {"$in": entity_type_ids}}},
|
|
352
|
+
{
|
|
353
|
+
"$project": {
|
|
354
|
+
"_id": 0,
|
|
355
|
+
"entity_type_id": 1,
|
|
356
|
+
"label": 1,
|
|
357
|
+
}
|
|
358
|
+
},
|
|
359
|
+
]
|
|
360
|
+
result = collection.aggregate(pipeline)
|
|
361
|
+
|
|
362
|
+
result_dict = {item["entity_type_id"]: item["label"] for item in result}
|
|
363
|
+
|
|
364
|
+
return result_dict
|
|
365
|
+
|
|
366
|
+
def retrieve_entity_type_hierarchy(self, entity_type: str) -> List[str]:
|
|
367
|
+
collection = self.ontology_db.get_collection(self.entity_type_collection_name)
|
|
368
|
+
entity_id_parent_types = collection.find_one(
|
|
369
|
+
{"label": entity_type},
|
|
370
|
+
{"entity_type_id": 1, "parent_type_ids": 1, "label": 1, "_id": 0},
|
|
371
|
+
)
|
|
372
|
+
parent_type_id_labels = collection.find(
|
|
373
|
+
{"entity_type_id": {"$in": entity_id_parent_types["parent_type_ids"]}},
|
|
374
|
+
{"_id": 0, "label": 1, "entity_type_id": 1},
|
|
375
|
+
)
|
|
376
|
+
if entity_id_parent_types:
|
|
377
|
+
extended_types = [entity_id_parent_types["entity_type_id"]] + [
|
|
378
|
+
item["entity_type_id"] for item in parent_type_id_labels
|
|
379
|
+
]
|
|
380
|
+
|
|
381
|
+
return extended_types
|
|
382
|
+
|
|
383
|
+
def retrieve_entity_by_type(self, entity_name, entity_type, sample_id, k=10):
|
|
384
|
+
|
|
385
|
+
collection = self.ontology_db.get_collection(self.entity_type_collection_name)
|
|
386
|
+
entity_id_parent_types = collection.find_one(
|
|
387
|
+
{"label": entity_type},
|
|
388
|
+
{"entity_type_id": 1, "parent_type_ids": 1, "label": 1, "_id": 0},
|
|
389
|
+
)
|
|
390
|
+
extended_types = [
|
|
391
|
+
entity_id_parent_types["entity_type_id"]
|
|
392
|
+
] + entity_id_parent_types["parent_type_ids"]
|
|
393
|
+
extended_types = [
|
|
394
|
+
elem["label"]
|
|
395
|
+
for elem in collection.find(
|
|
396
|
+
{"entity_type_id": {"$in": extended_types}},
|
|
397
|
+
{"_id": 0, "label": 1, "entity_type_id": 1},
|
|
398
|
+
)
|
|
399
|
+
]
|
|
400
|
+
|
|
401
|
+
collection = self.triplets_db.get_collection(
|
|
402
|
+
self.entity_aliases_collection_name
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
query_embedding = self.get_embedding(entity_name)
|
|
406
|
+
if query_embedding is None:
|
|
407
|
+
return {}
|
|
408
|
+
|
|
409
|
+
if not sample_id:
|
|
410
|
+
filter_query = {
|
|
411
|
+
"entity_type": {"$in": extended_types},
|
|
412
|
+
}
|
|
413
|
+
else:
|
|
414
|
+
filter_query = {
|
|
415
|
+
"entity_type": {"$in": extended_types},
|
|
416
|
+
"sample_id": {"$eq": sample_id},
|
|
417
|
+
}
|
|
418
|
+
pipeline = [
|
|
419
|
+
{
|
|
420
|
+
"$vectorSearch": {
|
|
421
|
+
"index": self.entities_vector_index_name,
|
|
422
|
+
"queryVector": query_embedding,
|
|
423
|
+
"path": "alias_text_embedding",
|
|
424
|
+
"numCandidates": 150 if k < 150 else k,
|
|
425
|
+
"limit": k,
|
|
426
|
+
"filter": filter_query,
|
|
427
|
+
}
|
|
428
|
+
},
|
|
429
|
+
{"$project": {"_id": 0, "label": 1, "alias": 1}},
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
result = collection.aggregate(pipeline)
|
|
433
|
+
result_dict = {item["alias"]: item["label"] for item in result}
|
|
434
|
+
|
|
435
|
+
return result_dict
|
|
436
|
+
|
|
437
|
+
def add_entity(self, entity_name, alias, entity_type, sample_id):
|
|
438
|
+
|
|
439
|
+
collection = self.triplets_db.get_collection(
|
|
440
|
+
self.entity_aliases_collection_name
|
|
441
|
+
)
|
|
442
|
+
if not sample_id:
|
|
443
|
+
sample_id = "all"
|
|
444
|
+
|
|
445
|
+
if not collection.find_one(
|
|
446
|
+
{
|
|
447
|
+
"label": entity_name,
|
|
448
|
+
"entity_type": entity_type,
|
|
449
|
+
"alias": alias,
|
|
450
|
+
"sample_id": {"$eq": sample_id},
|
|
451
|
+
}
|
|
452
|
+
):
|
|
453
|
+
|
|
454
|
+
collection.insert_one(
|
|
455
|
+
{
|
|
456
|
+
"label": entity_name,
|
|
457
|
+
"entity_type": entity_type,
|
|
458
|
+
"alias": alias,
|
|
459
|
+
"sample_id": sample_id,
|
|
460
|
+
"alias_text_embedding": self.get_embedding(alias),
|
|
461
|
+
}
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
def add_triplets(self, triplets_list, sample_id):
|
|
465
|
+
collection = self.triplets_db.get_collection(self.triplets_collection_name)
|
|
466
|
+
|
|
467
|
+
operations = []
|
|
468
|
+
if not sample_id:
|
|
469
|
+
sample_id = "all"
|
|
470
|
+
for triple in triplets_list:
|
|
471
|
+
triple["sample_id"] = sample_id
|
|
472
|
+
filter_query = {
|
|
473
|
+
"subject": triple["subject"],
|
|
474
|
+
"relation": triple["relation"],
|
|
475
|
+
"object": triple["object"],
|
|
476
|
+
"subject_type": triple["subject_type"],
|
|
477
|
+
"object_type": triple["object_type"],
|
|
478
|
+
"sample_id": triple["sample_id"],
|
|
479
|
+
}
|
|
480
|
+
operations.append(
|
|
481
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
if operations:
|
|
485
|
+
collection.bulk_write(operations)
|
|
486
|
+
|
|
487
|
+
def add_filtered_triplets(self, triplets_list, sample_id):
|
|
488
|
+
collection = self.triplets_db.get_collection(
|
|
489
|
+
self.filtered_triplets_collection_name
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
operations = []
|
|
493
|
+
if not sample_id:
|
|
494
|
+
sample_id = "all"
|
|
495
|
+
for triple in triplets_list:
|
|
496
|
+
triple["sample_id"] = sample_id
|
|
497
|
+
filter_query = {
|
|
498
|
+
"subject": triple["subject"],
|
|
499
|
+
"relation": triple["relation"],
|
|
500
|
+
"object": triple["object"],
|
|
501
|
+
"subject_type": triple["subject_type"],
|
|
502
|
+
"object_type": triple["object_type"],
|
|
503
|
+
"sample_id": triple["sample_id"],
|
|
504
|
+
}
|
|
505
|
+
operations.append(
|
|
506
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
if operations:
|
|
510
|
+
collection.bulk_write(operations)
|
|
511
|
+
|
|
512
|
+
def add_ontology_filtered_triplets(self, triplets_list, sample_id):
|
|
513
|
+
collection = self.triplets_db.get_collection(
|
|
514
|
+
self.ontology_filtered_triplets_collection_name
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
operations = []
|
|
518
|
+
if not sample_id:
|
|
519
|
+
sample_id = "all"
|
|
520
|
+
for triple in triplets_list:
|
|
521
|
+
triple["sample_id"] = sample_id
|
|
522
|
+
filter_query = {
|
|
523
|
+
"subject": triple["subject"],
|
|
524
|
+
"relation": triple["relation"],
|
|
525
|
+
"object": triple["object"],
|
|
526
|
+
"subject_type": triple["subject_type"],
|
|
527
|
+
"object_type": triple["object_type"],
|
|
528
|
+
"sample_id": triple["sample_id"],
|
|
529
|
+
}
|
|
530
|
+
operations.append(
|
|
531
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if operations:
|
|
535
|
+
collection.bulk_write(operations)
|
|
536
|
+
|
|
537
|
+
def add_initial_triplets(self, triplets_list, sample_id):
|
|
538
|
+
if not sample_id:
|
|
539
|
+
sample_id = "all"
|
|
540
|
+
collection = self.triplets_db.get_collection(
|
|
541
|
+
self.initial_triplets_collection_name
|
|
542
|
+
)
|
|
543
|
+
operations = []
|
|
544
|
+
for triple in triplets_list:
|
|
545
|
+
triple["sample_id"] = sample_id
|
|
546
|
+
filter_query = {
|
|
547
|
+
"subject": triple["subject"],
|
|
548
|
+
"relation": triple["relation"],
|
|
549
|
+
"object": triple["object"],
|
|
550
|
+
"subject_type": triple["subject_type"],
|
|
551
|
+
"object_type": triple["object_type"],
|
|
552
|
+
"sample_id": triple["sample_id"],
|
|
553
|
+
}
|
|
554
|
+
operations.append(
|
|
555
|
+
UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
|
|
556
|
+
)
|
|
557
|
+
if operations:
|
|
558
|
+
collection.bulk_write(operations)
|
|
559
|
+
|
|
560
|
+
def retrieve_similar_entity_names(
|
|
561
|
+
self, entity_name: str, k: int = 10, sample_id: str = None
|
|
562
|
+
) -> List[Dict[str, str]]:
|
|
563
|
+
embedded_query = self.get_embedding(entity_name)
|
|
564
|
+
if embedded_query is None:
|
|
565
|
+
return []
|
|
566
|
+
collection = self.triplets_db.get_collection(
|
|
567
|
+
self.entity_aliases_collection_name
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# First try to search with sample_id filter if provided
|
|
571
|
+
if sample_id:
|
|
572
|
+
pipeline = [
|
|
573
|
+
{
|
|
574
|
+
"$vectorSearch": {
|
|
575
|
+
"index": self.entities_vector_index_name,
|
|
576
|
+
"queryVector": embedded_query,
|
|
577
|
+
"path": "alias_text_embedding",
|
|
578
|
+
"numCandidates": 150,
|
|
579
|
+
"limit": k,
|
|
580
|
+
"filter": {
|
|
581
|
+
"sample_id": {"$eq": sample_id},
|
|
582
|
+
},
|
|
583
|
+
}
|
|
584
|
+
},
|
|
585
|
+
{"$project": {"_id": 0, "label": 1, "entity_type": 1}},
|
|
586
|
+
]
|
|
587
|
+
else:
|
|
588
|
+
pipeline = [
|
|
589
|
+
{
|
|
590
|
+
"$vectorSearch": {
|
|
591
|
+
"index": self.entities_vector_index_name,
|
|
592
|
+
"queryVector": embedded_query,
|
|
593
|
+
"path": "alias_text_embedding",
|
|
594
|
+
"numCandidates": 150,
|
|
595
|
+
"limit": k,
|
|
596
|
+
}
|
|
597
|
+
},
|
|
598
|
+
{"$project": {"_id": 0, "label": 1, "entity_type": 1}},
|
|
599
|
+
]
|
|
600
|
+
|
|
601
|
+
result = collection.aggregate(pipeline)
|
|
602
|
+
result_list = list(result)
|
|
603
|
+
|
|
604
|
+
result_dict = [{"entity": item["label"]} for item in result_list]
|
|
605
|
+
|
|
606
|
+
return result_dict
|