cognee 0.5.0__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. cognee/api/client.py +5 -1
  2. cognee/api/v1/add/add.py +1 -2
  3. cognee/api/v1/cognify/code_graph_pipeline.py +119 -0
  4. cognee/api/v1/cognify/cognify.py +16 -24
  5. cognee/api/v1/cognify/routers/__init__.py +1 -0
  6. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +90 -0
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -3
  8. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  9. cognee/api/v1/ontologies/ontologies.py +37 -12
  10. cognee/api/v1/ontologies/routers/get_ontology_router.py +25 -27
  11. cognee/api/v1/search/search.py +0 -4
  12. cognee/api/v1/ui/ui.py +68 -38
  13. cognee/context_global_variables.py +16 -61
  14. cognee/eval_framework/answer_generation/answer_generation_executor.py +0 -10
  15. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  16. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +2 -0
  17. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  18. cognee/eval_framework/eval_config.py +2 -2
  19. cognee/eval_framework/modal_run_eval.py +28 -16
  20. cognee/infrastructure/databases/graph/config.py +0 -3
  21. cognee/infrastructure/databases/graph/get_graph_engine.py +0 -1
  22. cognee/infrastructure/databases/graph/graph_db_interface.py +0 -15
  23. cognee/infrastructure/databases/graph/kuzu/adapter.py +0 -228
  24. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +1 -80
  25. cognee/infrastructure/databases/utils/__init__.py +0 -3
  26. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +48 -62
  27. cognee/infrastructure/databases/vector/config.py +0 -2
  28. cognee/infrastructure/databases/vector/create_vector_engine.py +0 -1
  29. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -8
  30. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +7 -9
  31. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +10 -11
  32. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +544 -0
  33. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -2
  34. cognee/infrastructure/databases/vector/vector_db_interface.py +0 -35
  35. cognee/infrastructure/files/storage/s3_config.py +0 -2
  36. cognee/infrastructure/llm/LLMGateway.py +2 -5
  37. cognee/infrastructure/llm/config.py +0 -35
  38. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  39. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +8 -23
  40. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +16 -17
  41. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +37 -40
  42. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +36 -39
  43. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +1 -19
  44. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +9 -11
  45. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +21 -23
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +34 -42
  47. cognee/modules/cognify/config.py +0 -2
  48. cognee/modules/data/deletion/prune_system.py +2 -52
  49. cognee/modules/data/methods/delete_dataset.py +0 -26
  50. cognee/modules/engine/models/__init__.py +0 -1
  51. cognee/modules/graph/cognee_graph/CogneeGraph.py +37 -85
  52. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +3 -8
  53. cognee/modules/memify/memify.py +7 -1
  54. cognee/modules/pipelines/operations/pipeline.py +2 -18
  55. cognee/modules/retrieval/__init__.py +1 -1
  56. cognee/modules/retrieval/code_retriever.py +232 -0
  57. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -4
  58. cognee/modules/retrieval/graph_completion_cot_retriever.py +0 -4
  59. cognee/modules/retrieval/graph_completion_retriever.py +0 -10
  60. cognee/modules/retrieval/graph_summary_completion_retriever.py +0 -4
  61. cognee/modules/retrieval/temporal_retriever.py +0 -4
  62. cognee/modules/retrieval/utils/brute_force_triplet_search.py +10 -42
  63. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +1 -8
  64. cognee/modules/search/methods/get_search_type_tools.py +8 -54
  65. cognee/modules/search/methods/no_access_control_search.py +0 -4
  66. cognee/modules/search/methods/search.py +0 -21
  67. cognee/modules/search/types/SearchType.py +1 -1
  68. cognee/modules/settings/get_settings.py +0 -19
  69. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  70. cognee/modules/users/models/DatasetDatabase.py +3 -15
  71. cognee/shared/logging_utils.py +0 -4
  72. cognee/tasks/code/enrich_dependency_graph_checker.py +35 -0
  73. cognee/tasks/code/get_local_dependencies_checker.py +20 -0
  74. cognee/tasks/code/get_repo_dependency_graph_checker.py +35 -0
  75. cognee/tasks/documents/__init__.py +1 -0
  76. cognee/tasks/documents/check_permissions_on_dataset.py +26 -0
  77. cognee/tasks/graph/extract_graph_from_data.py +10 -9
  78. cognee/tasks/repo_processor/__init__.py +2 -0
  79. cognee/tasks/repo_processor/get_local_dependencies.py +335 -0
  80. cognee/tasks/repo_processor/get_non_code_files.py +158 -0
  81. cognee/tasks/repo_processor/get_repo_file_dependencies.py +243 -0
  82. cognee/tasks/storage/add_data_points.py +2 -142
  83. cognee/tests/test_cognee_server_start.py +4 -2
  84. cognee/tests/test_conversation_history.py +1 -23
  85. cognee/tests/test_delete_bmw_example.py +60 -0
  86. cognee/tests/test_search_db.py +1 -37
  87. cognee/tests/unit/api/test_ontology_endpoint.py +89 -77
  88. cognee/tests/unit/infrastructure/mock_embedding_engine.py +7 -3
  89. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -0
  90. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  91. cognee/tests/unit/modules/graph/cognee_graph_test.py +0 -406
  92. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +89 -76
  93. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +97 -118
  94. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  95. cognee/api/v1/ui/node_setup.py +0 -360
  96. cognee/api/v1/ui/npm_utils.py +0 -50
  97. cognee/eval_framework/Dockerfile +0 -29
  98. cognee/infrastructure/databases/dataset_database_handler/__init__.py +0 -3
  99. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +0 -80
  100. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +0 -18
  101. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +0 -10
  102. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +0 -81
  103. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +0 -168
  104. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +0 -10
  105. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +0 -10
  106. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +0 -30
  107. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +0 -50
  108. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +0 -5
  109. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +0 -153
  110. cognee/memify_pipelines/create_triplet_embeddings.py +0 -53
  111. cognee/modules/engine/models/Triplet.py +0 -9
  112. cognee/modules/retrieval/register_retriever.py +0 -10
  113. cognee/modules/retrieval/registered_community_retrievers.py +0 -1
  114. cognee/modules/retrieval/triplet_retriever.py +0 -182
  115. cognee/shared/rate_limiting.py +0 -30
  116. cognee/tasks/memify/get_triplet_datapoints.py +0 -289
  117. cognee/tests/integration/retrieval/test_triplet_retriever.py +0 -84
  118. cognee/tests/integration/tasks/test_add_data_points.py +0 -139
  119. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +0 -69
  120. cognee/tests/test_dataset_database_handler.py +0 -137
  121. cognee/tests/test_dataset_delete.py +0 -76
  122. cognee/tests/test_edge_centered_payload.py +0 -170
  123. cognee/tests/test_pipeline_cache.py +0 -164
  124. cognee/tests/unit/infrastructure/llm/test_llm_config.py +0 -46
  125. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +0 -214
  126. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +0 -608
  127. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +0 -83
  128. cognee/tests/unit/tasks/storage/test_add_data_points.py +0 -288
  129. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -0
  130. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  131. {cognee-0.5.0.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,182 +0,0 @@
1
- import asyncio
2
- from typing import Any, Optional, Type, List
3
-
4
- from cognee.shared.logging_utils import get_logger
5
- from cognee.infrastructure.databases.vector import get_vector_engine
6
- from cognee.modules.retrieval.utils.completion import generate_completion, summarize_text
7
- from cognee.modules.retrieval.utils.session_cache import (
8
- save_conversation_history,
9
- get_conversation_history,
10
- )
11
- from cognee.modules.retrieval.base_retriever import BaseRetriever
12
- from cognee.modules.retrieval.exceptions.exceptions import NoDataError
13
- from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
14
- from cognee.context_global_variables import session_user
15
- from cognee.infrastructure.databases.cache.config import CacheConfig
16
-
17
- logger = get_logger("TripletRetriever")
18
-
19
-
20
- class TripletRetriever(BaseRetriever):
21
- """
22
- Retriever for handling LLM-based completion searches using triplets.
23
-
24
- Public methods:
25
- - get_context(query: str) -> str
26
- - get_completion(query: str, context: Optional[Any] = None) -> Any
27
- """
28
-
29
- def __init__(
30
- self,
31
- user_prompt_path: str = "context_for_question.txt",
32
- system_prompt_path: str = "answer_simple_question.txt",
33
- system_prompt: Optional[str] = None,
34
- top_k: Optional[int] = 5,
35
- ):
36
- """Initialize retriever with optional custom prompt paths."""
37
- self.user_prompt_path = user_prompt_path
38
- self.system_prompt_path = system_prompt_path
39
- self.top_k = top_k if top_k is not None else 1
40
- self.system_prompt = system_prompt
41
-
42
- async def get_context(self, query: str) -> str:
43
- """
44
- Retrieves relevant triplets as context.
45
-
46
- Fetches triplets based on a query from a vector engine and combines their text.
47
- Returns empty string if no triplets are found. Raises NoDataError if the collection is not
48
- found.
49
-
50
- Parameters:
51
- -----------
52
-
53
- - query (str): The query string used to search for relevant triplets.
54
-
55
- Returns:
56
- --------
57
-
58
- - str: A string containing the combined text of the retrieved triplets, or an
59
- empty string if none are found.
60
- """
61
- vector_engine = get_vector_engine()
62
-
63
- try:
64
- if not await vector_engine.has_collection(collection_name="Triplet_text"):
65
- logger.error("Triplet_text collection not found")
66
- raise NoDataError(
67
- "In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
68
- )
69
-
70
- found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
71
-
72
- if len(found_triplets) == 0:
73
- return ""
74
-
75
- triplets_payload = [found_triplet.payload["text"] for found_triplet in found_triplets]
76
- combined_context = "\n".join(triplets_payload)
77
- return combined_context
78
- except CollectionNotFoundError as error:
79
- logger.error("Triplet_text collection not found")
80
- raise NoDataError("No data found in the system, please add data first.") from error
81
-
82
- async def get_completion(
83
- self,
84
- query: str,
85
- context: Optional[Any] = None,
86
- session_id: Optional[str] = None,
87
- response_model: Type = str,
88
- ) -> List[Any]:
89
- """
90
- Generates an LLM completion using the context.
91
-
92
- Retrieves context if not provided and generates a completion based on the query and
93
- context using an external completion generator.
94
-
95
- Parameters:
96
- -----------
97
-
98
- - query (str): The query string to be used for generating a completion.
99
- - context (Optional[Any]): Optional pre-fetched context to use for generating the
100
- completion; if None, it retrieves the context for the query. (default None)
101
- - session_id (Optional[str]): Optional session identifier for caching. If None,
102
- defaults to 'default_session'. (default None)
103
- - response_model (Type): The Pydantic model type for structured output. (default str)
104
-
105
- Returns:
106
- --------
107
-
108
- - Any: The generated completion based on the provided query and context.
109
- """
110
- if context is None:
111
- context = await self.get_context(query)
112
-
113
- cache_config = CacheConfig()
114
- user = session_user.get()
115
- user_id = getattr(user, "id", None)
116
- session_save = user_id and cache_config.caching
117
-
118
- if session_save:
119
- completion = await self._get_completion_with_session(
120
- query=query,
121
- context=context,
122
- session_id=session_id,
123
- response_model=response_model,
124
- )
125
- else:
126
- completion = await self._get_completion_without_session(
127
- query=query,
128
- context=context,
129
- response_model=response_model,
130
- )
131
-
132
- return [completion]
133
-
134
- async def _get_completion_with_session(
135
- self,
136
- query: str,
137
- context: str,
138
- session_id: Optional[str],
139
- response_model: Type,
140
- ) -> Any:
141
- """Generate completion with session history and caching."""
142
- conversation_history = await get_conversation_history(session_id=session_id)
143
-
144
- context_summary, completion = await asyncio.gather(
145
- summarize_text(context),
146
- generate_completion(
147
- query=query,
148
- context=context,
149
- user_prompt_path=self.user_prompt_path,
150
- system_prompt_path=self.system_prompt_path,
151
- system_prompt=self.system_prompt,
152
- conversation_history=conversation_history,
153
- response_model=response_model,
154
- ),
155
- )
156
-
157
- await save_conversation_history(
158
- query=query,
159
- context_summary=context_summary,
160
- answer=completion,
161
- session_id=session_id,
162
- )
163
-
164
- return completion
165
-
166
- async def _get_completion_without_session(
167
- self,
168
- query: str,
169
- context: str,
170
- response_model: Type,
171
- ) -> Any:
172
- """Generate completion without session history."""
173
- completion = await generate_completion(
174
- query=query,
175
- context=context,
176
- user_prompt_path=self.user_prompt_path,
177
- system_prompt_path=self.system_prompt_path,
178
- system_prompt=self.system_prompt,
179
- response_model=response_model,
180
- )
181
-
182
- return completion
@@ -1,30 +0,0 @@
1
- from aiolimiter import AsyncLimiter
2
- from contextlib import nullcontext
3
- from cognee.infrastructure.llm.config import get_llm_config
4
-
5
- llm_config = get_llm_config()
6
-
7
- llm_rate_limiter = AsyncLimiter(
8
- llm_config.llm_rate_limit_requests, llm_config.embedding_rate_limit_interval
9
- )
10
- embedding_rate_limiter = AsyncLimiter(
11
- llm_config.embedding_rate_limit_requests, llm_config.embedding_rate_limit_interval
12
- )
13
-
14
-
15
- def llm_rate_limiter_context_manager():
16
- global llm_rate_limiter
17
- if llm_config.llm_rate_limit_enabled:
18
- return llm_rate_limiter
19
- else:
20
- # Return a no-op context manager if rate limiting is disabled
21
- return nullcontext()
22
-
23
-
24
- def embedding_rate_limiter_context_manager():
25
- global embedding_rate_limiter
26
- if llm_config.embedding_rate_limit_enabled:
27
- return embedding_rate_limiter
28
- else:
29
- # Return a no-op context manager if rate limiting is disabled
30
- return nullcontext()
@@ -1,289 +0,0 @@
1
- from typing import AsyncGenerator, Dict, Any, List, Optional
2
- from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
3
- from cognee.modules.engine.utils import generate_node_id
4
- from cognee.shared.logging_utils import get_logger
5
- from cognee.modules.graph.utils.convert_node_to_data_point import get_all_subclasses
6
- from cognee.infrastructure.engine import DataPoint
7
- from cognee.modules.engine.models import Triplet
8
- from cognee.tasks.storage import index_data_points
9
-
10
- logger = get_logger("get_triplet_datapoints")
11
-
12
-
13
- def _build_datapoint_type_index_mapping() -> Dict[str, List[str]]:
14
- """
15
- Build a mapping of DataPoint type names to their index_fields.
16
-
17
- Returns:
18
- --------
19
- - Dict[str, List[str]]: Mapping of type name to list of index field names
20
- """
21
- logger.debug("Building DataPoint type to index_fields mapping")
22
- subclasses = get_all_subclasses(DataPoint)
23
- datapoint_type_index_property = {}
24
-
25
- for subclass in subclasses:
26
- if "metadata" in subclass.model_fields:
27
- metadata_field = subclass.model_fields["metadata"]
28
- default = getattr(metadata_field, "default", None)
29
- if isinstance(default, dict):
30
- index_fields = default.get("index_fields", [])
31
- if index_fields:
32
- datapoint_type_index_property[subclass.__name__] = index_fields
33
- logger.debug(
34
- f"Registered {subclass.__name__} with index_fields: {index_fields}"
35
- )
36
-
37
- logger.info(
38
- f"Found {len(datapoint_type_index_property)} DataPoint types with index_fields: "
39
- f"{list(datapoint_type_index_property.keys())}"
40
- )
41
- return datapoint_type_index_property
42
-
43
-
44
- def _extract_embeddable_text(node_or_edge: Dict[str, Any], index_fields: List[str]) -> str:
45
- """
46
- Extract and concatenate embeddable properties from a node or edge dictionary.
47
-
48
- Parameters:
49
- -----------
50
- - node_or_edge (Dict[str, Any]): Dictionary containing node or edge properties.
51
- - index_fields (List[str]): List of field names to extract and concatenate.
52
-
53
- Returns:
54
- --------
55
- - str: Concatenated string of all embeddable property values, or empty string if none found.
56
- """
57
- if not node_or_edge or not index_fields:
58
- return ""
59
-
60
- embeddable_values = []
61
- for field_name in index_fields:
62
- field_value = node_or_edge.get(field_name)
63
- if field_value is not None:
64
- field_value = str(field_value).strip()
65
-
66
- if field_value:
67
- embeddable_values.append(field_value)
68
-
69
- return " ".join(embeddable_values) if embeddable_values else ""
70
-
71
-
72
- def _extract_relationship_text(
73
- relationship: Dict[str, Any], datapoint_type_index_property: Dict[str, List[str]]
74
- ) -> str:
75
- """
76
- Extract relationship text from edge properties.
77
-
78
- Parameters:
79
- -----------
80
- - relationship (Dict[str, Any]): Dictionary containing relationship properties
81
- - datapoint_type_index_property (Dict[str, List[str]]): Mapping of type to index fields
82
-
83
- Returns:
84
- --------
85
- - str: Extracted relationship text or empty string
86
- """
87
- if not relationship:
88
- return ""
89
-
90
- edge_text = relationship.get("edge_text")
91
- if edge_text and isinstance(edge_text, str) and edge_text.strip():
92
- return edge_text.strip()
93
-
94
- # Fallback to extracting from EdgeType index_fields
95
- edge_type_index_fields = datapoint_type_index_property.get("EdgeType", [])
96
- return _extract_embeddable_text(relationship, edge_type_index_fields)
97
-
98
-
99
- def _process_single_triplet(
100
- triplet_datapoint: Dict[str, Any],
101
- datapoint_type_index_property: Dict[str, List[str]],
102
- offset: int,
103
- idx: int,
104
- ) -> tuple[Optional[Triplet], Optional[str]]:
105
- """
106
- Process a single triplet and create a Triplet object.
107
-
108
- Parameters:
109
- -----------
110
- - triplet_datapoint (Dict[str, Any]): Raw triplet data from graph engine
111
- - datapoint_type_index_property (Dict[str, List[str]]): Type to index fields mapping
112
- - offset (int): Current batch offset
113
- - idx (int): Index within current batch
114
-
115
- Returns:
116
- --------
117
- - tuple[Optional[Triplet], Optional[str]]: (Triplet object, error message if skipped)
118
- """
119
- start_node = triplet_datapoint.get("start_node", {})
120
- end_node = triplet_datapoint.get("end_node", {})
121
- relationship = triplet_datapoint.get("relationship_properties", {})
122
-
123
- start_node_type = start_node.get("type")
124
- end_node_type = end_node.get("type")
125
-
126
- start_index_fields = datapoint_type_index_property.get(start_node_type, [])
127
- end_index_fields = datapoint_type_index_property.get(end_node_type, [])
128
-
129
- if not start_index_fields:
130
- logger.debug(
131
- f"No index_fields found for start_node type '{start_node_type}' in triplet {offset + idx}"
132
- )
133
- if not end_index_fields:
134
- logger.debug(
135
- f"No index_fields found for end_node type '{end_node_type}' in triplet {offset + idx}"
136
- )
137
-
138
- start_node_id = start_node.get("id", "")
139
- end_node_id = end_node.get("id", "")
140
-
141
- if not start_node_id or not end_node_id:
142
- return None, (
143
- f"Skipping triplet at offset {offset + idx}: missing node IDs "
144
- f"(start: {start_node_id}, end: {end_node_id})"
145
- )
146
-
147
- relationship_text = _extract_relationship_text(relationship, datapoint_type_index_property)
148
- start_node_text = _extract_embeddable_text(start_node, start_index_fields)
149
- end_node_text = _extract_embeddable_text(end_node, end_index_fields)
150
-
151
- if not start_node_text and not end_node_text and not relationship_text:
152
- return None, (
153
- f"Skipping triplet at offset {offset + idx}: empty embeddable text "
154
- f"(start_node_id: {start_node_id}, end_node_id: {end_node_id})"
155
- )
156
-
157
- embeddable_text = f"{start_node_text}-›{relationship_text}-›{end_node_text}".strip()
158
-
159
- relationship_name = relationship.get("relationship_name", "")
160
- triplet_id = generate_node_id(str(start_node_id) + str(relationship_name) + str(end_node_id))
161
-
162
- triplet_obj = Triplet(
163
- id=triplet_id, from_node_id=start_node_id, to_node_id=end_node_id, text=embeddable_text
164
- )
165
-
166
- return triplet_obj, None
167
-
168
-
169
- async def get_triplet_datapoints(
170
- data,
171
- triplets_batch_size: int = 100,
172
- ) -> AsyncGenerator[Triplet, None]:
173
- """
174
- Async generator that yields batches of triplet datapoints with embeddable text extracted.
175
-
176
- Each triplet in the batch includes:
177
- - Original triplet structure (start_node, relationship_properties, end_node)
178
- - Extracted embeddable text for each element based on index_fields
179
-
180
- Parameters:
181
- -----------
182
- - triplets_batch_size (int): Number of triplets to retrieve per batch. Default is 100.
183
-
184
- Yields:
185
- -------
186
- - List[Dict[str, Any]]: A batch of triplets, each enriched with embeddable text.
187
- """
188
- if not data or data == [{}]:
189
- logger.info("Fetching graph data for current user")
190
-
191
- logger.info(f"Starting triplet datapoints extraction with batch size: {triplets_batch_size}")
192
-
193
- graph_engine = await get_graph_engine()
194
- graph_engine_type = type(graph_engine).__name__
195
- logger.debug(f"Using graph engine: {graph_engine_type}")
196
-
197
- if not hasattr(graph_engine, "get_triplets_batch"):
198
- error_msg = f"Graph adapter {graph_engine_type} does not support get_triplets_batch method"
199
- logger.error(error_msg)
200
- raise NotImplementedError(error_msg)
201
-
202
- datapoint_type_index_property = _build_datapoint_type_index_mapping()
203
-
204
- offset = 0
205
- total_triplets_processed = 0
206
- batch_number = 0
207
-
208
- while True:
209
- try:
210
- batch_number += 1
211
- logger.debug(
212
- f"Fetching triplet batch {batch_number} (offset: {offset}, limit: {triplets_batch_size})"
213
- )
214
-
215
- triplets_batch = await graph_engine.get_triplets_batch(
216
- offset=offset, limit=triplets_batch_size
217
- )
218
-
219
- if not triplets_batch:
220
- logger.info(f"No more triplets found at offset {offset}. Processing complete.")
221
- break
222
-
223
- logger.debug(f"Retrieved {len(triplets_batch)} triplets in batch {batch_number}")
224
-
225
- triplet_datapoints = []
226
- skipped_count = 0
227
-
228
- for idx, triplet_datapoint in enumerate(triplets_batch):
229
- try:
230
- triplet_obj, error_msg = _process_single_triplet(
231
- triplet_datapoint, datapoint_type_index_property, offset, idx
232
- )
233
-
234
- if error_msg:
235
- logger.warning(error_msg)
236
- skipped_count += 1
237
- continue
238
-
239
- if triplet_obj:
240
- triplet_datapoints.append(triplet_obj)
241
- yield triplet_obj
242
-
243
- except Exception as e:
244
- logger.warning(
245
- f"Error processing triplet at offset {offset + idx}: {e}. "
246
- f"Skipping this triplet and continuing."
247
- )
248
- skipped_count += 1
249
- continue
250
-
251
- if skipped_count > 0:
252
- logger.warning(
253
- f"Skipped {skipped_count} out of {len(triplets_batch)} triplets in batch {batch_number}"
254
- )
255
-
256
- if not triplet_datapoints:
257
- logger.warning(
258
- f"No valid triplet datapoints in batch {batch_number} after processing"
259
- )
260
- offset += len(triplets_batch)
261
- if len(triplets_batch) < triplets_batch_size:
262
- break
263
- continue
264
-
265
- total_triplets_processed += len(triplet_datapoints)
266
- logger.info(
267
- f"Batch {batch_number} complete: processed {len(triplet_datapoints)} triplets "
268
- f"(total processed: {total_triplets_processed})"
269
- )
270
-
271
- offset += len(triplets_batch)
272
- if len(triplets_batch) < triplets_batch_size:
273
- logger.info(
274
- f"Last batch retrieved (got {len(triplets_batch)} < {triplets_batch_size} triplets). "
275
- f"Processing complete."
276
- )
277
- break
278
-
279
- except Exception as e:
280
- logger.error(
281
- f"Error retrieving triplet batch {batch_number} at offset {offset}: {e}",
282
- exc_info=True,
283
- )
284
- raise
285
-
286
- logger.info(
287
- f"Triplet datapoints extraction complete. "
288
- f"Processed {total_triplets_processed} triplets across {batch_number} batch(es)."
289
- )
@@ -1,84 +0,0 @@
1
- import os
2
- import pytest
3
- import pathlib
4
- import pytest_asyncio
5
- import cognee
6
-
7
- from cognee.low_level import setup
8
- from cognee.tasks.storage import add_data_points
9
- from cognee.modules.retrieval.exceptions.exceptions import NoDataError
10
- from cognee.modules.retrieval.triplet_retriever import TripletRetriever
11
- from cognee.modules.engine.models import Triplet
12
-
13
-
14
- @pytest_asyncio.fixture
15
- async def setup_test_environment_with_triplets():
16
- """Set up a clean test environment with triplets."""
17
- base_dir = pathlib.Path(__file__).parent.parent.parent.parent
18
- system_directory_path = str(base_dir / ".cognee_system/test_triplet_retriever_context_simple")
19
- data_directory_path = str(base_dir / ".data_storage/test_triplet_retriever_context_simple")
20
-
21
- cognee.config.system_root_directory(system_directory_path)
22
- cognee.config.data_root_directory(data_directory_path)
23
-
24
- await cognee.prune.prune_data()
25
- await cognee.prune.prune_system(metadata=True)
26
- await setup()
27
-
28
- triplet1 = Triplet(
29
- from_node_id="node1",
30
- to_node_id="node2",
31
- text="Alice knows Bob",
32
- )
33
- triplet2 = Triplet(
34
- from_node_id="node2",
35
- to_node_id="node3",
36
- text="Bob works at Tech Corp",
37
- )
38
-
39
- triplets = [triplet1, triplet2]
40
- await add_data_points(triplets)
41
-
42
- yield
43
-
44
- try:
45
- await cognee.prune.prune_data()
46
- await cognee.prune.prune_system(metadata=True)
47
- except Exception:
48
- pass
49
-
50
-
51
- @pytest_asyncio.fixture
52
- async def setup_test_environment_empty():
53
- """Set up a clean test environment without triplets."""
54
- base_dir = pathlib.Path(__file__).parent.parent.parent.parent
55
- system_directory_path = str(
56
- base_dir / ".cognee_system/test_triplet_retriever_context_empty_collection"
57
- )
58
- data_directory_path = str(
59
- base_dir / ".data_storage/test_triplet_retriever_context_empty_collection"
60
- )
61
-
62
- cognee.config.system_root_directory(system_directory_path)
63
- cognee.config.data_root_directory(data_directory_path)
64
-
65
- await cognee.prune.prune_data()
66
- await cognee.prune.prune_system(metadata=True)
67
-
68
- yield
69
-
70
- try:
71
- await cognee.prune.prune_data()
72
- await cognee.prune.prune_system(metadata=True)
73
- except Exception:
74
- pass
75
-
76
-
77
- @pytest.mark.asyncio
78
- async def test_triplet_retriever_context_simple(setup_test_environment_with_triplets):
79
- """Integration test: verify TripletRetriever can retrieve triplet context."""
80
- retriever = TripletRetriever(top_k=5)
81
-
82
- context = await retriever.get_context("Alice")
83
-
84
- assert "Alice knows Bob" in context, "Failed to get Alice triplet"