wikontic 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. wikontic/__init__.py +16 -0
  2. wikontic/create_ontological_triplets_db.py +193 -0
  3. wikontic/create_triplets_db.py +259 -0
  4. wikontic/create_wikidata_ontology_db.py +555 -0
  5. wikontic/utils/__init__.py +7 -0
  6. wikontic/utils/base_inference_with_db.py +329 -0
  7. wikontic/utils/dynamic_aligner.py +281 -0
  8. wikontic/utils/inference_with_db.py +224 -0
  9. wikontic/utils/ontology_mappings/entity_hierarchy.json +1 -0
  10. wikontic/utils/ontology_mappings/entity_names.json +1 -0
  11. wikontic/utils/ontology_mappings/entity_type2aliases.json +1 -0
  12. wikontic/utils/ontology_mappings/entity_type2hierarchy.json +1 -0
  13. wikontic/utils/ontology_mappings/entity_type2label.json +1 -0
  14. wikontic/utils/ontology_mappings/enum_entity_ids.json +1 -0
  15. wikontic/utils/ontology_mappings/enum_prop_ids.json +1 -0
  16. wikontic/utils/ontology_mappings/label2entity.json +1 -0
  17. wikontic/utils/ontology_mappings/obj_constraint2prop.json +1 -0
  18. wikontic/utils/ontology_mappings/prop2aliases.json +1 -0
  19. wikontic/utils/ontology_mappings/prop2constraints.json +1 -0
  20. wikontic/utils/ontology_mappings/prop2data_type.json +1 -0
  21. wikontic/utils/ontology_mappings/prop2label.json +1 -0
  22. wikontic/utils/ontology_mappings/propid2enum.json +1 -0
  23. wikontic/utils/ontology_mappings/subj_constraint2prop.json +1 -0
  24. wikontic/utils/ontology_mappings/subject_object_constraints.json +1 -0
  25. wikontic/utils/openai_utils.py +517 -0
  26. wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types.txt +17 -0
  27. wikontic/utils/prompts/name_refinement/prompt_choose_relation_wo_entity_types_dialog_bench.txt +18 -0
  28. wikontic/utils/prompts/name_refinement/rank_object_names.txt +17 -0
  29. wikontic/utils/prompts/name_refinement/rank_object_names_dialog_bench.txt +18 -0
  30. wikontic/utils/prompts/name_refinement/rank_object_qualifiers.txt +20 -0
  31. wikontic/utils/prompts/name_refinement/rank_subject_names.txt +18 -0
  32. wikontic/utils/prompts/name_refinement/rank_subject_names_dialog_bench.txt +20 -0
  33. wikontic/utils/prompts/ontology_refinement/prompt_choose_entity_types.txt +26 -0
  34. wikontic/utils/prompts/ontology_refinement/prompt_choose_relation.txt +24 -0
  35. wikontic/utils/prompts/ontology_refinement/prompt_choose_relation_and_types.txt +28 -0
  36. wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question.txt +17 -0
  37. wikontic/utils/prompts/qa/prompt_choose_relevant_entities_for_question_wo_types.txt +16 -0
  38. wikontic/utils/prompts/qa/prompt_entity_extraction_from_question.txt +3 -0
  39. wikontic/utils/prompts/qa/prompt_is_answered.txt +43 -0
  40. wikontic/utils/prompts/qa/qa_collapsing_prompt.txt +22 -0
  41. wikontic/utils/prompts/qa/qa_prompt.txt +5 -0
  42. wikontic/utils/prompts/qa/qa_prompt_hotpot.txt +6 -0
  43. wikontic/utils/prompts/qa/question_decomposition_1.txt +7 -0
  44. wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench.txt +75 -0
  45. wikontic/utils/prompts/triplet_extraction/prompt_1_types_qualifiers_dialog_bench_in_russian.txt +78 -0
  46. wikontic/utils/prompts/triplet_extraction/propmt_1_types_qualifiers.txt +91 -0
  47. wikontic/utils/structured_aligner.py +606 -0
  48. wikontic/utils/structured_inference_with_db.py +561 -0
  49. wikontic-0.0.3.dist-info/METADATA +111 -0
  50. wikontic-0.0.3.dist-info/RECORD +53 -0
  51. wikontic-0.0.3.dist-info/WHEEL +5 -0
  52. wikontic-0.0.3.dist-info/licenses/LICENSE +19 -0
  53. wikontic-0.0.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,329 @@
1
+ from typing import Dict, List, Optional
2
+ import logging
3
+
4
+ logger = logging.getLogger("BaseInferenceWithDB")
5
+ logger.setLevel(logging.ERROR)
6
+
7
+
8
+ class BaseInferenceWithDB:
9
+ """
10
+ Base class for inference with database functionality.
11
+ Contains common methods shared by InferenceWithDB and StructuredInferenceWithDB.
12
+
13
+ Note: This is an abstract base class. Child classes must define the following
14
+ attributes in their __init__ methods:
15
+ - self.extractor: The extractor instance
16
+ - self.aligner: The aligner instance
17
+ - self.triplets_db: The triplets database instance
18
+ """
19
+
20
+ def retrieve_similar_entity_names(
21
+ self, entity_name: str, k: int, sample_id: Optional[str] = None
22
+ ) -> List[Dict[str, str]]:
23
+ """
24
+ Retrieve similar entity names from the knowledge graph using vector search.
25
+ Useful to link entities from the question to the knowledge graph.
26
+ Args:
27
+ entity_name: The entity name to retrieve similar entity names from.
28
+ k: The number of similar entity names to retrieve.
29
+ sample_id: The sample ID of the subgraph to retrieve similar entity names from. If None, perform the search across all samples.
30
+ Returns:
31
+ A list of dictionaries with the entity name and entity type.
32
+ """
33
+
34
+ similar_entity_names = self.aligner.retrieve_similar_entity_names(
35
+ entity_name=entity_name, k=k, sample_id=sample_id
36
+ )
37
+ if isinstance(similar_entity_names, dict):
38
+ similar_entity_names = [e["entity"] for e in similar_entity_names]
39
+ return similar_entity_names
40
+
41
+ def identify_relevant_entities_from_question_with_llm(
42
+ self, question, sample_id=None, use_entity_types=True
43
+ ):
44
+ """
45
+ Identify relevant entities from question using LLM.
46
+ Args:
47
+ question: The question to identify relevant entities from.
48
+ sample_id: The sample ID of the subgraph to identify relevant entities from. If None, perform the search across all samples.
49
+ Returns:
50
+ The relevant entities.
51
+ """
52
+
53
+ entities = self.extractor.extract_entities_from_question(question)
54
+ identified_entities = []
55
+ linked_entities = []
56
+
57
+ if isinstance(entities, dict):
58
+ entities = [entities]
59
+
60
+ for ent in entities:
61
+ similar_entities = self.retrieve_similar_entity_names(
62
+ entity_name=ent, k=10, sample_id=sample_id
63
+ )
64
+ logger.log(logging.DEBUG, "Similar entities: %s" % (str(similar_entities)))
65
+
66
+ exact_entity_match = [e for e in similar_entities if e == ent]
67
+ if len(exact_entity_match) > 0:
68
+ linked_entities.extend(exact_entity_match)
69
+ else:
70
+ identified_entities.extend(similar_entities)
71
+
72
+ logger.log(
73
+ logging.DEBUG,
74
+ "Identified entities from question: %s" % (str(identified_entities)),
75
+ )
76
+ logger.log(
77
+ logging.DEBUG, "Linked entities from question: %s" % (str(linked_entities))
78
+ )
79
+ if use_entity_types:
80
+ linked_identified_entities = self.extractor.identify_relevant_entities(
81
+ question=question, entity_list=identified_entities
82
+ )
83
+ else:
84
+ linked_identified_entities = (
85
+ self.extractor.identify_relevant_entities_wo_types(
86
+ question=question, entity_list=identified_entities
87
+ )
88
+ )
89
+ linked_entities.extend([e["entity"] for e in linked_identified_entities])
90
+
91
+ logger.log(
92
+ logging.DEBUG,
93
+ "Linked entities after refinement: %s" % (str(linked_entities)),
94
+ )
95
+ return linked_entities
96
+
97
+ def get_1_hop_supporting_triplets(
98
+ self,
99
+ entities4search: List[str],
100
+ sample_id=None,
101
+ use_qualifiers=False,
102
+ use_filtered_triplets=False,
103
+ ):
104
+ """
105
+ Get the 1-hop supporting triplets for the given entities.
106
+ Useful to answer the question with the given entities.
107
+ Can be invoked multiple times for more than 1-hop support.
108
+ Args:
109
+ entities4search: The entities to get the 1-hop supporting triplets for.
110
+ sample_id: The sample ID of the subgraph to get the 1-hop supporting triplets from. If None, perform the search across all samples.
111
+ use_qualifiers: Whether to use qualifiers.
112
+ use_filtered_triplets: Whether to use the triplets that violate the ontology constraints along with the valid triplets.
113
+ Returns:
114
+ A list of dictionaries with the subject, relation, object, and qualifiers that correspond to the 1-hop supporting triplets for the given entities.
115
+ """
116
+ or_conditions = []
117
+ for ent in entities4search:
118
+ or_conditions.append({"$and": [{"subject": ent}]})
119
+ or_conditions.append({"$and": [{"object": ent}]})
120
+ if sample_id is None:
121
+ pipeline = [{"$match": {"$or": or_conditions}}]
122
+ else:
123
+ pipeline = [{"$match": {"sample_id": sample_id, "$or": or_conditions}}]
124
+ results = list(
125
+ self.triplets_db.get_collection(
126
+ self.aligner.triplets_collection_name
127
+ ).aggregate(pipeline)
128
+ )
129
+
130
+ if use_filtered_triplets:
131
+ filtered_results = list(
132
+ self.triplets_db.get_collection(
133
+ self.aligner.ontology_filtered_triplets_collection_name
134
+ ).aggregate(pipeline)
135
+ )
136
+ results.extend(filtered_results)
137
+
138
+ if use_qualifiers:
139
+ supporting_triplets = [
140
+ {
141
+ "subject": item["subject"],
142
+ "relation": item["relation"],
143
+ "object": item["object"],
144
+ "qualifiers": item["qualifiers"],
145
+ }
146
+ for item in results
147
+ ]
148
+ else:
149
+ supporting_triplets = [
150
+ {
151
+ "subject": item["subject"],
152
+ "relation": item["relation"],
153
+ "object": item["object"],
154
+ }
155
+ for item in results
156
+ ]
157
+ logger.log(
158
+ logging.DEBUG,
159
+ "Supporting triplets: %s\n%s" % (str(supporting_triplets), "-" * 100),
160
+ )
161
+ return supporting_triplets
162
+
163
+ def answer_question_with_llm(
164
+ self,
165
+ question,
166
+ linked_entities,
167
+ sample_id=None,
168
+ hop_depth=5,
169
+ use_filtered_triplets=False,
170
+ use_qualifiers=False,
171
+ ):
172
+ """
173
+ "Answer a question with relevant entities."
174
+ Args:
175
+ question: The question to answer.
176
+ linked_entities: The linked entities to answer the question.
177
+ sample_id: The sample ID of the subgraph to answer the question from. If None, perform the search across all samples.
178
+ use_filtered_triplets: Whether to use filtered triplets.
179
+ use_qualifiers: Whether to use qualifiers.
180
+ Returns:
181
+ The answer to the question.
182
+ """
183
+ logger.log(logging.DEBUG, "Linked entities: %s" % (str(linked_entities)))
184
+
185
+ entity_set = {e for e in linked_entities}
186
+ entities4search = list(entity_set)
187
+ supporting_triplets = []
188
+
189
+ for _ in range(hop_depth):
190
+ new_entities4search = []
191
+ new_supporting_triplets = self.get_1_hop_supporting_triplets(
192
+ entities4search, sample_id, use_qualifiers, use_filtered_triplets
193
+ )
194
+ supporting_triplets.extend(new_supporting_triplets)
195
+
196
+ for doc in supporting_triplets:
197
+ if doc["subject"] not in entities4search:
198
+ new_entities4search.append(doc["subject"])
199
+ if doc["object"] not in entities4search:
200
+ new_entities4search.append(doc["object"])
201
+ if use_qualifiers:
202
+ for q in doc["qualifiers"]:
203
+ if q["object"] not in entities4search:
204
+ new_entities4search.append(q["object"])
205
+
206
+ entities4search = list(set(new_entities4search))
207
+
208
+ if use_qualifiers:
209
+ supporting_triplets = [
210
+ {
211
+ "subject": item["subject"],
212
+ "relation": item["relation"],
213
+ "object": item["object"],
214
+ "qualifiers": item["qualifiers"],
215
+ }
216
+ for item in supporting_triplets
217
+ ]
218
+ else:
219
+ supporting_triplets = [
220
+ {
221
+ "subject": item["subject"],
222
+ "relation": item["relation"],
223
+ "object": item["object"],
224
+ }
225
+ for item in supporting_triplets
226
+ ]
227
+ logger.log(
228
+ logging.DEBUG,
229
+ "Supporting triplets: %s\n%s" % (str(supporting_triplets), "-" * 100),
230
+ )
231
+
232
+ ans = self.extractor.answer_question(
233
+ question=question, triplets=supporting_triplets
234
+ )
235
+ return supporting_triplets, ans
236
+
237
+ def answer_with_qa_collapsing(
238
+ self,
239
+ question,
240
+ sample_id=None,
241
+ max_attempts=5,
242
+ use_qualifiers=False,
243
+ use_filtered_triplets=False,
244
+ ):
245
+ """
246
+ "Answer a question with QA collapsing."
247
+ Args:
248
+ question: The question to answer.
249
+ sample_id: The sample ID of the subgraph to answer the question from. If None, perform the search across all samples.
250
+ max_attempts: The maximum number of attempts to answer the question. Useful to handle complex questions that require multiple hops to answer.
251
+ use_qualifiers: Whether to use qualifiers.
252
+ use_filtered_triplets: Whether to use filtered triplets.
253
+ Returns:
254
+ The answer to the question.
255
+ """
256
+ collapsed_question_answer = ""
257
+ collapsed_question_sequence = []
258
+ collapsed_answer_sequence = []
259
+
260
+ logger.log(logging.DEBUG, "Question: %s" % (str(question)))
261
+ collapsed_question = self.extractor.decompose_question(question)
262
+
263
+ for i in range(max_attempts):
264
+ extracted_entities = self.extractor.extract_entities_from_question(
265
+ collapsed_question
266
+ )
267
+ logger.log(
268
+ logging.DEBUG, "Collapsed question: %s" % (str(collapsed_question))
269
+ )
270
+ logger.log(
271
+ logging.DEBUG, "Extracted entities: %s" % (str(extracted_entities))
272
+ )
273
+
274
+ if len(collapsed_question_answer) > 0:
275
+ extracted_entities.append(collapsed_question_answer)
276
+
277
+ entities4search = []
278
+ for ent in extracted_entities:
279
+ similar_entities = self.retrieve_similar_entity_names(
280
+ entity_name=ent, k=10, sample_id=sample_id
281
+ )
282
+ entities4search.extend([e for e in similar_entities])
283
+
284
+ entities4search = list(set(entities4search))
285
+ logger.log(logging.DEBUG, "Similar entities: %s" % (str(entities4search)))
286
+
287
+ supporting_triplets = self.get_1_hop_supporting_triplets(
288
+ entities4search, sample_id, use_qualifiers, use_filtered_triplets
289
+ )
290
+
291
+ logger.log(
292
+ logging.DEBUG,
293
+ "Supporting triplets length: %s" % (str(len(supporting_triplets))),
294
+ )
295
+
296
+ collapsed_question_answer = self.extractor.answer_question(
297
+ collapsed_question, supporting_triplets
298
+ )
299
+ collapsed_question_sequence.append(collapsed_question)
300
+ collapsed_answer_sequence.append(collapsed_question_answer)
301
+
302
+ logger.log(
303
+ logging.DEBUG, "Collapsed question: %s" % (str(collapsed_question))
304
+ )
305
+ logger.log(
306
+ logging.DEBUG,
307
+ "Collapsed question answer: %s" % (str(collapsed_question_answer)),
308
+ )
309
+
310
+ is_answered = self.extractor.check_if_question_is_answered(
311
+ question, collapsed_question_sequence, collapsed_answer_sequence
312
+ )
313
+ question_answer_sequence = list(
314
+ zip(collapsed_question_sequence, collapsed_answer_sequence)
315
+ )
316
+
317
+ if is_answered == "NOT FINAL":
318
+ collapsed_question = self.extractor.collapse_question(
319
+ original_question=question,
320
+ question=collapsed_question,
321
+ answer=collapsed_question_answer,
322
+ )
323
+ continue
324
+ else:
325
+ logger.log(logging.DEBUG, "Final answer: %s" % (str(is_answered)))
326
+ return is_answered
327
+
328
+ logger.log(logging.DEBUG, "Final answer: %s" % (str(collapsed_question_answer)))
329
+ return collapsed_question_answer
@@ -0,0 +1,281 @@
1
+ from typing import List, Tuple, Set, Dict, Optional
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from dataclasses import dataclass
4
+ from pydantic import BaseModel, ValidationError
5
+ from pymongo import MongoClient, UpdateOne
6
+ from dotenv import load_dotenv, find_dotenv
7
+ import os
8
+ import torch
9
+
10
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
11
+ _ = load_dotenv(find_dotenv())
12
+
13
+
14
+ class EntityAlias(BaseModel):
15
+ _id: int
16
+ label: str
17
+ alias: str
18
+ sample_id: str
19
+ alias_text_embedding: List[float]
20
+
21
+
22
+ class PropertyAlias(BaseModel):
23
+ _id: int
24
+ label: str
25
+ alias: str
26
+ sample_id: str
27
+ alias_text_embedding: List[float]
28
+
29
+
30
+ class Aligner:
31
+ def __init__(self, triplets_db, device="cuda:0"):
32
+ self.db = triplets_db
33
+
34
+ self.entity_aliases_collection_name = "entity_aliases"
35
+ self.property_aliases_collection_name = "property_aliases"
36
+
37
+ self.property_vector_index_name = "property_aliases"
38
+ self.entities_vector_index_name = "entity_aliases"
39
+
40
+ self.initial_triplets_collection_name = "initial_triplets"
41
+ self.triplets_collection_name = "triplets"
42
+ self.filtered_triplets_collection_name = "filtered_triplets"
43
+
44
+ self.device = torch.device(device)
45
+ # self.tokenizer = AutoTokenizer.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY"))
46
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
47
+ # self.model = AutoModel.from_pretrained('facebook/contriever', token=os.getenv("HF_KEY")).to(self.device)
48
+ self.model = AutoModel.from_pretrained(
49
+ "facebook/contriever", use_safetensors=True
50
+ ).to(self.device)
51
+
52
+ def get_embedding(self, text):
53
+
54
+ def mean_pooling(token_embeddings, mask):
55
+ token_embeddings = token_embeddings.masked_fill(
56
+ ~mask[..., None].bool(), 0.0
57
+ )
58
+ sentence_embeddings = (
59
+ token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
60
+ )
61
+ return sentence_embeddings
62
+
63
+ if not text or not isinstance(text, str):
64
+ return None
65
+
66
+ inputs = self.tokenizer(
67
+ [text], padding=True, truncation=True, return_tensors="pt"
68
+ )
69
+ outputs = self.model(**inputs.to(self.device))
70
+ embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
71
+ return embeddings.detach().cpu().tolist()[0]
72
+
73
+ def retrieve_similar_properties(
74
+ self, target_relation: str, sample_id: str, k: int = 10
75
+ ) -> List[str]: # List of property labels
76
+ """
77
+ Retrieve and rank properties that match given relation.
78
+
79
+ Args:
80
+ target_relation: The relation to search for
81
+ k: Number of results to return
82
+
83
+ Returns:
84
+ List of property labels
85
+ """
86
+
87
+ collection = self.db.get_collection(self.property_aliases_collection_name)
88
+ query_embedding = self.get_embedding(target_relation)
89
+ if query_embedding is None:
90
+ return []
91
+
92
+ query_k = k * 2
93
+ max_attempts = 5 #
94
+ attempt = 0
95
+ unique_ranked_properties: List[str] = []
96
+
97
+ while len(unique_ranked_properties) < k and attempt < max_attempts:
98
+
99
+ pipeline = [
100
+ {
101
+ "$vectorSearch": {
102
+ "index": self.property_vector_index_name,
103
+ "queryVector": query_embedding,
104
+ "path": "alias_text_embedding",
105
+ "numCandidates": 150,
106
+ "limit": query_k if query_k < 150 else 150,
107
+ # "filter": {
108
+ # "sample_id": {"$eq": sample_id},
109
+ # },
110
+ }
111
+ },
112
+ {
113
+ "$project": {
114
+ "_id": 0,
115
+ "label": 1,
116
+ # "alias": 1
117
+ # "score": {"$meta": "vectorSearchScore"}
118
+ }
119
+ },
120
+ ]
121
+
122
+ similar_properties = collection.aggregate(pipeline)
123
+
124
+ for prop in similar_properties:
125
+ if prop["label"] not in unique_ranked_properties:
126
+ unique_ranked_properties.append(prop["label"])
127
+ if len(unique_ranked_properties) == k:
128
+ break
129
+
130
+ query_k *= 2
131
+ attempt += 1
132
+
133
+ return unique_ranked_properties
134
+
135
+ def retrieve_similar_entity_names(
136
+ self, entity_name: str, sample_id: Optional[str] = None, k: int = 10
137
+ ) -> List[str]: # List of entity labels
138
+ """
139
+ Retrieve and rank entities that match given entity.
140
+
141
+ Args:
142
+ entity_name: The entity to search for
143
+ k: Number of results to return
144
+
145
+ Returns:
146
+ List of entity labels
147
+ """
148
+
149
+ collection = self.db.get_collection(self.entity_aliases_collection_name)
150
+ query_embedding = self.get_embedding(entity_name)
151
+ if query_embedding is None:
152
+ return []
153
+
154
+ query_k = k * 2
155
+ max_attempts = 5 #
156
+ attempt = 0
157
+ unique_ranked_entities: List[str] = []
158
+
159
+ while len(unique_ranked_entities) < k and attempt < max_attempts:
160
+
161
+ if sample_id is not None:
162
+ filter = {
163
+ "sample_id": {"$eq": sample_id},
164
+ }
165
+ else:
166
+ filter = {}
167
+
168
+ pipeline = [
169
+ {
170
+ "$vectorSearch": {
171
+ "index": self.entities_vector_index_name,
172
+ "queryVector": query_embedding,
173
+ "path": "alias_text_embedding",
174
+ "numCandidates": 150,
175
+ "limit": query_k if query_k < 150 else 150,
176
+ "filter": filter,
177
+ }
178
+ },
179
+ {
180
+ "$project": {
181
+ "_id": 0,
182
+ "label": 1,
183
+ # "score": {"$meta": "vectorSearchScore"}
184
+ }
185
+ },
186
+ ]
187
+
188
+ similar_entities = collection.aggregate(pipeline)
189
+
190
+ for entity in similar_entities:
191
+ if entity["label"] not in unique_ranked_entities:
192
+ unique_ranked_entities.append(entity["label"])
193
+ if len(unique_ranked_entities) == k:
194
+ break
195
+
196
+ query_k *= 2
197
+ attempt += 1
198
+
199
+ return unique_ranked_entities
200
+
201
+ def add_entity(self, entity_name, alias, sample_id):
202
+ collection = self.db.get_collection(self.entity_aliases_collection_name)
203
+ if not collection.find_one(
204
+ {"label": entity_name, "alias": alias, "sample_id": sample_id}
205
+ ):
206
+
207
+ collection.insert_one(
208
+ {
209
+ "label": entity_name,
210
+ "alias": alias,
211
+ "sample_id": sample_id,
212
+ "alias_text_embedding": self.get_embedding(alias),
213
+ }
214
+ )
215
+
216
+ def add_property(self, property_name, alias, sample_id):
217
+ collection = self.db.get_collection(self.property_aliases_collection_name)
218
+ if not collection.find_one({"label": property_name, "alias": alias}):
219
+ collection.insert_one(
220
+ {
221
+ "label": property_name,
222
+ "alias": alias,
223
+ # "sample_id": sample_id,
224
+ "alias_text_embedding": self.get_embedding(alias),
225
+ }
226
+ )
227
+
228
+ def add_triplets(self, triplets_list, sample_id):
229
+ collection = self.db.get_collection(self.triplets_collection_name)
230
+
231
+ operations = []
232
+ for triple in triplets_list:
233
+ triple["sample_id"] = sample_id
234
+ filter_query = {
235
+ "subject": triple["subject"],
236
+ "relation": triple["relation"],
237
+ "object": triple["object"],
238
+ "sample_id": triple["sample_id"],
239
+ }
240
+ operations.append(
241
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
242
+ )
243
+
244
+ if operations:
245
+ collection.bulk_write(operations)
246
+
247
+ def add_filtered_triplets(self, triplets_list, sample_id):
248
+ collection = self.db.get_collection(self.filtered_triplets_collection_name)
249
+
250
+ operations = []
251
+ for triple in triplets_list:
252
+ triple["sample_id"] = sample_id
253
+ filter_query = {
254
+ "subject": triple["subject"],
255
+ "relation": triple["relation"],
256
+ "object": triple["object"],
257
+ "sample_id": triple["sample_id"],
258
+ }
259
+ operations.append(
260
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
261
+ )
262
+
263
+ if operations:
264
+ collection.bulk_write(operations)
265
+
266
+ def add_initial_triplets(self, triplets_list, sample_id):
267
+ collection = self.db.get_collection(self.initial_triplets_collection_name)
268
+ operations = []
269
+ for triple in triplets_list:
270
+ triple["sample_id"] = sample_id
271
+ filter_query = {
272
+ "subject": triple["subject"],
273
+ "relation": triple["relation"],
274
+ "object": triple["object"],
275
+ "sample_id": triple["sample_id"],
276
+ }
277
+ operations.append(
278
+ UpdateOne(filter_query, {"$setOnInsert": triple}, upsert=True)
279
+ )
280
+ if operations:
281
+ collection.bulk_write(operations)