cognee 0.2.4__py3-none-any.whl → 0.3.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/client.py +28 -3
  3. cognee/api/health.py +10 -13
  4. cognee/api/v1/add/add.py +3 -1
  5. cognee/api/v1/add/routers/get_add_router.py +12 -37
  6. cognee/api/v1/cloud/routers/__init__.py +1 -0
  7. cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
  8. cognee/api/v1/cognify/code_graph_pipeline.py +9 -4
  9. cognee/api/v1/cognify/cognify.py +50 -3
  10. cognee/api/v1/cognify/routers/get_cognify_router.py +1 -1
  11. cognee/api/v1/datasets/routers/get_datasets_router.py +15 -4
  12. cognee/api/v1/memify/__init__.py +0 -0
  13. cognee/api/v1/memify/routers/__init__.py +1 -0
  14. cognee/api/v1/memify/routers/get_memify_router.py +100 -0
  15. cognee/api/v1/notebooks/routers/__init__.py +1 -0
  16. cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
  17. cognee/api/v1/search/routers/get_search_router.py +20 -1
  18. cognee/api/v1/search/search.py +11 -4
  19. cognee/api/v1/sync/__init__.py +17 -0
  20. cognee/api/v1/sync/routers/__init__.py +3 -0
  21. cognee/api/v1/sync/routers/get_sync_router.py +241 -0
  22. cognee/api/v1/sync/sync.py +877 -0
  23. cognee/api/v1/users/routers/get_auth_router.py +13 -1
  24. cognee/base_config.py +10 -1
  25. cognee/infrastructure/databases/graph/config.py +10 -4
  26. cognee/infrastructure/databases/graph/kuzu/adapter.py +135 -0
  27. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +89 -0
  28. cognee/infrastructure/databases/relational/__init__.py +2 -0
  29. cognee/infrastructure/databases/relational/get_async_session.py +15 -0
  30. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
  31. cognee/infrastructure/databases/relational/with_async_session.py +25 -0
  32. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
  33. cognee/infrastructure/databases/vector/config.py +13 -6
  34. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +1 -1
  35. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
  36. cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +4 -1
  37. cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
  38. cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
  39. cognee/infrastructure/files/storage/StorageManager.py +7 -1
  40. cognee/infrastructure/files/storage/storage.py +16 -0
  41. cognee/infrastructure/llm/LLMGateway.py +18 -0
  42. cognee/infrastructure/llm/config.py +4 -2
  43. cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
  44. cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
  45. cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
  46. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
  47. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
  48. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
  49. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -1
  51. cognee/infrastructure/utils/run_sync.py +8 -1
  52. cognee/modules/chunking/models/DocumentChunk.py +4 -3
  53. cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
  54. cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
  55. cognee/modules/cloud/exceptions/__init__.py +2 -0
  56. cognee/modules/cloud/operations/__init__.py +1 -0
  57. cognee/modules/cloud/operations/check_api_key.py +25 -0
  58. cognee/modules/data/deletion/prune_system.py +1 -1
  59. cognee/modules/data/methods/check_dataset_name.py +1 -1
  60. cognee/modules/data/methods/get_dataset_data.py +1 -1
  61. cognee/modules/data/methods/load_or_create_datasets.py +1 -1
  62. cognee/modules/engine/models/Event.py +16 -0
  63. cognee/modules/engine/models/Interval.py +8 -0
  64. cognee/modules/engine/models/Timestamp.py +13 -0
  65. cognee/modules/engine/models/__init__.py +3 -0
  66. cognee/modules/engine/utils/__init__.py +2 -0
  67. cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
  68. cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
  69. cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
  70. cognee/modules/graph/utils/__init__.py +1 -0
  71. cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
  72. cognee/modules/memify/__init__.py +1 -0
  73. cognee/modules/memify/memify.py +118 -0
  74. cognee/modules/notebooks/methods/__init__.py +5 -0
  75. cognee/modules/notebooks/methods/create_notebook.py +26 -0
  76. cognee/modules/notebooks/methods/delete_notebook.py +13 -0
  77. cognee/modules/notebooks/methods/get_notebook.py +21 -0
  78. cognee/modules/notebooks/methods/get_notebooks.py +18 -0
  79. cognee/modules/notebooks/methods/update_notebook.py +17 -0
  80. cognee/modules/notebooks/models/Notebook.py +53 -0
  81. cognee/modules/notebooks/models/__init__.py +1 -0
  82. cognee/modules/notebooks/operations/__init__.py +1 -0
  83. cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
  84. cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +19 -3
  85. cognee/modules/pipelines/operations/pipeline.py +1 -0
  86. cognee/modules/pipelines/operations/run_tasks.py +17 -41
  87. cognee/modules/retrieval/base_graph_retriever.py +18 -0
  88. cognee/modules/retrieval/base_retriever.py +1 -1
  89. cognee/modules/retrieval/code_retriever.py +8 -0
  90. cognee/modules/retrieval/coding_rules_retriever.py +31 -0
  91. cognee/modules/retrieval/completion_retriever.py +9 -3
  92. cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
  93. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +23 -14
  94. cognee/modules/retrieval/graph_completion_cot_retriever.py +21 -11
  95. cognee/modules/retrieval/graph_completion_retriever.py +32 -65
  96. cognee/modules/retrieval/graph_summary_completion_retriever.py +3 -1
  97. cognee/modules/retrieval/insights_retriever.py +14 -3
  98. cognee/modules/retrieval/summaries_retriever.py +1 -1
  99. cognee/modules/retrieval/temporal_retriever.py +152 -0
  100. cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
  101. cognee/modules/retrieval/utils/completion.py +10 -3
  102. cognee/modules/search/methods/get_search_type_tools.py +168 -0
  103. cognee/modules/search/methods/no_access_control_search.py +47 -0
  104. cognee/modules/search/methods/search.py +219 -139
  105. cognee/modules/search/types/SearchResult.py +21 -0
  106. cognee/modules/search/types/SearchType.py +2 -0
  107. cognee/modules/search/types/__init__.py +1 -0
  108. cognee/modules/search/utils/__init__.py +2 -0
  109. cognee/modules/search/utils/prepare_search_result.py +41 -0
  110. cognee/modules/search/utils/transform_context_to_graph.py +38 -0
  111. cognee/modules/sync/__init__.py +1 -0
  112. cognee/modules/sync/methods/__init__.py +23 -0
  113. cognee/modules/sync/methods/create_sync_operation.py +53 -0
  114. cognee/modules/sync/methods/get_sync_operation.py +107 -0
  115. cognee/modules/sync/methods/update_sync_operation.py +248 -0
  116. cognee/modules/sync/models/SyncOperation.py +142 -0
  117. cognee/modules/sync/models/__init__.py +3 -0
  118. cognee/modules/users/__init__.py +0 -1
  119. cognee/modules/users/methods/__init__.py +4 -1
  120. cognee/modules/users/methods/create_user.py +26 -1
  121. cognee/modules/users/methods/get_authenticated_user.py +36 -42
  122. cognee/modules/users/methods/get_default_user.py +3 -1
  123. cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
  124. cognee/root_dir.py +19 -0
  125. cognee/shared/logging_utils.py +1 -1
  126. cognee/tasks/codingagents/__init__.py +0 -0
  127. cognee/tasks/codingagents/coding_rule_associations.py +127 -0
  128. cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
  129. cognee/tasks/memify/__init__.py +2 -0
  130. cognee/tasks/memify/extract_subgraph.py +7 -0
  131. cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
  132. cognee/tasks/repo_processor/get_repo_file_dependencies.py +52 -27
  133. cognee/tasks/temporal_graph/__init__.py +1 -0
  134. cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
  135. cognee/tasks/temporal_graph/enrich_events.py +34 -0
  136. cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
  137. cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
  138. cognee/tasks/temporal_graph/models.py +49 -0
  139. cognee/tests/test_kuzu.py +4 -4
  140. cognee/tests/test_neo4j.py +4 -4
  141. cognee/tests/test_permissions.py +3 -3
  142. cognee/tests/test_relational_db_migration.py +7 -5
  143. cognee/tests/test_search_db.py +18 -24
  144. cognee/tests/test_temporal_graph.py +167 -0
  145. cognee/tests/unit/api/__init__.py +1 -0
  146. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
  147. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
  148. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +13 -16
  149. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +11 -16
  150. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +5 -4
  151. cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
  152. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
  153. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
  154. cognee/tests/unit/modules/users/__init__.py +1 -0
  155. cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
  156. cognee/tests/unit/processing/utils/utils_test.py +20 -1
  157. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/METADATA +8 -6
  158. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/RECORD +162 -89
  159. cognee/tests/unit/modules/search/search_methods_test.py +0 -225
  160. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/WHEEL +0 -0
  161. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/entry_points.txt +0 -0
  162. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/LICENSE +0 -0
  163. {cognee-0.2.4.dist-info → cognee-0.3.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,24 +1,48 @@
1
1
  import asyncio
2
2
  import math
3
3
  import os
4
-
5
- # from concurrent.futures import ProcessPoolExecutor
6
- from typing import AsyncGenerator
4
+ from pathlib import Path
5
+ from typing import Set
6
+ from typing import AsyncGenerator, Optional, List
7
7
  from uuid import NAMESPACE_OID, uuid5
8
8
 
9
9
  from cognee.infrastructure.engine import DataPoint
10
10
  from cognee.shared.CodeGraphEntities import CodeFile, Repository
11
11
 
12
-
13
- async def get_source_code_files(repo_path, language_config: dict[str, list[str]] | None = None):
12
+ # constant, declared only once
13
+ EXCLUDED_DIRS: Set[str] = {
14
+ ".venv",
15
+ "venv",
16
+ "env",
17
+ ".env",
18
+ "site-packages",
19
+ "node_modules",
20
+ "dist",
21
+ "build",
22
+ ".git",
23
+ "tests",
24
+ "test",
25
+ }
26
+
27
+
28
+ async def get_source_code_files(
29
+ repo_path,
30
+ language_config: dict[str, list[str]] | None = None,
31
+ excluded_paths: Optional[List[str]] = None,
32
+ ):
14
33
  """
15
- Retrieve source code files from the specified repository path for multiple languages.
34
+ Retrieve Python source code files from the specified repository path.
35
+
36
+ This function scans the given repository path for files that have the .py extension
37
+ while excluding test files and files within a virtual environment. It returns a list of
38
+ absolute paths to the source code files that are not empty.
16
39
 
17
40
  Parameters:
18
41
  -----------
19
- - repo_path: The file path to the repository to search for source files.
20
- - language_config: dict mapping language names to file extensions, e.g.,
42
+ - repo_path: Root path of the repository to search
43
+ - language_config: dict mapping language names to file extensions, e.g.,
21
44
  {'python': ['.py'], 'javascript': ['.js', '.jsx'], ...}
45
+ - excluded_paths: Optional list of path fragments or glob patterns to exclude
22
46
 
23
47
  Returns:
24
48
  --------
@@ -54,28 +78,23 @@ async def get_source_code_files(repo_path, language_config: dict[str, list[str]]
54
78
  lang = _get_language_from_extension(file, language_config)
55
79
  if lang is None:
56
80
  continue
57
- # Exclude tests and common build/venv directories
58
- excluded_dirs = {
59
- ".venv",
60
- "venv",
61
- "env",
62
- ".env",
63
- "site-packages",
64
- "node_modules",
65
- "dist",
66
- "build",
67
- ".git",
68
- "tests",
69
- "test",
70
- }
71
- root_parts = set(os.path.normpath(root).split(os.sep))
81
+ # Exclude tests, common build/venv directories and files provided in exclude_paths
82
+ excluded_dirs = EXCLUDED_DIRS
83
+ excluded_paths = {Path(p).resolve() for p in (excluded_paths or [])} # full paths
84
+
85
+ root_path = Path(root).resolve()
86
+ root_parts = set(root_path.parts) # same as before
72
87
  base_name, _ext = os.path.splitext(file)
73
88
  if (
74
89
  base_name.startswith("test_")
75
- or base_name.endswith("_test") # catches Go's *_test.go and similar
90
+ or base_name.endswith("_test")
76
91
  or ".test." in file
77
92
  or ".spec." in file
78
- or (excluded_dirs & root_parts)
93
+ or (excluded_dirs & root_parts) # name match
94
+ or any(
95
+ root_path.is_relative_to(p) # full-path match
96
+ for p in excluded_paths
97
+ )
79
98
  ):
80
99
  continue
81
100
  file_path = os.path.abspath(os.path.join(root, file))
@@ -115,7 +134,10 @@ def run_coroutine(coroutine_func, *args, **kwargs):
115
134
 
116
135
 
117
136
  async def get_repo_file_dependencies(
118
- repo_path: str, detailed_extraction: bool = False, supported_languages: list = None
137
+ repo_path: str,
138
+ detailed_extraction: bool = False,
139
+ supported_languages: list = None,
140
+ excluded_paths: Optional[List[str]] = None,
119
141
  ) -> AsyncGenerator[DataPoint, None]:
120
142
  """
121
143
  Generate a dependency graph for source files (multi-language) in the given repository path.
@@ -150,6 +172,7 @@ async def get_repo_file_dependencies(
150
172
  "go": [".go"],
151
173
  "rust": [".rs"],
152
174
  "cpp": [".cpp", ".c", ".h", ".hpp"],
175
+ "c": [".c", ".h"],
153
176
  }
154
177
  if supported_languages is not None:
155
178
  language_config = {
@@ -158,7 +181,9 @@ async def get_repo_file_dependencies(
158
181
  else:
159
182
  language_config = default_language_config
160
183
 
161
- source_code_files = await get_source_code_files(repo_path, language_config=language_config)
184
+ source_code_files = await get_source_code_files(
185
+ repo_path, language_config=language_config, excluded_paths=excluded_paths
186
+ )
162
187
 
163
188
  repo = Repository(
164
189
  id=uuid5(NAMESPACE_OID, repo_path),
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,85 @@
1
+ from cognee.modules.engine.models import Event
2
+ from cognee.tasks.temporal_graph.models import EventWithEntities
3
+ from cognee.modules.engine.models.Entity import Entity
4
+ from cognee.modules.engine.models.EntityType import EntityType
5
+ from cognee.infrastructure.engine.models.Edge import Edge
6
+ from cognee.modules.engine.utils import generate_node_id, generate_node_name
7
+
8
+
9
+ def add_entities_to_event(event: Event, event_with_entities: EventWithEntities) -> None:
10
+ """
11
+ Adds extracted entities to an Event object by populating its attributes field.
12
+
13
+ For each attribute in the provided EventWithEntities, the function ensures that
14
+ the corresponding entity type exists, creates an Entity node with metadata, and
15
+ links it to the event via an Edge representing the relationship. Entities are
16
+ cached by type to avoid duplication.
17
+
18
+ Args:
19
+ event (Event): The target Event object to enrich with entities.
20
+ event_with_entities (EventWithEntities): An event model containing extracted
21
+ attributes with entity, type, and relationship metadata.
22
+
23
+ Returns:
24
+ None
25
+ """
26
+
27
+ if not event_with_entities.attributes:
28
+ return
29
+
30
+ # Create entity types cache
31
+ entity_types = {}
32
+
33
+ # Process each attribute
34
+ for attribute in event_with_entities.attributes:
35
+ # Get or create entity type
36
+ entity_type = get_or_create_entity_type(entity_types, attribute.entity_type)
37
+
38
+ # Create entity
39
+ entity_id = generate_node_id(attribute.entity)
40
+ entity_name = generate_node_name(attribute.entity)
41
+ entity = Entity(
42
+ id=entity_id,
43
+ name=entity_name,
44
+ is_a=entity_type,
45
+ description=f"Entity {attribute.entity} of type {attribute.entity_type}",
46
+ ontology_valid=False,
47
+ belongs_to_set=None,
48
+ )
49
+
50
+ # Create edge
51
+ edge = Edge(relationship_type=attribute.relationship)
52
+
53
+ # Add to event attributes
54
+ if event.attributes is None:
55
+ event.attributes = []
56
+ event.attributes.append((edge, [entity]))
57
+
58
+
59
+ def get_or_create_entity_type(entity_types: dict, entity_type_name: str) -> EntityType:
60
+ """
61
+ Retrieves an existing EntityType from the cache or creates a new one if it does not exist.
62
+
63
+ If the given entity type name is not already in the cache, a new EntityType is generated
64
+ with a unique ID, normalized name, and description, then added to the cache.
65
+
66
+ Args:
67
+ entity_types (dict): A cache mapping entity type names to EntityType objects.
68
+ entity_type_name (str): The name of the entity type to retrieve or create.
69
+
70
+ Returns:
71
+ EntityType: The existing or newly created EntityType object.
72
+ """
73
+ if entity_type_name not in entity_types:
74
+ type_id = generate_node_id(entity_type_name)
75
+ type_name = generate_node_name(entity_type_name)
76
+ entity_type = EntityType(
77
+ id=type_id,
78
+ name=type_name,
79
+ type=type_name,
80
+ description=f"Type for {entity_type_name}",
81
+ ontology_valid=False,
82
+ )
83
+ entity_types[entity_type_name] = entity_type
84
+
85
+ return entity_types[entity_type_name]
@@ -0,0 +1,34 @@
1
+ from typing import List
2
+
3
+ from cognee.infrastructure.llm import LLMGateway
4
+ from cognee.modules.engine.models import Event
5
+ from cognee.tasks.temporal_graph.models import EventWithEntities, EventEntityList
6
+
7
+
8
+ async def enrich_events(events: List[Event]) -> List[EventWithEntities]:
9
+ """
10
+ Enriches a list of events by extracting entities using an LLM.
11
+
12
+ The function serializes event data into JSON, sends it to the LLM for
13
+ entity extraction, and returns enriched events with associated entities.
14
+
15
+ Args:
16
+ events (List[Event]): A list of Event objects to be enriched.
17
+
18
+ Returns:
19
+ List[EventWithEntities]: A list of events augmented with extracted entities.
20
+ """
21
+
22
+ import json
23
+
24
+ # Convert events to JSON format for LLM processing
25
+ events_json = [
26
+ {"event_name": event.name, "description": event.description or ""} for event in events
27
+ ]
28
+
29
+ events_json_str = json.dumps(events_json)
30
+
31
+ # Extract entities from events
32
+ entity_result = await LLMGateway.extract_event_entities(events_json_str, EventEntityList)
33
+
34
+ return entity_result.events
@@ -0,0 +1,32 @@
1
+ import asyncio
2
+ from typing import Type, List
3
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
4
+ from cognee.modules.chunking.models import DocumentChunk
5
+ from cognee.tasks.temporal_graph.models import EventList
6
+ from cognee.modules.engine.utils.generate_event_datapoint import generate_event_datapoint
7
+
8
+
9
+ async def extract_events_and_timestamps(data_chunks: List[DocumentChunk]) -> List[DocumentChunk]:
10
+ """
11
+ Extracts events and their timestamps from document chunks using an LLM.
12
+
13
+ Each document chunk is processed with the event graph extractor to identify events.
14
+ The extracted events are converted into Event datapoints and appended to the
15
+ chunk's `contains` list.
16
+
17
+ Args:
18
+ data_chunks (List[DocumentChunk]): A list of document chunks containing text to process.
19
+
20
+ Returns:
21
+ List[DocumentChunk]: The same list of document chunks, enriched with extracted Event datapoints.
22
+ """
23
+ events = await asyncio.gather(
24
+ *[LLMGateway.extract_event_graph(chunk.text, EventList) for chunk in data_chunks]
25
+ )
26
+
27
+ for data_chunk, event_list in zip(data_chunks, events):
28
+ for event in event_list.events:
29
+ event_datapoint = generate_event_datapoint(event)
30
+ data_chunk.contains.append(event_datapoint)
31
+
32
+ return data_chunks
@@ -0,0 +1,41 @@
1
+ from typing import List
2
+ from cognee.modules.chunking.models import DocumentChunk
3
+ from cognee.modules.engine.models import Event
4
+ from cognee.tasks.temporal_graph.enrich_events import enrich_events
5
+ from cognee.tasks.temporal_graph.add_entities_to_event import add_entities_to_event
6
+
7
+
8
+ async def extract_knowledge_graph_from_events(
9
+ data_chunks: List[DocumentChunk],
10
+ ) -> List[DocumentChunk]:
11
+ """
12
+ Extracts events from document chunks and enriches them with entities to form a knowledge graph.
13
+
14
+ The function collects all Event objects from the given document chunks,
15
+ uses an LLM to extract and attach related entities, and updates the events
16
+ with these enriched attributes.
17
+
18
+ Args:
19
+ data_chunks (List[DocumentChunk]): A list of document chunks containing extracted events.
20
+
21
+ Returns:
22
+ List[DocumentChunk]: The same list of document chunks, with their events enriched by entities.
23
+ """
24
+ # Extract events from chunks
25
+ all_events = []
26
+ for chunk in data_chunks:
27
+ for item in chunk.contains:
28
+ if isinstance(item, Event):
29
+ all_events.append(item)
30
+
31
+ if not all_events:
32
+ return data_chunks
33
+
34
+ # Enrich events with entities
35
+ enriched_events = await enrich_events(all_events)
36
+
37
+ # Add entities to events
38
+ for event, enriched_event in zip(all_events, enriched_events):
39
+ add_entities_to_event(event, enriched_event)
40
+
41
+ return data_chunks
@@ -0,0 +1,49 @@
1
+ from typing import Optional, List
2
+ from pydantic import BaseModel, Field
3
+
4
+
5
+ class Timestamp(BaseModel):
6
+ year: int = Field(..., ge=1, le=9999)
7
+ month: int = Field(..., ge=1, le=12)
8
+ day: int = Field(..., ge=1, le=31)
9
+ hour: int = Field(..., ge=0, le=23)
10
+ minute: int = Field(..., ge=0, le=59)
11
+ second: int = Field(..., ge=0, le=59)
12
+
13
+
14
+ class Interval(BaseModel):
15
+ starts_at: Timestamp
16
+ ends_at: Timestamp
17
+
18
+
19
+ class QueryInterval(BaseModel):
20
+ starts_at: Optional[Timestamp] = None
21
+ ends_at: Optional[Timestamp] = None
22
+
23
+
24
+ class Event(BaseModel):
25
+ name: str
26
+ description: Optional[str] = None
27
+ time_from: Optional[Timestamp] = None
28
+ time_to: Optional[Timestamp] = None
29
+ location: Optional[str] = None
30
+
31
+
32
+ class EventList(BaseModel):
33
+ events: List[Event]
34
+
35
+
36
+ class EntityAttribute(BaseModel):
37
+ entity: str
38
+ entity_type: str
39
+ relationship: str
40
+
41
+
42
+ class EventWithEntities(BaseModel):
43
+ event_name: str
44
+ description: Optional[str] = None
45
+ attributes: List[EntityAttribute] = []
46
+
47
+
48
+ class EventEntityList(BaseModel):
49
+ events: List[EventWithEntities]
cognee/tests/test_kuzu.py CHANGED
@@ -94,21 +94,21 @@ async def main():
94
94
 
95
95
  await cognee.cognify([dataset_name])
96
96
 
97
- context_nonempty, _ = await GraphCompletionRetriever(
97
+ context_nonempty = await GraphCompletionRetriever(
98
98
  node_type=NodeSet,
99
99
  node_name=["first"],
100
100
  ).get_context("What is in the context?")
101
101
 
102
- context_empty, _ = await GraphCompletionRetriever(
102
+ context_empty = await GraphCompletionRetriever(
103
103
  node_type=NodeSet,
104
104
  node_name=["nonexistent"],
105
105
  ).get_context("What is in the context?")
106
106
 
107
- assert isinstance(context_nonempty, str) and context_nonempty != "", (
107
+ assert isinstance(context_nonempty, list) and context_nonempty != [], (
108
108
  f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
109
109
  )
110
110
 
111
- assert context_empty == "", (
111
+ assert context_empty == [], (
112
112
  f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
113
113
  )
114
114
 
@@ -98,21 +98,21 @@ async def main():
98
98
 
99
99
  await cognee.cognify([dataset_name])
100
100
 
101
- context_nonempty, _ = await GraphCompletionRetriever(
101
+ context_nonempty = await GraphCompletionRetriever(
102
102
  node_type=NodeSet,
103
103
  node_name=["first"],
104
104
  ).get_context("What is in the context?")
105
105
 
106
- context_empty, _ = await GraphCompletionRetriever(
106
+ context_empty = await GraphCompletionRetriever(
107
107
  node_type=NodeSet,
108
108
  node_name=["nonexistent"],
109
109
  ).get_context("What is in the context?")
110
110
 
111
- assert isinstance(context_nonempty, str) and context_nonempty != "", (
111
+ assert isinstance(context_nonempty, list) and context_nonempty != [], (
112
112
  f"Nodeset_search_test:Expected non-empty string for context_nonempty, got: {context_nonempty!r}"
113
113
  )
114
114
 
115
- assert context_empty == "", (
115
+ assert context_empty == [], (
116
116
  f"Nodeset_search_test:Expected empty string for context_empty, got: {context_empty!r}"
117
117
  )
118
118
 
@@ -79,7 +79,7 @@ async def main():
79
79
  print("\n\nExtracted sentences are:\n")
80
80
  for result in search_results:
81
81
  print(f"{result}\n")
82
- assert search_results[0]["dataset_name"] == "NLP", (
82
+ assert search_results[0].dataset_name == "NLP", (
83
83
  f"Dict must contain dataset name 'NLP': {search_results[0]}"
84
84
  )
85
85
 
@@ -93,7 +93,7 @@ async def main():
93
93
  print("\n\nExtracted sentences are:\n")
94
94
  for result in search_results:
95
95
  print(f"{result}\n")
96
- assert search_results[0]["dataset_name"] == "QUANTUM", (
96
+ assert search_results[0].dataset_name == "QUANTUM", (
97
97
  f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
98
98
  )
99
99
 
@@ -170,7 +170,7 @@ async def main():
170
170
  for result in search_results:
171
171
  print(f"{result}\n")
172
172
 
173
- assert search_results[0]["dataset_name"] == "QUANTUM", (
173
+ assert search_results[0].dataset_name == "QUANTUM", (
174
174
  f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
175
175
  )
176
176
 
@@ -1,6 +1,6 @@
1
- import json
2
1
  import pathlib
3
2
  import os
3
+ from typing import List
4
4
  from cognee.infrastructure.databases.graph import get_graph_engine
5
5
  from cognee.infrastructure.databases.relational import (
6
6
  get_migration_relational_engine,
@@ -10,7 +10,7 @@ from cognee.infrastructure.databases.vector.pgvector import (
10
10
  create_db_and_tables as create_pgvector_db_and_tables,
11
11
  )
12
12
  from cognee.tasks.ingestion import migrate_relational_database
13
- from cognee.modules.search.types import SearchType
13
+ from cognee.modules.search.types import SearchResult, SearchType
14
14
  import cognee
15
15
 
16
16
 
@@ -45,13 +45,15 @@ async def relational_db_migration():
45
45
  await migrate_relational_database(graph_engine, schema=schema)
46
46
 
47
47
  # 1. Search the graph
48
- search_results = await cognee.search(
48
+ search_results: List[SearchResult] = await cognee.search(
49
49
  query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
50
- )
50
+ ) # type: ignore
51
51
  print("Search results:", search_results)
52
52
 
53
53
  # 2. Assert that the search results contain "AC/DC"
54
- assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
54
+ assert any("AC/DC" in r.search_result for r in search_results), (
55
+ "AC/DC not found in search results!"
56
+ )
55
57
 
56
58
  migration_db_provider = migration_engine.engine.dialect.name
57
59
  if migration_db_provider == "postgresql":
@@ -1,11 +1,7 @@
1
- import os
2
- import pathlib
3
-
4
- from dns.e164 import query
5
-
6
1
  import cognee
7
2
  from cognee.infrastructure.databases.graph import get_graph_engine
8
3
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
4
+ from cognee.modules.graph.utils import resolve_edges_to_text
9
5
  from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
10
6
  from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
11
7
  GraphCompletionContextExtensionRetriever,
@@ -14,11 +10,8 @@ from cognee.modules.retrieval.graph_completion_cot_retriever import GraphComplet
14
10
  from cognee.modules.retrieval.graph_summary_completion_retriever import (
15
11
  GraphSummaryCompletionRetriever,
16
12
  )
17
- from cognee.modules.search.operations import get_history
18
- from cognee.modules.users.methods import get_default_user
19
13
  from cognee.shared.logging_utils import get_logger
20
14
  from cognee.modules.search.types import SearchType
21
- from cognee.modules.engine.models import NodeSet
22
15
  from collections import Counter
23
16
 
24
17
  logger = get_logger()
@@ -46,16 +39,16 @@ async def main():
46
39
 
47
40
  await cognee.cognify([dataset_name])
48
41
 
49
- context_gk, _ = await GraphCompletionRetriever().get_context(
42
+ context_gk = await GraphCompletionRetriever().get_context(
50
43
  query="Next to which country is Germany located?"
51
44
  )
52
- context_gk_cot, _ = await GraphCompletionCotRetriever().get_context(
45
+ context_gk_cot = await GraphCompletionCotRetriever().get_context(
53
46
  query="Next to which country is Germany located?"
54
47
  )
55
- context_gk_ext, _ = await GraphCompletionContextExtensionRetriever().get_context(
48
+ context_gk_ext = await GraphCompletionContextExtensionRetriever().get_context(
56
49
  query="Next to which country is Germany located?"
57
50
  )
58
- context_gk_sum, _ = await GraphSummaryCompletionRetriever().get_context(
51
+ context_gk_sum = await GraphSummaryCompletionRetriever().get_context(
59
52
  query="Next to which country is Germany located?"
60
53
  )
61
54
 
@@ -65,9 +58,11 @@ async def main():
65
58
  ("GraphCompletionContextExtensionRetriever", context_gk_ext),
66
59
  ("GraphSummaryCompletionRetriever", context_gk_sum),
67
60
  ]:
68
- assert isinstance(context, str), f"{name}: Context should be a string"
69
- assert context.strip(), f"{name}: Context should not be empty"
70
- lower = context.lower()
61
+ assert isinstance(context, list), f"{name}: Context should be a list"
62
+ assert len(context) > 0, f"{name}: Context should not be empty"
63
+
64
+ context_text = await resolve_edges_to_text(context)
65
+ lower = context_text.lower()
71
66
  assert "germany" in lower or "netherlands" in lower, (
72
67
  f"{name}: Context did not contain 'germany' or 'netherlands'; got: {context!r}"
73
68
  )
@@ -143,20 +138,19 @@ async def main():
143
138
  last_k=1,
144
139
  )
145
140
 
146
- for name, completion in [
141
+ for name, search_results in [
147
142
  ("GRAPH_COMPLETION", completion_gk),
148
143
  ("GRAPH_COMPLETION_COT", completion_cot),
149
144
  ("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
150
145
  ("GRAPH_SUMMARY_COMPLETION", completion_sum),
151
146
  ]:
152
- assert isinstance(completion, list), f"{name}: should return a list"
153
- assert len(completion) == 1, f"{name}: expected single-element list, got {len(completion)}"
154
- text = completion[0]
155
- assert isinstance(text, str), f"{name}: element should be a string"
156
- assert text.strip(), f"{name}: string should not be empty"
157
- assert "netherlands" in text.lower(), (
158
- f"{name}: expected 'netherlands' in result, got: {text!r}"
159
- )
147
+ for search_result in search_results:
148
+ completion = search_result.search_result
149
+ assert isinstance(completion, str), f"{name}: should return a string"
150
+ assert completion.strip(), f"{name}: string should not be empty"
151
+ assert "netherlands" in completion.lower(), (
152
+ f"{name}: expected 'netherlands' in result, got: {completion!r}"
153
+ )
160
154
 
161
155
  graph_engine = await get_graph_engine()
162
156
  graph = await graph_engine.get_graph_data()