wikontic 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. wikontic/__init__.py +16 -0
  2. wikontic/create_ontological_triplets_db.py +193 -0
  3. wikontic/create_triplets_db.py +259 -0
  4. wikontic/create_wikidata_ontology_db.py +555 -0
  5. wikontic/utils/__init__.py +7 -0
  6. wikontic/utils/base_inference_with_db.py +329 -0
  7. wikontic/utils/dynamic_aligner.py +281 -0
  8. wikontic/utils/inference_with_db.py +224 -0
  9. wikontic/utils/ontology_mappings/entity_hierarchy.json +1 -0
  10. wikontic/utils/ontology_mappings/entity_names.json +1 -0
  11. wikontic/utils/ontology_mappings/entity_type2aliases.json +1 -0
  12. wikontic/utils/ontology_mappings/entity_type2hierarchy.json +1 -0
  13. wikontic/utils/ontology_mappings/entity_type2label.json +1 -0
  14. wikontic/utils/ontology_mappings/enum_entity_ids.json +1 -0
  15. wikontic/utils/ontology_mappings/enum_prop_ids.json +1 -0
  16. wikontic/utils/ontology_mappings/label2entity.json +1 -0
  17. wikontic/utils/ontology_mappings/obj_constraint2prop.json +1 -0
  18. wikontic/utils/ontology_mappings/prop2aliases.json +1 -0
  19. wikontic/utils/ontology_mappings/prop2constraints.json +1 -0
  20. wikontic/utils/ontology_mappings/prop2data_type.json +1 -0
  21. wikontic/utils/ontology_mappings/prop2label.json +1 -0
  22. wikontic/utils/ontology_mappings/propid2enum.json +1 -0
  23. wikontic/utils/ontology_mappings/subj_constraint2prop.json +1 -0
  24. wikontic/utils/ontology_mappings/subject_object_constraints.json +1 -0
  25. wikontic/utils/openai_utils.py +517 -0
  26. wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types.txt +17 -0
  27. wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types_dialog_bench.txt +18 -0
  28. wikontic/utils/prompts/name_refinement/rank_object_names.txt +17 -0
  29. wikontic/utils/prompts/name_refinement/rank_object_names_dialog_bench.txt +18 -0
  30. wikontic/utils/prompts/name_refinement/rank_object_qualifiers.txt +20 -0
  31. wikontic/utils/prompts/name_refinement/rank_subject_names.txt +18 -0
  32. wikontic/utils/prompts/name_refinement/rank_subject_names_dialog_bench.txt +20 -0
  33. wikontic/utils/prompts/ontology_refinement/prompt_choose_entity_types.txt +26 -0
  34. wikontic/utils/prompts/ontology_refinement/prompt_choose_relation.txt +24 -0
  35. wikontic/utils/prompts/ontology_refinement/prompt_choose_relation_and_types.txt +28 -0
  36. wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question.txt +17 -0
  37. wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question_wo_types.txt +16 -0
  38. wikontic/utils/prompts/qa/prompt_entity_extraction_from_question.txt +3 -0
  39. wikontic/utils/prompts/qa/prompt_is_answered.txt +43 -0
  40. wikontic/utils/prompts/qa/qa_collapsing_prompt.txt +22 -0
  41. wikontic/utils/prompts/qa/qa_prompt.txt +5 -0
  42. wikontic/utils/prompts/qa/qa_prompt_hotpot.txt +6 -0
  43. wikontic/utils/prompts/qa/question_decomposition_1.txt +7 -0
  44. wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench.txt +75 -0
  45. wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench_in_russian.txt +78 -0
  46. wikontic/utils/prompts/triplet_extraction/propmt_1_types_qualifiers.txt +91 -0
  47. wikontic/utils/structured_aligner.py +606 -0
  48. wikontic/utils/structured_inference_with_db.py +561 -0
  49. wikontic-0.0.3.dist-info/METADATA +111 -0
  50. wikontic-0.0.3.dist-info/RECORD +53 -0
  51. wikontic-0.0.3.dist-info/WHEEL +5 -0
  52. wikontic-0.0.3.dist-info/licenses/LICENSE +19 -0
  53. wikontic-0.0.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,606 @@
1
+ from typing import List, Tuple, Set, Dict
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from dataclasses import dataclass
4
+ from pydantic import BaseModel, ValidationError
5
+ from pymongo import MongoClient, UpdateOne
6
+ import torch
7
+ from dotenv import load_dotenv, find_dotenv
8
+ import os
9
+
10
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
11
+ _ = load_dotenv(find_dotenv())
12
+
13
+
14
+ @dataclass
15
+ class PropertyConstraints:
16
+ subject_properties: Set[str]
17
+ object_properties: Set[str]
18
+
19
+
20
+ class EntityAlias(BaseModel):
21
+ _id: int
22
+ label: str
23
+ entity_type: str
24
+ alias: str
25
+ sample_id: str
26
+ alias_text_embedding: List[float]
27
+
28
+
29
+ class Aligner:
30
+ def __init__(self, ontology_db, triplets_db, device="cuda:0"):
31
+ self.ontology_db = ontology_db
32
+ self.triplets_db = triplets_db
33
+
34
+ self.entity_type_collection_name = "entity_types"
35
+ self.entity_type_aliases_collection_name = "entity_type_aliases"
36
+ self.property_collection_name = "properties"
37
+ self.property_aliases_collection_name = "property_aliases"
38
+
39
+ self.entity_type_vector_index_name = "entity_type_aliases"
40
+ self.property_vector_index_name = "property_aliases"
41
+
42
+ self.entity_aliases_collection_name = "entity_aliases"
43
+ self.triplets_collection_name = "triplets"
44
+ self.filtered_triplets_collection_name = "filtered_triplets"
45
+ self.ontology_filtered_triplets_collection_name = "ontology_filtered_triplets"
46
+ self.initial_triplets_collection_name = "initial_triplets"
47
+ self.entities_vector_index_name = "entity_aliases"
48
+
49
+ self.device = torch.device(device)
50
+ # self.tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY"))
51
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
52
+ # self.model = AutoModel.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY")).to(self.device)
53
+ self.model = AutoModel.from_pretrained(
54
+ "facebook/contriever", use_safetensors=True
55
+ ).to(self.device)
56
+
57
+ def get_embedding(self, text):
58
+
59
+ def mean_pooling(token_embeddings, mask):
60
+ token_embeddings = token_embeddings.masked_fill(
61
+ ~mask[..., None].bool(), 0.0
62
+ )
63
+ sentence_embeddings = (
64
+ token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
65
+ )
66
+ return sentence_embeddings
67
+
68
+ if not text or not isinstance(text, str):
69
+ return None
70
+
71
+ inputs = self.tokenizer(
72
+ [text], padding=True, truncation=True, return_tensors="pt"
73
+ )
74
+ outputs = self.model(**inputs.to(self.device))
75
+ embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
76
+ return embeddings.detach().cpu().tolist()[0]
77
+
78
+ def _get_unique_similar_entity_types(
79
+ self, target_entity_type: str, k: int = 5, max_attempts: int = 10
80
+ ) -> List[str]:
81
+ # retrieve k most similar entity types to the given triplet
82
+ # using the entity type index
83
+ # return the wikidata ids of the most similar entity types
84
+
85
+ query_k = k * 2
86
+ attempt = 0
87
+ unique_ranked_entities: List[str] = []
88
+ query_embedding = self.get_embedding(target_entity_type)
89
+ collection = self.ontology_db.get_collection(
90
+ self.entity_type_aliases_collection_name
91
+ )
92
+
93
+ # as we search among aliases, there can be duplicated original entitites
94
+ # and as we want K unique entities in result, we querying the index until we get exactly K unique entities
95
+ while len(unique_ranked_entities) < k and attempt < max_attempts:
96
+ search_pipeline = [
97
+ {
98
+ "$vectorSearch": {
99
+ "index": self.entity_type_vector_index_name, #
100
+ "queryVector": query_embedding,
101
+ "path": "alias_text_embedding",
102
+ "numCandidates": 150 if query_k < 150 else query_k,
103
+ "limit": query_k,
104
+ }
105
+ },
106
+ {"$project": {"_id": 0, "entity_type_id": 1}},
107
+ ]
108
+ result = collection.aggregate(search_pipeline)
109
+ for res in result:
110
+ if res["entity_type_id"] not in unique_ranked_entities:
111
+ unique_ranked_entities.append(res["entity_type_id"])
112
+ if len(unique_ranked_entities) == k:
113
+ break
114
+ query_k *= 2
115
+ attempt += 1
116
+
117
+ return unique_ranked_entities
118
+
119
+ def retrieve_similar_entity_types(
120
+ self, triplet: Dict[str, str], k: int = 10
121
+ ) -> Tuple[List[str], List[str]]:
122
+
123
+ similar_subject_types = self._get_unique_similar_entity_types(
124
+ target_entity_type=triplet["subject_type"], k=k
125
+ )
126
+ if "object_type" in triplet:
127
+
128
+ similar_object_types = self._get_unique_similar_entity_types(
129
+ target_entity_type=triplet["object_type"], k=k
130
+ )
131
+ else:
132
+ similar_object_types = []
133
+ return similar_subject_types, similar_object_types
134
+
135
+ def _get_valid_property_ids_by_entity_type(
136
+ self, entity_type: str, is_object: bool = True
137
+ ) -> Tuple[Set[str], Set[str]]:
138
+ """
139
+ Get direct and inverse properties for an entity type.
140
+
141
+ Args:
142
+ entity_type: The entity type to look up
143
+ is_object: Whether this is an object type in triplet (True) or a subject type (False)
144
+ """
145
+
146
+ collection = self.ontology_db.get_collection(self.entity_type_collection_name)
147
+
148
+ # Get extended types including supertypes
149
+ extended_types = [entity_type, "ANY"]
150
+ hirerarchy = collection.find_one(
151
+ {"entity_type_id": entity_type}, {"parent_type_ids": 1, "_id": 0}
152
+ )
153
+ extended_types.extend(hirerarchy["parent_type_ids"])
154
+
155
+ pipeline = [
156
+ {"$match": {"entity_type_id": {"$in": extended_types}}},
157
+ {
158
+ "$group": {
159
+ "_id": None,
160
+ "subject_ids": {
161
+ "$addToSet": {"$ifNull": ["$valid_subject_property_ids", []]}
162
+ },
163
+ "object_ids": {
164
+ "$addToSet": {"$ifNull": ["$valid_object_property_ids", []]}
165
+ },
166
+ }
167
+ },
168
+ {
169
+ "$project": {
170
+ "subject_ids": {
171
+ "$reduce": {
172
+ "input": "$subject_ids",
173
+ "initialValue": [],
174
+ "in": {"$setUnion": ["$$value", "$$this"]},
175
+ }
176
+ },
177
+ "object_ids": {
178
+ "$reduce": {
179
+ "input": "$object_ids",
180
+ "initialValue": [],
181
+ "in": {"$setUnion": ["$$value", "$$this"]},
182
+ }
183
+ },
184
+ }
185
+ },
186
+ ]
187
+ result = collection.aggregate(pipeline)
188
+
189
+ result_data = next(result, {})
190
+
191
+ subject_props = result_data.get("subject_ids", [])
192
+ object_props = result_data.get("object_ids", [])
193
+
194
+ if is_object:
195
+ direct_props = set(object_props)
196
+ inverse_props = set(subject_props)
197
+ else:
198
+ direct_props = set(subject_props)
199
+ inverse_props = set(object_props)
200
+
201
+ return direct_props, inverse_props
202
+
203
+ def _get_ranked_properties(
204
+ self,
205
+ prop_2_direction: Dict[str, List[str]],
206
+ target_property: str,
207
+ k: int,
208
+ ) -> List[Tuple[str, str]]:
209
+ """
210
+ Rank properties based on similarity to target relation.
211
+ """
212
+ collection = self.ontology_db.get_collection(
213
+ self.property_aliases_collection_name
214
+ )
215
+ query_embedding = self.get_embedding(target_property)
216
+ if query_embedding is None:
217
+ return []
218
+ props = list(prop_2_direction.keys())
219
+
220
+ query_k = k * 2
221
+ max_attempts = 5
222
+ attempt = 0
223
+ unique_ranked_properties: List[str] = []
224
+
225
+ while len(unique_ranked_properties) < k and attempt < max_attempts:
226
+
227
+ pipeline = [
228
+ {
229
+ "$vectorSearch": {
230
+ "index": self.property_vector_index_name,
231
+ "queryVector": query_embedding,
232
+ "path": "alias_text_embedding",
233
+ "numCandidates": 150 if query_k < 150 else query_k,
234
+ "limit": query_k,
235
+ "filter": {"relation_id": {"$in": props}},
236
+ }
237
+ },
238
+ {
239
+ "$project": {
240
+ "_id": 0,
241
+ "relation_id": 1,
242
+ }
243
+ },
244
+ ]
245
+
246
+ similar_properties = collection.aggregate(pipeline)
247
+
248
+ for prop in similar_properties:
249
+ if prop["relation_id"] not in unique_ranked_properties:
250
+ unique_ranked_properties.append(prop["relation_id"])
251
+ if len(unique_ranked_properties) == k:
252
+ break
253
+
254
+ query_k *= 2
255
+ attempt += 1
256
+
257
+ # taking into account directions of properties
258
+ unique_ranked_properties_with_direction = []
259
+ for prop_id in unique_ranked_properties:
260
+ for direction in prop_2_direction[prop_id]:
261
+ unique_ranked_properties_with_direction.append((prop_id, direction))
262
+ return unique_ranked_properties_with_direction
263
+
264
+ def retrieve_properties_for_entity_type(
265
+ self,
266
+ target_relation: str, # relation from triplet
267
+ object_types: List[str],
268
+ subject_types: List[str],
269
+ k: int = 10,
270
+ ) -> List[Tuple[str, str]]: # List of tuples (<property_id>, <property_direction>)
271
+ """
272
+ Retrieve and rank properties that match given entity types and relation.
273
+
274
+ Args:
275
+ target_relation: The relation to search for
276
+ object_types: List of valid object types
277
+ subject_types: List of valid subject types
278
+ k: Number of results to return
279
+
280
+ Returns:
281
+ List of tuples (<property_id>, <property_direction>)
282
+ """
283
+ # Initialize property constraints
284
+ direct_props = PropertyConstraints(set(), set())
285
+ inverse_props = PropertyConstraints(set(), set())
286
+
287
+ # Collect object type properties
288
+ for obj_type in object_types:
289
+ obj_direct, obj_inverse = self._get_valid_property_ids_by_entity_type(
290
+ obj_type, is_object=True
291
+ )
292
+ direct_props.object_properties.update(obj_direct)
293
+ inverse_props.subject_properties.update(obj_inverse)
294
+
295
+ # Collect subject type properties
296
+ for subj_type in subject_types:
297
+ subj_direct, subj_inverse = self._get_valid_property_ids_by_entity_type(
298
+ subj_type, is_object=False
299
+ )
300
+ direct_props.subject_properties.update(subj_direct)
301
+ inverse_props.object_properties.update(subj_inverse)
302
+
303
+ # Find valid properties that satisfy both subject and object constraints
304
+ valid_direct = direct_props.subject_properties & direct_props.object_properties
305
+ valid_inverse = (
306
+ inverse_props.subject_properties & inverse_props.object_properties
307
+ )
308
+
309
+ prop_id_2_direction = {prop_id: ["direct"] for prop_id in valid_direct}
310
+ for prop_id in valid_inverse:
311
+ if prop_id in prop_id_2_direction:
312
+ prop_id_2_direction[prop_id].append("inverse")
313
+ else:
314
+ prop_id_2_direction[prop_id] = ["inverse"]
315
+
316
+ return self._get_ranked_properties(prop_id_2_direction, target_relation, k)
317
+
318
+ def retrieve_properties_labels_and_constraints(
319
+ self, property_id_list: List[str]
320
+ ) -> Dict[str, Dict[str, str]]:
321
+ collection = self.ontology_db.get_collection(self.property_collection_name)
322
+
323
+ pipeline = [
324
+ {"$match": {"property_id": {"$in": property_id_list}}},
325
+ {
326
+ "$project": {
327
+ "_id": 0,
328
+ "property_id": 1,
329
+ "label": 1,
330
+ "valid_subject_type_ids": 1,
331
+ "valid_object_type_ids": 1,
332
+ }
333
+ },
334
+ ]
335
+ result = collection.aggregate(pipeline)
336
+
337
+ result_dict = {
338
+ item["property_id"]: {
339
+ "label": item["label"],
340
+ "valid_subject_type_ids": item["valid_subject_type_ids"],
341
+ "valid_object_type_ids": item["valid_object_type_ids"],
342
+ }
343
+ for item in result
344
+ }
345
+
346
+ return result_dict
347
+
348
+ def retrieve_entity_type_labels(self, entity_type_ids: List[str]):
349
+ collection = self.ontology_db.get_collection(self.entity_type_collection_name)
350
+ pipeline = [
351
+ {"$match": {"entity_type_id": {"$in": entity_type_ids}}},
352
+ {
353
+ "$project": {
354
+ "_id": 0,
355
+ "entity_type_id": 1,
356
+ "label": 1,
357
+ }
358
+ },
359
+ ]
360
+ result = collection.aggregate(pipeline)
361
+
362
+ result_dict = {item["entity_type_id"]: item["label"] for item in result}
363
+
364
+ return result_dict
365
+
366
+ def retrieve_entity_type_hierarchy(self, entity_type: str) -> List[str]:
367
+ collection = self.ontology_db.get_collection(self.entity_type_collection_name)
368
+ entity_id_parent_types = collection.find_one(
369
+ {"label": entity_type},
370
+ {"entity_type_id": 1, "parent_type_ids": 1, "label": 1, "_id": 0},
371
+ )
372
+ parent_type_id_labels = collection.find(
373
+ {"entity_type_id": {"$in": entity_id_parent_types["parent_type_ids"]}},
374
+ {"_id": 0, "label": 1, "entity_type_id": 1},
375
+ )
376
+ if entity_id_parent_types:
377
+ extended_types = [entity_id_parent_types["entity_type_id"]] + [
378
+ item["entity_type_id"] for item in parent_type_id_labels
379
+ ]
380
+
381
+ return extended_types
382
+
383
+ def retrieve_entity_by_type(self, entity_name, entity_type, sample_id, k=10):
384
+
385
+ collection = self.ontology_db.get_collection(self.entity_type_collection_name)
386
+ entity_id_parent_types = collection.find_one(
387
+ {"label": entity_type},
388
+ {"entity_type_id": 1, "parent_type_ids": 1, "label": 1, "_id": 0},
389
+ )
390
+ extended_types = [
391
+ entity_id_parent_types["entity_type_id"]
392
+ ] + entity_id_parent_types["parent_type_ids"]
393
+ extended_types = [
394
+ elem["label"]
395
+ for elem in collection.find(
396
+ {"entity_type_id": {"$in": extended_types}},
397
+ {"_id": 0, "label": 1, "entity_type_id": 1},
398
+ )
399
+ ]
400
+
401
+ collection = self.triplets_db.get_collection(
402
+ self.entity_aliases_collection_name
403
+ )
404
+
405
+ query_embedding = self.get_embedding(entity_name)
406
+ if query_embedding is None:
407
+ return {}
408
+
409
+ if not sample_id:
410
+ filter_query = {
411
+ "entity_type": {"$in": extended_types},
412
+ }
413
+ else:
414
+ filter_query = {
415
+ "entity_type": {"$in": extended_types},
416
+ "sample_id": {"$eq": sample_id},
417
+ }
418
+ pipeline = [
419
+ {
420
+ "$vectorSearch": {
421
+ "index": self.entities_vector_index_name,
422
+ "queryVector": query_embedding,
423
+ "path": "alias_text_embedding",
424
+ "numCandidates": 150 if k < 150 else k,
425
+ "limit": k,
426
+ "filter": filter_query,
427
+ }
428
+ },
429
+ {"$project": {"_id": 0, "label": 1, "alias": 1}},
430
+ ]
431
+
432
+ result = collection.aggregate(pipeline)
433
+ result_dict = {item["alias"]: item["label"] for item in result}
434
+
435
+ return result_dict
436
+
437
+ def add_entity(self, entity_name, alias, entity_type, sample_id):
438
+
439
+ collection = self.triplets_db.get_collection(
440
+ self.entity_aliases_collection_name
441
+ )
442
+ if not sample_id:
443
+ sample_id = "all"
444
+
445
+ if not collection.find_one(
446
+ {
447
+ "label": entity_name,
448
+ "entity_type": entity_type,
449
+ "alias": alias,
450
+ "sample_id": {"$eq": sample_id},
451
+ }
452
+ ):
453
+
454
+ collection.insert_one(
455
+ {
456
+ "label": entity_name,
457
+ "entity_type": entity_type,
458
+ "alias": alias,
459
+ "sample_id": sample_id,
460
+ "alias_text_embedding": self.get_embedding(alias),
461
+ }
462
+ )
463
+
464
+ def add_triplets(self, triplets_list, sample_id):
465
+ collection = self.triplets_db.get_collection(self.triplets_collection_name)
466
+
467
+ operations = []
468
+ if not sample_id:
469
+ sample_id = "all"
470
+ for triple in triplets_list:
471
+ triple["sample_id"] = sample_id
472
+ filter_query = {
473
+ "subject": triple["subject"],
474
+ "relation": triple["relation"],
475
+ "object": triple["object"],
476
+ "subject_type": triple["subject_type"],
477
+ "object_type": triple["object_type"],
478
+ "sample_id": triple["sample_id"],
479
+ }
480
+ operations.append(
481
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
482
+ )
483
+
484
+ if operations:
485
+ collection.bulk_write(operations)
486
+
487
+ def add_filtered_triplets(self, triplets_list, sample_id):
488
+ collection = self.triplets_db.get_collection(
489
+ self.filtered_triplets_collection_name
490
+ )
491
+
492
+ operations = []
493
+ if not sample_id:
494
+ sample_id = "all"
495
+ for triple in triplets_list:
496
+ triple["sample_id"] = sample_id
497
+ filter_query = {
498
+ "subject": triple["subject"],
499
+ "relation": triple["relation"],
500
+ "object": triple["object"],
501
+ "subject_type": triple["subject_type"],
502
+ "object_type": triple["object_type"],
503
+ "sample_id": triple["sample_id"],
504
+ }
505
+ operations.append(
506
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
507
+ )
508
+
509
+ if operations:
510
+ collection.bulk_write(operations)
511
+
512
+ def add_ontology_filtered_triplets(self, triplets_list, sample_id):
513
+ collection = self.triplets_db.get_collection(
514
+ self.ontology_filtered_triplets_collection_name
515
+ )
516
+
517
+ operations = []
518
+ if not sample_id:
519
+ sample_id = "all"
520
+ for triple in triplets_list:
521
+ triple["sample_id"] = sample_id
522
+ filter_query = {
523
+ "subject": triple["subject"],
524
+ "relation": triple["relation"],
525
+ "object": triple["object"],
526
+ "subject_type": triple["subject_type"],
527
+ "object_type": triple["object_type"],
528
+ "sample_id": triple["sample_id"],
529
+ }
530
+ operations.append(
531
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
532
+ )
533
+
534
+ if operations:
535
+ collection.bulk_write(operations)
536
+
537
+ def add_initial_triplets(self, triplets_list, sample_id):
538
+ if not sample_id:
539
+ sample_id = "all"
540
+ collection = self.triplets_db.get_collection(
541
+ self.initial_triplets_collection_name
542
+ )
543
+ operations = []
544
+ for triple in triplets_list:
545
+ triple["sample_id"] = sample_id
546
+ filter_query = {
547
+ "subject": triple["subject"],
548
+ "relation": triple["relation"],
549
+ "object": triple["object"],
550
+ "subject_type": triple["subject_type"],
551
+ "object_type": triple["object_type"],
552
+ "sample_id": triple["sample_id"],
553
+ }
554
+ operations.append(
555
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
556
+ )
557
+ if operations:
558
+ collection.bulk_write(operations)
559
+
560
+ def retrieve_similar_entity_names(
561
+ self, entity_name: str, k: int = 10, sample_id: str = None
562
+ ) -> List[Dict[str, str]]:
563
+ embedded_query = self.get_embedding(entity_name)
564
+ if embedded_query is None:
565
+ return []
566
+ collection = self.triplets_db.get_collection(
567
+ self.entity_aliases_collection_name
568
+ )
569
+
570
+ # First try to search with sample_id filter if provided
571
+ if sample_id:
572
+ pipeline = [
573
+ {
574
+ "$vectorSearch": {
575
+ "index": self.entities_vector_index_name,
576
+ "queryVector": embedded_query,
577
+ "path": "alias_text_embedding",
578
+ "numCandidates": 150,
579
+ "limit": k,
580
+ "filter": {
581
+ "sample_id": {"$eq": sample_id},
582
+ },
583
+ }
584
+ },
585
+ {"$project": {"_id": 0, "label": 1, "entity_type": 1}},
586
+ ]
587
+ else:
588
+ pipeline = [
589
+ {
590
+ "$vectorSearch": {
591
+ "index": self.entities_vector_index_name,
592
+ "queryVector": embedded_query,
593
+ "path": "alias_text_embedding",
594
+ "numCandidates": 150,
595
+ "limit": k,
596
+ }
597
+ },
598
+ {"$project": {"_id": 0, "label": 1, "entity_type": 1}},
599
+ ]
600
+
601
+ result = collection.aggregate(pipeline)
602
+ result_list = list(result)
603
+
604
+ result_dict = [{"entity": item["label"]} for item in result_list]
605
+
606
+ return result_dict