cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +8 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +46 -18
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/modules/search/test_search.py +100 -0
  113. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
  116. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
  117. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  118. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  119. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  120. cognee/modules/retrieval/code_retriever.py +0 -232
  121. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  122. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  123. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  124. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  125. cognee/tasks/repo_processor/__init__.py +0 -2
  126. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  127. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  128. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  129. cognee/tests/test_delete_bmw_example.py +0 -60
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
  132. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,232 +0,0 @@
1
- from typing import Any, Optional, List
2
- import asyncio
3
- import aiofiles
4
- from pydantic import BaseModel
5
-
6
- from cognee.shared.logging_utils import get_logger
7
- from cognee.modules.retrieval.base_retriever import BaseRetriever
8
- from cognee.infrastructure.databases.graph import get_graph_engine
9
- from cognee.infrastructure.databases.vector import get_vector_engine
10
- from cognee.infrastructure.llm.prompts import read_query_prompt
11
- from cognee.infrastructure.llm.LLMGateway import LLMGateway
12
-
13
- logger = get_logger("CodeRetriever")
14
-
15
-
16
- class CodeRetriever(BaseRetriever):
17
- """Retriever for handling code-based searches."""
18
-
19
- class CodeQueryInfo(BaseModel):
20
- """
21
- Model for representing the result of a query related to code files.
22
-
23
- This class holds a list of filenames and the corresponding source code extracted from a
24
- query. It is used to encapsulate response data in a structured format.
25
- """
26
-
27
- filenames: List[str] = []
28
- sourcecode: str
29
-
30
- def __init__(self, top_k: int = 3):
31
- """Initialize retriever with search parameters."""
32
- self.top_k = top_k
33
- self.file_name_collections = ["CodeFile_name"]
34
- self.classes_and_functions_collections = [
35
- "ClassDefinition_source_code",
36
- "FunctionDefinition_source_code",
37
- ]
38
-
39
- async def _process_query(self, query: str) -> "CodeRetriever.CodeQueryInfo":
40
- """Process the query using LLM to extract file names and source code parts."""
41
- logger.debug(
42
- f"Processing query with LLM: '{query[:100]}{'...' if len(query) > 100 else ''}'"
43
- )
44
-
45
- system_prompt = read_query_prompt("codegraph_retriever_system.txt")
46
-
47
- try:
48
- result = await LLMGateway.acreate_structured_output(
49
- text_input=query,
50
- system_prompt=system_prompt,
51
- response_model=self.CodeQueryInfo,
52
- )
53
- logger.info(
54
- f"LLM extracted {len(result.filenames)} filenames and {len(result.sourcecode)} chars of source code"
55
- )
56
- return result
57
- except Exception as e:
58
- logger.error(f"Failed to retrieve structured output from LLM: {str(e)}")
59
- raise RuntimeError("Failed to retrieve structured output from LLM") from e
60
-
61
- async def get_context(self, query: str) -> Any:
62
- """Find relevant code files based on the query."""
63
- logger.info(
64
- f"Starting code retrieval for query: '{query[:100]}{'...' if len(query) > 100 else ''}'"
65
- )
66
-
67
- if not query or not isinstance(query, str):
68
- logger.error("Invalid query: must be a non-empty string")
69
- raise ValueError("The query must be a non-empty string.")
70
-
71
- try:
72
- vector_engine = get_vector_engine()
73
- graph_engine = await get_graph_engine()
74
- logger.debug("Successfully initialized vector and graph engines")
75
- except Exception as e:
76
- logger.error(f"Database initialization error: {str(e)}")
77
- raise RuntimeError("Database initialization error in code_graph_retriever, ") from e
78
-
79
- files_and_codeparts = await self._process_query(query)
80
-
81
- similar_filenames = []
82
- similar_codepieces = []
83
-
84
- if not files_and_codeparts.filenames or not files_and_codeparts.sourcecode:
85
- logger.info("No specific files/code extracted from query, performing general search")
86
-
87
- for collection in self.file_name_collections:
88
- logger.debug(f"Searching {collection} collection with general query")
89
- search_results_file = await vector_engine.search(
90
- collection, query, limit=self.top_k
91
- )
92
- logger.debug(f"Found {len(search_results_file)} results in {collection}")
93
- for res in search_results_file:
94
- similar_filenames.append(
95
- {"id": res.id, "score": res.score, "payload": res.payload}
96
- )
97
-
98
- existing_collection = []
99
- for collection in self.classes_and_functions_collections:
100
- if await vector_engine.has_collection(collection):
101
- existing_collection.append(collection)
102
-
103
- if not existing_collection:
104
- raise RuntimeError("No collection found for code retriever")
105
-
106
- for collection in existing_collection:
107
- logger.debug(f"Searching {collection} collection with general query")
108
- search_results_code = await vector_engine.search(
109
- collection, query, limit=self.top_k
110
- )
111
- logger.debug(f"Found {len(search_results_code)} results in {collection}")
112
- for res in search_results_code:
113
- similar_codepieces.append(
114
- {"id": res.id, "score": res.score, "payload": res.payload}
115
- )
116
- else:
117
- logger.info(
118
- f"Using extracted filenames ({len(files_and_codeparts.filenames)}) and source code for targeted search"
119
- )
120
-
121
- for collection in self.file_name_collections:
122
- for file_from_query in files_and_codeparts.filenames:
123
- logger.debug(f"Searching {collection} for specific file: {file_from_query}")
124
- search_results_file = await vector_engine.search(
125
- collection, file_from_query, limit=self.top_k
126
- )
127
- logger.debug(
128
- f"Found {len(search_results_file)} results for file {file_from_query}"
129
- )
130
- for res in search_results_file:
131
- similar_filenames.append(
132
- {"id": res.id, "score": res.score, "payload": res.payload}
133
- )
134
-
135
- for collection in self.classes_and_functions_collections:
136
- logger.debug(f"Searching {collection} with extracted source code")
137
- search_results_code = await vector_engine.search(
138
- collection, files_and_codeparts.sourcecode, limit=self.top_k
139
- )
140
- logger.debug(f"Found {len(search_results_code)} results for source code search")
141
- for res in search_results_code:
142
- similar_codepieces.append(
143
- {"id": res.id, "score": res.score, "payload": res.payload}
144
- )
145
-
146
- total_items = len(similar_filenames) + len(similar_codepieces)
147
- logger.info(
148
- f"Total search results: {total_items} items ({len(similar_filenames)} filenames, {len(similar_codepieces)} code pieces)"
149
- )
150
-
151
- if total_items == 0:
152
- logger.warning("No search results found, returning empty list")
153
- return []
154
-
155
- logger.debug("Getting graph connections for all search results")
156
- relevant_triplets = await asyncio.gather(
157
- *[
158
- graph_engine.get_connections(similar_piece["id"])
159
- for similar_piece in similar_filenames + similar_codepieces
160
- ]
161
- )
162
- logger.info(f"Retrieved graph connections for {len(relevant_triplets)} items")
163
-
164
- paths = set()
165
- for i, sublist in enumerate(relevant_triplets):
166
- logger.debug(f"Processing connections for item {i}: {len(sublist)} connections")
167
- for tpl in sublist:
168
- if isinstance(tpl, tuple) and len(tpl) >= 3:
169
- if "file_path" in tpl[0]:
170
- paths.add(tpl[0]["file_path"])
171
- if "file_path" in tpl[2]:
172
- paths.add(tpl[2]["file_path"])
173
-
174
- logger.info(f"Found {len(paths)} unique file paths to read")
175
-
176
- retrieved_files = {}
177
- read_tasks = []
178
- for file_path in paths:
179
-
180
- async def read_file(fp):
181
- try:
182
- logger.debug(f"Reading file: {fp}")
183
- async with aiofiles.open(fp, "r", encoding="utf-8") as f:
184
- content = await f.read()
185
- retrieved_files[fp] = content
186
- logger.debug(f"Successfully read {len(content)} characters from {fp}")
187
- except Exception as e:
188
- logger.error(f"Error reading {fp}: {e}")
189
- retrieved_files[fp] = ""
190
-
191
- read_tasks.append(read_file(file_path))
192
-
193
- await asyncio.gather(*read_tasks)
194
- logger.info(
195
- f"Successfully read {len([f for f in retrieved_files.values() if f])} files (out of {len(paths)} total)"
196
- )
197
-
198
- result = [
199
- {
200
- "name": file_path,
201
- "description": file_path,
202
- "content": retrieved_files[file_path],
203
- }
204
- for file_path in paths
205
- ]
206
-
207
- logger.info(f"Returning {len(result)} code file contexts")
208
- return result
209
-
210
- async def get_completion(
211
- self, query: str, context: Optional[Any] = None, session_id: Optional[str] = None
212
- ) -> Any:
213
- """
214
- Returns the code files context.
215
-
216
- Parameters:
217
- -----------
218
-
219
- - query (str): The query string to retrieve code context for.
220
- - context (Optional[Any]): Optional pre-fetched context; if None, it retrieves
221
- the context for the query. (default None)
222
- - session_id (Optional[str]): Optional session identifier for caching. If None,
223
- defaults to 'default_session'. (default None)
224
-
225
- Returns:
226
- --------
227
-
228
- - Any: The code files context, either provided or retrieved.
229
- """
230
- if context is None:
231
- context = await self.get_context(query)
232
- return context
@@ -1,35 +0,0 @@
1
- import os
2
- import asyncio
3
- import argparse
4
- from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
5
- from cognee.tasks.repo_processor.enrich_dependency_graph import enrich_dependency_graph
6
-
7
-
8
- def main():
9
- """
10
- Execute the main logic of the dependency graph processor.
11
-
12
- This function sets up argument parsing to retrieve the repository path, checks the
13
- existence of the specified path, and processes the repository to produce a dependency
14
- graph. If the repository path does not exist, it logs an error message and terminates
15
- without further execution.
16
- """
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument("repo_path", help="Path to the repository")
19
- args = parser.parse_args()
20
-
21
- repo_path = args.repo_path
22
- if not os.path.exists(repo_path):
23
- print(f"Error: The provided repository path does not exist: {repo_path}")
24
- return
25
-
26
- graph = asyncio.run(get_repo_file_dependencies(repo_path))
27
- graph = asyncio.run(enrich_dependency_graph(graph))
28
- for node in graph.nodes:
29
- print(f"Node: {node}")
30
- for _, target, data in graph.out_edges(node, data=True):
31
- print(f" Edge to {target}, data: {data}")
32
-
33
-
34
- if __name__ == "__main__":
35
- main()
@@ -1,20 +0,0 @@
1
- import argparse
2
- import asyncio
3
- from cognee.tasks.repo_processor.get_local_dependencies import get_local_script_dependencies
4
-
5
- if __name__ == "__main__":
6
- parser = argparse.ArgumentParser(description="Get local script dependencies.")
7
-
8
- # Suggested path: .../cognee/examples/python/simple_example.py
9
- parser.add_argument("script_path", type=str, help="Absolute path to the Python script file")
10
-
11
- # Suggested path: .../cognee
12
- parser.add_argument("repo_path", type=str, help="Absolute path to the repository root")
13
-
14
- args = parser.parse_args()
15
-
16
- dependencies = asyncio.run(get_local_script_dependencies(args.script_path, args.repo_path))
17
-
18
- print("Dependencies:")
19
- for dependency in dependencies:
20
- print(dependency)
@@ -1,35 +0,0 @@
1
- import os
2
- import asyncio
3
- import argparse
4
- from cognee.tasks.repo_processor.get_repo_file_dependencies import get_repo_file_dependencies
5
-
6
-
7
- def main():
8
- """
9
- Parse the command line arguments and print the repository file dependencies.
10
-
11
- This function sets up an argument parser to retrieve the path of a repository. It checks
12
- if the provided path exists and if it doesn’t, it prints an error message and exits. If
13
- the path is valid, it calls an asynchronous function to get the dependencies and prints
14
- the nodes and their relations in the dependency graph.
15
- """
16
- parser = argparse.ArgumentParser()
17
- parser.add_argument("repo_path", help="Path to the repository")
18
- args = parser.parse_args()
19
-
20
- repo_path = args.repo_path
21
- if not os.path.exists(repo_path):
22
- print(f"Error: The provided repository path does not exist: {repo_path}")
23
- return
24
-
25
- graph = asyncio.run(get_repo_file_dependencies(repo_path))
26
-
27
- for node in graph.nodes:
28
- print(f"Node: {node}")
29
- edges = graph.edges(node, data=True)
30
- for _, target, data in edges:
31
- print(f" Edge to {target}, Relation: {data.get('relation')}")
32
-
33
-
34
- if __name__ == "__main__":
35
- main()
@@ -1,26 +0,0 @@
1
- from cognee.modules.data.processing.document_types import Document
2
- from cognee.modules.users.permissions.methods import check_permission_on_dataset
3
- from typing import List
4
-
5
-
6
- async def check_permissions_on_dataset(
7
- documents: List[Document], context: dict, user, permissions
8
- ) -> List[Document]:
9
- """
10
- Validates a user's permissions on a list of documents.
11
-
12
- Notes:
13
- - This function assumes that `check_permission_on_documents` raises an exception if the permission check fails.
14
- - It is designed to validate multiple permissions in a sequential manner for the same set of documents.
15
- - Ensure that the `Document` and `user` objects conform to the expected structure and interfaces.
16
- """
17
-
18
- for permission in permissions:
19
- await check_permission_on_dataset(
20
- user,
21
- permission,
22
- # TODO: pass dataset through argument instead of context
23
- context["dataset"].id,
24
- )
25
-
26
- return documents
@@ -1,2 +0,0 @@
1
- from .get_non_code_files import get_non_py_files
2
- from .get_repo_file_dependencies import get_repo_file_dependencies
@@ -1,335 +0,0 @@
1
- import os
2
- import aiofiles
3
- import importlib
4
- from typing import AsyncGenerator, Optional
5
- from uuid import NAMESPACE_OID, uuid5
6
- import tree_sitter_python as tspython
7
- from tree_sitter import Language, Node, Parser, Tree
8
- from cognee.shared.logging_utils import get_logger
9
-
10
- from cognee.low_level import DataPoint
11
- from cognee.shared.CodeGraphEntities import (
12
- CodeFile,
13
- ImportStatement,
14
- FunctionDefinition,
15
- ClassDefinition,
16
- )
17
-
18
- logger = get_logger()
19
-
20
-
21
- class FileParser:
22
- """
23
- Handles the parsing of files into source code and an abstract syntax tree
24
- representation. Public methods include:
25
-
26
- - parse_file: Parses a file and returns its source code and syntax tree representation.
27
- """
28
-
29
- def __init__(self):
30
- self.parsed_files = {}
31
-
32
- async def parse_file(self, file_path: str) -> tuple[str, Tree]:
33
- """
34
- Parse a file and return its source code along with its syntax tree representation.
35
-
36
- If the file has already been parsed, retrieve the result from memory instead of reading
37
- the file again.
38
-
39
- Parameters:
40
- -----------
41
-
42
- - file_path (str): The path of the file to parse.
43
-
44
- Returns:
45
- --------
46
-
47
- - tuple[str, Tree]: A tuple containing the source code of the file and its
48
- corresponding syntax tree representation.
49
- """
50
- PY_LANGUAGE = Language(tspython.language())
51
- source_code_parser = Parser(PY_LANGUAGE)
52
-
53
- if file_path not in self.parsed_files:
54
- source_code = await get_source_code(file_path)
55
- source_code_tree = source_code_parser.parse(bytes(source_code, "utf-8"))
56
- self.parsed_files[file_path] = (source_code, source_code_tree)
57
-
58
- return self.parsed_files[file_path]
59
-
60
-
61
- async def get_source_code(file_path: str):
62
- """
63
- Read source code from a file asynchronously.
64
-
65
- This function attempts to open a file specified by the given file path, read its
66
- contents, and return the source code. In case of any errors during the file reading
67
- process, it logs an error message and returns None.
68
-
69
- Parameters:
70
- -----------
71
-
72
- - file_path (str): The path to the file from which to read the source code.
73
-
74
- Returns:
75
- --------
76
-
77
- Returns the contents of the file as a string if successful, or None if an error
78
- occurs.
79
- """
80
- try:
81
- async with aiofiles.open(file_path, "r", encoding="utf-8") as f:
82
- source_code = await f.read()
83
- return source_code
84
- except Exception as error:
85
- logger.error(f"Error reading file {file_path}: {str(error)}")
86
- return None
87
-
88
-
89
- def resolve_module_path(module_name):
90
- """
91
- Find the file path of a module.
92
-
93
- Return the file path of the specified module if found, or return None if the module does
94
- not exist or cannot be located.
95
-
96
- Parameters:
97
- -----------
98
-
99
- - module_name: The name of the module whose file path is to be resolved.
100
-
101
- Returns:
102
- --------
103
-
104
- The file path of the module as a string or None if the module is not found.
105
- """
106
- try:
107
- spec = importlib.util.find_spec(module_name)
108
- if spec and spec.origin:
109
- return spec.origin
110
- except ModuleNotFoundError:
111
- return None
112
- return None
113
-
114
-
115
- def find_function_location(
116
- module_path: str, function_name: str, parser: FileParser
117
- ) -> Optional[tuple[str, str]]:
118
- """
119
- Find the location of a function definition in a specified module.
120
-
121
- Parameters:
122
- -----------
123
-
124
- - module_path (str): The path to the module where the function is defined.
125
- - function_name (str): The name of the function whose location is to be found.
126
- - parser (FileParser): An instance of FileParser used to parse the module's source
127
- code.
128
-
129
- Returns:
130
- --------
131
-
132
- - Optional[tuple[str, str]]: Returns a tuple containing the module path and the
133
- start point of the function if found; otherwise, returns None.
134
- """
135
- if not module_path or not os.path.exists(module_path):
136
- return None
137
-
138
- source_code, tree = parser.parse_file(module_path)
139
- root_node: Node = tree.root_node
140
-
141
- for node in root_node.children:
142
- if node.type == "function_definition":
143
- func_name_node = node.child_by_field_name("name")
144
-
145
- if func_name_node and func_name_node.text.decode() == function_name:
146
- return (module_path, node.start_point) # (line, column)
147
-
148
- return None
149
-
150
-
151
- async def get_local_script_dependencies(
152
- repo_path: str, script_path: str, detailed_extraction: bool = False
153
- ) -> CodeFile:
154
- """
155
- Retrieve local script dependencies and create a CodeFile object.
156
-
157
- Parameters:
158
- -----------
159
-
160
- - repo_path (str): The path to the repository that contains the script.
161
- - script_path (str): The path of the script for which dependencies are being
162
- extracted.
163
- - detailed_extraction (bool): A flag indicating whether to perform a detailed
164
- extraction of code components.
165
-
166
- Returns:
167
- --------
168
-
169
- - CodeFile: Returns a CodeFile object containing information about the script,
170
- including its dependencies and definitions.
171
- """
172
- code_file_parser = FileParser()
173
- source_code, source_code_tree = await code_file_parser.parse_file(script_path)
174
-
175
- file_path_relative_to_repo = script_path[len(repo_path) + 1 :]
176
-
177
- if not detailed_extraction:
178
- code_file_node = CodeFile(
179
- id=uuid5(NAMESPACE_OID, script_path),
180
- name=file_path_relative_to_repo,
181
- source_code=source_code,
182
- file_path=script_path,
183
- language="python",
184
- )
185
- return code_file_node
186
-
187
- code_file_node = CodeFile(
188
- id=uuid5(NAMESPACE_OID, script_path),
189
- name=file_path_relative_to_repo,
190
- source_code=None,
191
- file_path=script_path,
192
- language="python",
193
- )
194
-
195
- async for part in extract_code_parts(source_code_tree.root_node, script_path=script_path):
196
- part.file_path = script_path
197
-
198
- if isinstance(part, FunctionDefinition):
199
- code_file_node.provides_function_definition.append(part)
200
- if isinstance(part, ClassDefinition):
201
- code_file_node.provides_class_definition.append(part)
202
- if isinstance(part, ImportStatement):
203
- code_file_node.depends_on.append(part)
204
-
205
- return code_file_node
206
-
207
-
208
- def find_node(nodes: list[Node], condition: callable) -> Node:
209
- """
210
- Find and return the first node that satisfies the given condition.
211
-
212
- Iterate through the provided list of nodes and return the first node for which the
213
- condition callable returns True. If no such node is found, return None.
214
-
215
- Parameters:
216
- -----------
217
-
218
- - nodes (list[Node]): A list of Node objects to search through.
219
- - condition (callable): A callable that takes a Node and returns a boolean
220
- indicating if the node meets specified criteria.
221
-
222
- Returns:
223
- --------
224
-
225
- - Node: The first Node that matches the condition, or None if no such node exists.
226
- """
227
- for node in nodes:
228
- if condition(node):
229
- return node
230
-
231
- return None
232
-
233
-
234
- async def extract_code_parts(
235
- tree_root: Node, script_path: str, existing_nodes: list[DataPoint] = {}
236
- ) -> AsyncGenerator[DataPoint, None]:
237
- """
238
- Extract code parts from a given AST node tree asynchronously.
239
-
240
- Iteratively yields DataPoint nodes representing import statements, function definitions,
241
- and class definitions found in the children of the specified tree root. The function
242
- checks
243
- if nodes are already present in the existing_nodes dictionary to prevent duplicates.
244
- This function has to be used in an asynchronous context, and it requires a valid
245
- tree_root
246
- and proper initialization of existing_nodes.
247
-
248
- Parameters:
249
- -----------
250
-
251
- - tree_root (Node): The root node of the AST tree containing code parts to extract.
252
- - script_path (str): The file path of the script from which the AST was generated.
253
- - existing_nodes (list[DataPoint]): A dictionary that holds already extracted
254
- DataPoint nodes to avoid duplicates. (default {})
255
-
256
- Returns:
257
- --------
258
-
259
- Yields DataPoint nodes representing imported modules, functions, and classes.
260
- """
261
- for child_node in tree_root.children:
262
- if child_node.type == "import_statement" or child_node.type == "import_from_statement":
263
- parts = child_node.text.decode("utf-8").split()
264
-
265
- if parts[0] == "import":
266
- module_name = parts[1]
267
- function_name = None
268
- elif parts[0] == "from":
269
- module_name = parts[1]
270
- function_name = parts[3]
271
-
272
- if " as " in function_name:
273
- function_name = function_name.split(" as ")[0]
274
-
275
- if " as " in module_name:
276
- module_name = module_name.split(" as ")[0]
277
-
278
- if function_name and "import " + function_name not in existing_nodes:
279
- import_statement_node = ImportStatement(
280
- name=function_name,
281
- module=module_name,
282
- start_point=child_node.start_point,
283
- end_point=child_node.end_point,
284
- file_path=script_path,
285
- source_code=child_node.text,
286
- )
287
- existing_nodes["import " + function_name] = import_statement_node
288
-
289
- if function_name:
290
- yield existing_nodes["import " + function_name]
291
-
292
- if module_name not in existing_nodes:
293
- import_statement_node = ImportStatement(
294
- name=module_name,
295
- module=module_name,
296
- start_point=child_node.start_point,
297
- end_point=child_node.end_point,
298
- file_path=script_path,
299
- source_code=child_node.text,
300
- )
301
- existing_nodes[module_name] = import_statement_node
302
-
303
- yield existing_nodes[module_name]
304
-
305
- if child_node.type == "function_definition":
306
- function_node = find_node(child_node.children, lambda node: node.type == "identifier")
307
- function_node_name = function_node.text
308
-
309
- if function_node_name not in existing_nodes:
310
- function_definition_node = FunctionDefinition(
311
- name=function_node_name,
312
- start_point=child_node.start_point,
313
- end_point=child_node.end_point,
314
- file_path=script_path,
315
- source_code=child_node.text,
316
- )
317
- existing_nodes[function_node_name] = function_definition_node
318
-
319
- yield existing_nodes[function_node_name]
320
-
321
- if child_node.type == "class_definition":
322
- class_name_node = find_node(child_node.children, lambda node: node.type == "identifier")
323
- class_name_node_name = class_name_node.text
324
-
325
- if class_name_node_name not in existing_nodes:
326
- class_definition_node = ClassDefinition(
327
- name=class_name_node_name,
328
- start_point=child_node.start_point,
329
- end_point=child_node.end_point,
330
- file_path=script_path,
331
- source_code=child_node.text,
332
- )
333
- existing_nodes[class_name_node_name] = class_definition_node
334
-
335
- yield existing_nodes[class_name_node_name]