cognee 0.2.1.dev7__py3-none-any.whl → 0.2.2.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (223) hide show
  1. cognee/api/client.py +44 -4
  2. cognee/api/health.py +332 -0
  3. cognee/api/v1/add/add.py +5 -2
  4. cognee/api/v1/add/routers/get_add_router.py +3 -0
  5. cognee/api/v1/cognify/code_graph_pipeline.py +3 -1
  6. cognee/api/v1/cognify/cognify.py +8 -0
  7. cognee/api/v1/cognify/routers/get_cognify_router.py +8 -1
  8. cognee/api/v1/config/config.py +3 -1
  9. cognee/api/v1/datasets/routers/get_datasets_router.py +2 -8
  10. cognee/api/v1/delete/delete.py +16 -12
  11. cognee/api/v1/responses/routers/get_responses_router.py +3 -1
  12. cognee/api/v1/search/search.py +10 -0
  13. cognee/api/v1/settings/routers/get_settings_router.py +0 -2
  14. cognee/base_config.py +1 -0
  15. cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +5 -6
  16. cognee/infrastructure/databases/graph/config.py +2 -0
  17. cognee/infrastructure/databases/graph/get_graph_engine.py +58 -12
  18. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -10
  19. cognee/infrastructure/databases/graph/kuzu/adapter.py +43 -16
  20. cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +281 -0
  21. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +151 -77
  22. cognee/infrastructure/databases/graph/neptune_driver/__init__.py +15 -0
  23. cognee/infrastructure/databases/graph/neptune_driver/adapter.py +1427 -0
  24. cognee/infrastructure/databases/graph/neptune_driver/exceptions.py +115 -0
  25. cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +224 -0
  26. cognee/infrastructure/databases/graph/networkx/adapter.py +3 -3
  27. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +449 -0
  28. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +11 -3
  29. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +8 -3
  30. cognee/infrastructure/databases/vector/create_vector_engine.py +31 -23
  31. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +3 -1
  32. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +21 -6
  33. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -3
  34. cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +3 -1
  35. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +22 -16
  36. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +36 -34
  37. cognee/infrastructure/databases/vector/vector_db_interface.py +78 -7
  38. cognee/infrastructure/files/utils/get_data_file_path.py +39 -0
  39. cognee/infrastructure/files/utils/guess_file_type.py +2 -2
  40. cognee/infrastructure/files/utils/open_data_file.py +4 -23
  41. cognee/infrastructure/llm/LLMGateway.py +137 -0
  42. cognee/infrastructure/llm/__init__.py +14 -4
  43. cognee/infrastructure/llm/config.py +29 -1
  44. cognee/infrastructure/llm/prompts/answer_hotpot_question.txt +1 -1
  45. cognee/infrastructure/llm/prompts/answer_hotpot_using_cognee_search.txt +1 -1
  46. cognee/infrastructure/llm/prompts/answer_simple_question.txt +1 -1
  47. cognee/infrastructure/llm/prompts/answer_simple_question_restricted.txt +1 -1
  48. cognee/infrastructure/llm/prompts/categorize_categories.txt +1 -1
  49. cognee/infrastructure/llm/prompts/classify_content.txt +1 -1
  50. cognee/infrastructure/llm/prompts/context_for_question.txt +1 -1
  51. cognee/infrastructure/llm/prompts/graph_context_for_question.txt +1 -1
  52. cognee/infrastructure/llm/prompts/natural_language_retriever_system.txt +1 -1
  53. cognee/infrastructure/llm/prompts/patch_gen_instructions.txt +1 -1
  54. cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +130 -0
  55. cognee/infrastructure/llm/prompts/summarize_code.txt +2 -2
  56. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +57 -0
  57. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +533 -0
  58. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/config.py +94 -0
  59. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/globals.py +37 -0
  60. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +21 -0
  61. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +131 -0
  62. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +266 -0
  63. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +137 -0
  64. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +550 -0
  65. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/tracing.py +26 -0
  66. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +962 -0
  67. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +52 -0
  68. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +166 -0
  69. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +109 -0
  70. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +343 -0
  71. cognee/{modules/data → infrastructure/llm/structured_output_framework/baml/baml_src}/extraction/__init__.py +1 -0
  72. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +89 -0
  73. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +33 -0
  74. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +18 -0
  75. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +3 -0
  76. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_categories.py +12 -0
  77. cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/extract_summary.py +16 -7
  78. cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/extract_content_graph.py +7 -6
  79. cognee/infrastructure/llm/{anthropic → structured_output_framework/litellm_instructor/llm/anthropic}/adapter.py +10 -4
  80. cognee/infrastructure/llm/{gemini → structured_output_framework/litellm_instructor/llm/gemini}/adapter.py +6 -5
  81. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/__init__.py +0 -0
  82. cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/litellm_instructor/llm/generic_llm_api}/adapter.py +7 -3
  83. cognee/infrastructure/llm/{get_llm_client.py → structured_output_framework/litellm_instructor/llm/get_llm_client.py} +18 -6
  84. cognee/infrastructure/llm/{llm_interface.py → structured_output_framework/litellm_instructor/llm/llm_interface.py} +2 -2
  85. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/__init__.py +0 -0
  86. cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor/llm/ollama}/adapter.py +4 -2
  87. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/__init__.py +0 -0
  88. cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm/openai}/adapter.py +6 -4
  89. cognee/infrastructure/llm/{rate_limiter.py → structured_output_framework/litellm_instructor/llm/rate_limiter.py} +0 -5
  90. cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +4 -2
  91. cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +7 -3
  92. cognee/infrastructure/llm/tokenizer/__init__.py +4 -0
  93. cognee/infrastructure/llm/utils.py +3 -1
  94. cognee/infrastructure/loaders/LoaderEngine.py +156 -0
  95. cognee/infrastructure/loaders/LoaderInterface.py +73 -0
  96. cognee/infrastructure/loaders/__init__.py +18 -0
  97. cognee/infrastructure/loaders/core/__init__.py +7 -0
  98. cognee/infrastructure/loaders/core/audio_loader.py +98 -0
  99. cognee/infrastructure/loaders/core/image_loader.py +114 -0
  100. cognee/infrastructure/loaders/core/text_loader.py +90 -0
  101. cognee/infrastructure/loaders/create_loader_engine.py +32 -0
  102. cognee/infrastructure/loaders/external/__init__.py +22 -0
  103. cognee/infrastructure/loaders/external/pypdf_loader.py +96 -0
  104. cognee/infrastructure/loaders/external/unstructured_loader.py +127 -0
  105. cognee/infrastructure/loaders/get_loader_engine.py +18 -0
  106. cognee/infrastructure/loaders/supported_loaders.py +18 -0
  107. cognee/infrastructure/loaders/use_loader.py +21 -0
  108. cognee/infrastructure/loaders/utils/__init__.py +0 -0
  109. cognee/modules/data/methods/__init__.py +1 -0
  110. cognee/modules/data/methods/get_authorized_dataset.py +23 -0
  111. cognee/modules/data/models/Data.py +13 -3
  112. cognee/modules/data/processing/document_types/AudioDocument.py +2 -2
  113. cognee/modules/data/processing/document_types/ImageDocument.py +2 -2
  114. cognee/modules/data/processing/document_types/PdfDocument.py +4 -11
  115. cognee/modules/data/processing/document_types/UnstructuredDocument.py +2 -5
  116. cognee/modules/engine/utils/generate_edge_id.py +5 -0
  117. cognee/modules/graph/cognee_graph/CogneeGraph.py +45 -35
  118. cognee/modules/graph/methods/get_formatted_graph_data.py +8 -2
  119. cognee/modules/graph/utils/get_graph_from_model.py +93 -101
  120. cognee/modules/ingestion/data_types/TextData.py +8 -2
  121. cognee/modules/ingestion/save_data_to_file.py +1 -1
  122. cognee/modules/pipelines/exceptions/__init__.py +1 -0
  123. cognee/modules/pipelines/exceptions/exceptions.py +12 -0
  124. cognee/modules/pipelines/models/DataItemStatus.py +5 -0
  125. cognee/modules/pipelines/models/PipelineRunInfo.py +6 -0
  126. cognee/modules/pipelines/models/__init__.py +1 -0
  127. cognee/modules/pipelines/operations/pipeline.py +10 -2
  128. cognee/modules/pipelines/operations/run_tasks.py +252 -20
  129. cognee/modules/pipelines/operations/run_tasks_distributed.py +1 -1
  130. cognee/modules/retrieval/chunks_retriever.py +23 -1
  131. cognee/modules/retrieval/code_retriever.py +66 -9
  132. cognee/modules/retrieval/completion_retriever.py +11 -9
  133. cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +0 -2
  134. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -2
  135. cognee/modules/retrieval/graph_completion_cot_retriever.py +8 -9
  136. cognee/modules/retrieval/graph_completion_retriever.py +1 -1
  137. cognee/modules/retrieval/insights_retriever.py +4 -0
  138. cognee/modules/retrieval/natural_language_retriever.py +9 -15
  139. cognee/modules/retrieval/summaries_retriever.py +23 -1
  140. cognee/modules/retrieval/utils/brute_force_triplet_search.py +23 -4
  141. cognee/modules/retrieval/utils/completion.py +6 -9
  142. cognee/modules/retrieval/utils/description_to_codepart_search.py +2 -3
  143. cognee/modules/search/methods/search.py +5 -1
  144. cognee/modules/search/operations/__init__.py +1 -0
  145. cognee/modules/search/operations/select_search_type.py +42 -0
  146. cognee/modules/search/types/SearchType.py +1 -0
  147. cognee/modules/settings/get_settings.py +0 -8
  148. cognee/modules/settings/save_vector_db_config.py +1 -1
  149. cognee/shared/data_models.py +3 -1
  150. cognee/shared/logging_utils.py +0 -5
  151. cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
  152. cognee/tasks/documents/extract_chunks_from_documents.py +10 -12
  153. cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +4 -6
  154. cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +4 -6
  155. cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +6 -7
  156. cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +4 -7
  157. cognee/tasks/graph/extract_graph_from_code.py +3 -2
  158. cognee/tasks/graph/extract_graph_from_data.py +4 -3
  159. cognee/tasks/graph/infer_data_ontology.py +5 -6
  160. cognee/tasks/ingestion/data_item_to_text_file.py +79 -0
  161. cognee/tasks/ingestion/ingest_data.py +91 -61
  162. cognee/tasks/ingestion/resolve_data_directories.py +3 -0
  163. cognee/tasks/repo_processor/get_repo_file_dependencies.py +3 -0
  164. cognee/tasks/storage/index_data_points.py +1 -1
  165. cognee/tasks/storage/index_graph_edges.py +4 -1
  166. cognee/tasks/summarization/summarize_code.py +2 -3
  167. cognee/tasks/summarization/summarize_text.py +3 -2
  168. cognee/tests/test_cognee_server_start.py +12 -7
  169. cognee/tests/test_deduplication.py +2 -2
  170. cognee/tests/test_deletion.py +58 -17
  171. cognee/tests/test_graph_visualization_permissions.py +161 -0
  172. cognee/tests/test_neptune_analytics_graph.py +309 -0
  173. cognee/tests/test_neptune_analytics_hybrid.py +176 -0
  174. cognee/tests/{test_weaviate.py → test_neptune_analytics_vector.py} +86 -11
  175. cognee/tests/test_pgvector.py +5 -5
  176. cognee/tests/test_s3.py +1 -6
  177. cognee/tests/unit/infrastructure/databases/test_rate_limiter.py +11 -10
  178. cognee/tests/unit/infrastructure/databases/vector/__init__.py +0 -0
  179. cognee/tests/unit/infrastructure/mock_embedding_engine.py +1 -1
  180. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -5
  181. cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py +6 -4
  182. cognee/tests/unit/infrastructure/test_rate_limiting_retry.py +1 -1
  183. cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py +61 -3
  184. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +84 -9
  185. cognee/tests/unit/modules/search/search_methods_test.py +55 -0
  186. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/METADATA +13 -9
  187. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/RECORD +203 -164
  188. cognee/infrastructure/databases/vector/pinecone/adapter.py +0 -8
  189. cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +0 -514
  190. cognee/infrastructure/databases/vector/qdrant/__init__.py +0 -2
  191. cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +0 -527
  192. cognee/infrastructure/databases/vector/weaviate_db/__init__.py +0 -1
  193. cognee/modules/data/extraction/extract_categories.py +0 -14
  194. cognee/tests/test_qdrant.py +0 -99
  195. distributed/Dockerfile +0 -34
  196. distributed/app.py +0 -4
  197. distributed/entrypoint.py +0 -71
  198. distributed/entrypoint.sh +0 -5
  199. distributed/modal_image.py +0 -11
  200. distributed/queues.py +0 -5
  201. distributed/tasks/queued_add_data_points.py +0 -13
  202. distributed/tasks/queued_add_edges.py +0 -13
  203. distributed/tasks/queued_add_nodes.py +0 -13
  204. distributed/test.py +0 -28
  205. distributed/utils.py +0 -19
  206. distributed/workers/data_point_saving_worker.py +0 -93
  207. distributed/workers/graph_saving_worker.py +0 -104
  208. /cognee/infrastructure/databases/{graph/memgraph → hybrid/neptune_analytics}/__init__.py +0 -0
  209. /cognee/infrastructure/{llm → databases/vector/embeddings}/embedding_rate_limiter.py +0 -0
  210. /cognee/infrastructure/{databases/vector/pinecone → llm/structured_output_framework}/__init__.py +0 -0
  211. /cognee/infrastructure/llm/{anthropic → structured_output_framework/baml/baml_src}/__init__.py +0 -0
  212. /cognee/infrastructure/llm/{gemini/__init__.py → structured_output_framework/baml/baml_src/extraction/extract_categories.py} +0 -0
  213. /cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/baml/baml_src/extraction/knowledge_graph}/__init__.py +0 -0
  214. /cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor}/__init__.py +0 -0
  215. /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/__init__.py +0 -0
  216. /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/texts.json +0 -0
  217. /cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm}/__init__.py +0 -0
  218. {distributed → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic}/__init__.py +0 -0
  219. {distributed/tasks → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini}/__init__.py +0 -0
  220. /cognee/modules/data/{extraction/knowledge_graph → methods}/add_model_class_to_graph.py +0 -0
  221. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/WHEEL +0 -0
  222. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/LICENSE +0 -0
  223. {cognee-0.2.1.dev7.dist-info → cognee-0.2.2.dev1.dist-info}/licenses/NOTICE.md +0 -0
@@ -51,6 +51,7 @@ class LanceDBAdapter(VectorDBInterface):
51
51
  self.url = url
52
52
  self.api_key = api_key
53
53
  self.embedding_engine = embedding_engine
54
+ self.VECTOR_DB_LOCK = asyncio.Lock()
54
55
 
55
56
  async def get_connection(self):
56
57
  """
@@ -127,12 +128,14 @@ class LanceDBAdapter(VectorDBInterface):
127
128
  payload: payload_schema
128
129
 
129
130
  if not await self.has_collection(collection_name):
130
- connection = await self.get_connection()
131
- return await connection.create_table(
132
- name=collection_name,
133
- schema=LanceDataPoint,
134
- exist_ok=True,
135
- )
131
+ async with self.VECTOR_DB_LOCK:
132
+ if not await self.has_collection(collection_name):
133
+ connection = await self.get_connection()
134
+ return await connection.create_table(
135
+ name=collection_name,
136
+ schema=LanceDataPoint,
137
+ exist_ok=True,
138
+ )
136
139
 
137
140
  async def get_collection(self, collection_name: str):
138
141
  if not await self.has_collection(collection_name):
@@ -145,10 +148,12 @@ class LanceDBAdapter(VectorDBInterface):
145
148
  payload_schema = type(data_points[0])
146
149
 
147
150
  if not await self.has_collection(collection_name):
148
- await self.create_collection(
149
- collection_name,
150
- payload_schema,
151
- )
151
+ async with self.VECTOR_DB_LOCK:
152
+ if not await self.has_collection(collection_name):
153
+ await self.create_collection(
154
+ collection_name,
155
+ payload_schema,
156
+ )
152
157
 
153
158
  collection = await self.get_collection(collection_name)
154
159
 
@@ -188,12 +193,13 @@ class LanceDBAdapter(VectorDBInterface):
188
193
  for (data_point_index, data_point) in enumerate(data_points)
189
194
  ]
190
195
 
191
- await (
192
- collection.merge_insert("id")
193
- .when_matched_update_all()
194
- .when_not_matched_insert_all()
195
- .execute(lance_data_points)
196
- )
196
+ async with self.VECTOR_DB_LOCK:
197
+ await (
198
+ collection.merge_insert("id")
199
+ .when_matched_update_all()
200
+ .when_not_matched_insert_all()
201
+ .execute(lance_data_points)
202
+ )
197
203
 
198
204
  async def retrieve(self, collection_name: str, data_point_ids: list[str]):
199
205
  collection = await self.get_collection(collection_name)
@@ -54,6 +54,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
54
54
  self.api_key = api_key
55
55
  self.embedding_engine = embedding_engine
56
56
  self.db_uri: str = connection_string
57
+ self.VECTOR_DB_LOCK = asyncio.Lock()
57
58
 
58
59
  relational_db = get_relational_engine()
59
60
 
@@ -124,40 +125,41 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
124
125
  data_point_types = get_type_hints(DataPoint)
125
126
  vector_size = self.embedding_engine.get_vector_size()
126
127
 
127
- if not await self.has_collection(collection_name):
128
-
129
- class PGVectorDataPoint(Base):
130
- """
131
- Represent a point in a vector data space with associated data and vector representation.
132
-
133
- This class inherits from Base and is associated with a database table defined by
134
- __tablename__. It maintains the following public methods and instance variables:
135
-
136
- - __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
137
-
138
- Instance variables:
139
- - id: Identifier for the data point, defined by data_point_types.
140
- - payload: JSON data associated with the data point.
141
- - vector: Vector representation of the data point, with size defined by vector_size.
142
- """
143
-
144
- __tablename__ = collection_name
145
- __table_args__ = {"extend_existing": True}
146
- # PGVector requires one column to be the primary key
147
- id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
148
- payload = Column(JSON)
149
- vector = Column(self.Vector(vector_size))
150
-
151
- def __init__(self, id, payload, vector):
152
- self.id = id
153
- self.payload = payload
154
- self.vector = vector
155
-
156
- async with self.engine.begin() as connection:
157
- if len(Base.metadata.tables.keys()) > 0:
158
- await connection.run_sync(
159
- Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
160
- )
128
+ async with self.VECTOR_DB_LOCK:
129
+ if not await self.has_collection(collection_name):
130
+
131
+ class PGVectorDataPoint(Base):
132
+ """
133
+ Represent a point in a vector data space with associated data and vector representation.
134
+
135
+ This class inherits from Base and is associated with a database table defined by
136
+ __tablename__. It maintains the following public methods and instance variables:
137
+
138
+ - __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
139
+
140
+ Instance variables:
141
+ - id: Identifier for the data point, defined by data_point_types.
142
+ - payload: JSON data associated with the data point.
143
+ - vector: Vector representation of the data point, with size defined by vector_size.
144
+ """
145
+
146
+ __tablename__ = collection_name
147
+ __table_args__ = {"extend_existing": True}
148
+ # PGVector requires one column to be the primary key
149
+ id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
150
+ payload = Column(JSON)
151
+ vector = Column(self.Vector(vector_size))
152
+
153
+ def __init__(self, id, payload, vector):
154
+ self.id = id
155
+ self.payload = payload
156
+ self.vector = vector
157
+
158
+ async with self.engine.begin() as connection:
159
+ if len(Base.metadata.tables.keys()) > 0:
160
+ await connection.run_sync(
161
+ Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
162
+ )
161
163
 
162
164
  @retry(
163
165
  retry=retry_if_exception_type(DeadlockDetectedError),
@@ -1,4 +1,4 @@
1
- from typing import List, Protocol, Optional
1
+ from typing import List, Protocol, Optional, Union, Any
2
2
  from abc import abstractmethod
3
3
  from cognee.infrastructure.engine import DataPoint
4
4
  from .models.PayloadSchema import PayloadSchema
@@ -31,7 +31,7 @@ class VectorDBInterface(Protocol):
31
31
  async def create_collection(
32
32
  self,
33
33
  collection_name: str,
34
- payload_schema: Optional[PayloadSchema] = None,
34
+ payload_schema: Optional[Any] = None,
35
35
  ):
36
36
  """
37
37
  Create a new collection with an optional payload schema.
@@ -40,8 +40,8 @@ class VectorDBInterface(Protocol):
40
40
  -----------
41
41
 
42
42
  - collection_name (str): The name of the new collection to create.
43
- - payload_schema (Optional[PayloadSchema]): An optional schema for the payloads
44
- within this collection. (default None)
43
+ - payload_schema (Optional[Any]): An optional schema for the payloads
44
+ within this collection. Can be PayloadSchema, BaseModel, or other schema types. (default None)
45
45
  """
46
46
  raise NotImplementedError
47
47
 
@@ -71,7 +71,7 @@ class VectorDBInterface(Protocol):
71
71
 
72
72
  - collection_name (str): The name of the collection from which to retrieve data
73
73
  points.
74
- - data_point_ids (list[str]): A list of IDs of the data points to retrieve.
74
+ - data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to retrieve.
75
75
  """
76
76
  raise NotImplementedError
77
77
 
@@ -123,7 +123,9 @@ class VectorDBInterface(Protocol):
123
123
  raise NotImplementedError
124
124
 
125
125
  @abstractmethod
126
- async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
126
+ async def delete_data_points(
127
+ self, collection_name: str, data_point_ids: Union[List[str], list[str]]
128
+ ):
127
129
  """
128
130
  Delete specified data points from a collection.
129
131
 
@@ -132,7 +134,7 @@ class VectorDBInterface(Protocol):
132
134
 
133
135
  - collection_name (str): The name of the collection from which to delete data
134
136
  points.
135
- - data_point_ids (list[str]): A list of IDs of the data points to delete.
137
+ - data_point_ids (Union[List[str], list[str]]): A list of IDs of the data points to delete.
136
138
  """
137
139
  raise NotImplementedError
138
140
 
@@ -142,3 +144,72 @@ class VectorDBInterface(Protocol):
142
144
  Remove obsolete or unnecessary data from the database.
143
145
  """
144
146
  raise NotImplementedError
147
+
148
+ @abstractmethod
149
+ async def embed_data(self, data: List[str]) -> List[List[float]]:
150
+ """
151
+ Embed textual data into vector representations.
152
+
153
+ Parameters:
154
+ -----------
155
+
156
+ - data (List[str]): A list of strings to be embedded.
157
+
158
+ Returns:
159
+ --------
160
+
161
+ - List[List[float]]: A list of embedded vectors corresponding to the input data.
162
+ """
163
+ raise NotImplementedError
164
+
165
+ # Optional methods that may be implemented by adapters
166
+ async def get_connection(self):
167
+ """
168
+ Get a connection to the vector database.
169
+ This method is optional and may return None for adapters that don't use connections.
170
+ """
171
+ return None
172
+
173
+ async def get_collection(self, collection_name: str):
174
+ """
175
+ Get a collection object from the vector database.
176
+ This method is optional and may return None for adapters that don't expose collection objects.
177
+ """
178
+ return None
179
+
180
+ async def create_vector_index(self, index_name: str, index_property_name: str):
181
+ """
182
+ Create a vector index for improved search performance.
183
+ This method is optional and may be a no-op for adapters that don't support indexing.
184
+ """
185
+ pass
186
+
187
+ async def index_data_points(
188
+ self, index_name: str, index_property_name: str, data_points: List[DataPoint]
189
+ ):
190
+ """
191
+ Index data points for improved search performance.
192
+ This method is optional and may be a no-op for adapters that don't support separate indexing.
193
+
194
+ Parameters:
195
+ -----------
196
+ - index_name (str): Name of the index to create/update
197
+ - index_property_name (str): Property name to index on
198
+ - data_points (List[DataPoint]): Data points to index
199
+ """
200
+ pass
201
+
202
+ def get_data_point_schema(self, model_type: Any) -> Any:
203
+ """
204
+ Get or transform a data point schema for the specific vector database.
205
+ This method is optional and may return the input unchanged for simple adapters.
206
+
207
+ Parameters:
208
+ -----------
209
+ - model_type (Any): The model type to get schema for
210
+
211
+ Returns:
212
+ --------
213
+ - Any: The schema object suitable for this vector database
214
+ """
215
+ return model_type
@@ -0,0 +1,39 @@
1
+ import os
2
+ from urllib.parse import urlparse
3
+
4
+
5
+ def get_data_file_path(file_path: str):
6
+ # Check if this is a file URI BEFORE normalizing (which corrupts URIs)
7
+ if file_path.startswith("file://"):
8
+ # Normalize the file URI for Windows - replace backslashes with forward slashes
9
+ normalized_file_uri = os.path.normpath(file_path)
10
+
11
+ parsed_url = urlparse(normalized_file_uri)
12
+
13
+ # Convert URI path to file system path
14
+ if os.name == "nt": # Windows
15
+ # Handle Windows drive letters correctly
16
+ fs_path = parsed_url.path
17
+ if fs_path.startswith("/") and len(fs_path) > 1 and fs_path[2] == ":":
18
+ fs_path = fs_path[1:] # Remove leading slash for Windows drive paths
19
+ else: # Unix-like systems
20
+ fs_path = parsed_url.path
21
+
22
+ # Now split the actual filesystem path
23
+ actual_fs_path = os.path.normpath(fs_path)
24
+ return actual_fs_path
25
+
26
+ elif file_path.startswith("s3://"):
27
+ # Handle S3 URLs without normalization (which corrupts them)
28
+ parsed_url = urlparse(file_path)
29
+
30
+ normalized_url = (
31
+ f"s3://{parsed_url.netloc}{os.sep}{os.path.normpath(parsed_url.path).lstrip(os.sep)}"
32
+ )
33
+
34
+ return normalized_url
35
+
36
+ else:
37
+ # Regular file path - normalize separators
38
+ normalized_path = os.path.normpath(file_path)
39
+ return normalized_path
@@ -109,8 +109,8 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
109
109
  """
110
110
  Guess the file type from the given binary file stream.
111
111
 
112
- If the file type cannot be determined, raise a FileTypeException with an appropriate
113
- message.
112
+ If the file type cannot be determined from content, attempts to infer from extension.
113
+ If still unable to determine, raise a FileTypeException with an appropriate message.
114
114
 
115
115
  Parameters:
116
116
  -----------
@@ -3,6 +3,7 @@ from os import path
3
3
  from urllib.parse import urlparse
4
4
  from contextlib import asynccontextmanager
5
5
 
6
+ from cognee.infrastructure.files.utils.get_data_file_path import get_data_file_path
6
7
  from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
7
8
  from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
8
9
 
@@ -11,22 +12,8 @@ from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorag
11
12
  async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None, **kwargs):
12
13
  # Check if this is a file URI BEFORE normalizing (which corrupts URIs)
13
14
  if file_path.startswith("file://"):
14
- # Normalize the file URI for Windows - replace backslashes with forward slashes
15
- normalized_file_uri = os.path.normpath(file_path)
16
-
17
- parsed_url = urlparse(normalized_file_uri)
18
-
19
- # Convert URI path to file system path
20
- if os.name == "nt": # Windows
21
- # Handle Windows drive letters correctly
22
- fs_path = parsed_url.path
23
- if fs_path.startswith("/") and len(fs_path) > 1 and fs_path[2] == ":":
24
- fs_path = fs_path[1:] # Remove leading slash for Windows drive paths
25
- else: # Unix-like systems
26
- fs_path = parsed_url.path
27
-
28
15
  # Now split the actual filesystem path
29
- actual_fs_path = os.path.normpath(fs_path)
16
+ actual_fs_path = get_data_file_path(file_path)
30
17
  file_dir_path = path.dirname(actual_fs_path)
31
18
  file_name = path.basename(actual_fs_path)
32
19
 
@@ -36,13 +23,7 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
36
23
  yield file
37
24
 
38
25
  elif file_path.startswith("s3://"):
39
- # Handle S3 URLs without normalization (which corrupts them)
40
- parsed_url = urlparse(file_path)
41
-
42
- normalized_url = (
43
- f"s3://{parsed_url.netloc}{os.sep}{os.path.normpath(parsed_url.path).lstrip(os.sep)}"
44
- )
45
-
26
+ normalized_url = get_data_file_path(file_path)
46
27
  s3_dir_path = os.path.dirname(normalized_url)
47
28
  s3_filename = os.path.basename(normalized_url)
48
29
 
@@ -66,7 +47,7 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
66
47
 
67
48
  else:
68
49
  # Regular file path - normalize separators
69
- normalized_path = os.path.normpath(file_path)
50
+ normalized_path = get_data_file_path(file_path)
70
51
  file_dir_path = path.dirname(normalized_path)
71
52
  file_name = path.basename(normalized_path)
72
53
 
@@ -0,0 +1,137 @@
1
+ from typing import Type
2
+ from pydantic import BaseModel
3
+ from typing import Coroutine
4
+ from cognee.infrastructure.llm import get_llm_config
5
+
6
+
7
+ class LLMGateway:
8
+ """
9
+ Class handles selection of structured output frameworks and LLM functions.
10
+ Class used as a namespace for LLM related functions, should not be instantiated, all methods are static.
11
+ """
12
+
13
+ @staticmethod
14
+ def render_prompt(filename: str, context: dict, base_directory: str = None):
15
+ from cognee.infrastructure.llm.prompts import render_prompt
16
+
17
+ return render_prompt(filename=filename, context=context, base_directory=base_directory)
18
+
19
+ @staticmethod
20
+ def acreate_structured_output(
21
+ text_input: str, system_prompt: str, response_model: Type[BaseModel]
22
+ ) -> Coroutine:
23
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
24
+ get_llm_client,
25
+ )
26
+
27
+ llm_client = get_llm_client()
28
+ return llm_client.acreate_structured_output(
29
+ text_input=text_input, system_prompt=system_prompt, response_model=response_model
30
+ )
31
+
32
+ @staticmethod
33
+ def create_structured_output(
34
+ text_input: str, system_prompt: str, response_model: Type[BaseModel]
35
+ ) -> BaseModel:
36
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
37
+ get_llm_client,
38
+ )
39
+
40
+ llm_client = get_llm_client()
41
+ return llm_client.create_structured_output(
42
+ text_input=text_input, system_prompt=system_prompt, response_model=response_model
43
+ )
44
+
45
+ @staticmethod
46
+ def create_transcript(input) -> Coroutine:
47
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
48
+ get_llm_client,
49
+ )
50
+
51
+ llm_client = get_llm_client()
52
+ return llm_client.create_transcript(input=input)
53
+
54
+ @staticmethod
55
+ def transcribe_image(input) -> Coroutine:
56
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
57
+ get_llm_client,
58
+ )
59
+
60
+ llm_client = get_llm_client()
61
+ return llm_client.transcribe_image(input=input)
62
+
63
+ @staticmethod
64
+ def show_prompt(text_input: str, system_prompt: str) -> str:
65
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
66
+ get_llm_client,
67
+ )
68
+
69
+ llm_client = get_llm_client()
70
+ return llm_client.show_prompt(text_input=text_input, system_prompt=system_prompt)
71
+
72
+ @staticmethod
73
+ def read_query_prompt(prompt_file_name: str, base_directory: str = None):
74
+ from cognee.infrastructure.llm.prompts import (
75
+ read_query_prompt,
76
+ )
77
+
78
+ return read_query_prompt(prompt_file_name=prompt_file_name, base_directory=base_directory)
79
+
80
+ @staticmethod
81
+ def extract_content_graph(
82
+ content: str, response_model: Type[BaseModel], mode: str = "simple"
83
+ ) -> Coroutine:
84
+ llm_config = get_llm_config()
85
+ if llm_config.structured_output_framework.upper() == "BAML":
86
+ from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
87
+ extract_content_graph,
88
+ )
89
+
90
+ return extract_content_graph(content=content, response_model=response_model, mode=mode)
91
+ else:
92
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
93
+ extract_content_graph,
94
+ )
95
+
96
+ return extract_content_graph(content=content, response_model=response_model)
97
+
98
+ @staticmethod
99
+ def extract_categories(content: str, response_model: Type[BaseModel]) -> Coroutine:
100
+ # TODO: Add BAML version of category and extraction and update function
101
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
102
+ extract_categories,
103
+ )
104
+
105
+ return extract_categories(content=content, response_model=response_model)
106
+
107
+ @staticmethod
108
+ def extract_code_summary(content: str) -> Coroutine:
109
+ llm_config = get_llm_config()
110
+ if llm_config.structured_output_framework.upper() == "BAML":
111
+ from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
112
+ extract_code_summary,
113
+ )
114
+
115
+ return extract_code_summary(content=content)
116
+ else:
117
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
118
+ extract_code_summary,
119
+ )
120
+
121
+ return extract_code_summary(content=content)
122
+
123
+ @staticmethod
124
+ def extract_summary(content: str, response_model: Type[BaseModel]) -> Coroutine:
125
+ llm_config = get_llm_config()
126
+ if llm_config.structured_output_framework.upper() == "BAML":
127
+ from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction import (
128
+ extract_summary,
129
+ )
130
+
131
+ return extract_summary(content=content, response_model=response_model)
132
+ else:
133
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.extraction import (
134
+ extract_summary,
135
+ )
136
+
137
+ return extract_summary(content=content, response_model=response_model)
@@ -1,4 +1,14 @@
1
- from .config import get_llm_config
2
- from .utils import get_max_chunk_tokens
3
- from .utils import test_llm_connection
4
- from .utils import test_embedding_connection
1
+ from cognee.infrastructure.llm.config import (
2
+ get_llm_config,
3
+ )
4
+ from cognee.infrastructure.llm.utils import (
5
+ get_max_chunk_tokens,
6
+ )
7
+ from cognee.infrastructure.llm.utils import (
8
+ test_llm_connection,
9
+ )
10
+ from cognee.infrastructure.llm.utils import (
11
+ test_embedding_connection,
12
+ )
13
+
14
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
@@ -1,8 +1,9 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Optional, ClassVar
3
3
  from functools import lru_cache
4
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
5
  from pydantic import model_validator
6
+ from baml_py import ClientRegistry
6
7
 
7
8
 
8
9
  class LLMConfig(BaseSettings):
@@ -32,6 +33,7 @@ class LLMConfig(BaseSettings):
32
33
  - to_dict
33
34
  """
34
35
 
36
+ structured_output_framework: str = "instructor"
35
37
  llm_provider: str = "openai"
36
38
  llm_model: str = "gpt-4o-mini"
37
39
  llm_endpoint: str = ""
@@ -40,6 +42,14 @@ class LLMConfig(BaseSettings):
40
42
  llm_temperature: float = 0.0
41
43
  llm_streaming: bool = False
42
44
  llm_max_tokens: int = 16384
45
+
46
+ baml_llm_provider: str = "openai"
47
+ baml_llm_model: str = "gpt-4o-mini"
48
+ baml_llm_endpoint: str = ""
49
+ baml_llm_api_key: Optional[str] = None
50
+ baml_llm_temperature: float = 0.0
51
+ baml_llm_api_version: str = ""
52
+
43
53
  transcription_model: str = "whisper-1"
44
54
  graph_prompt_path: str = "generate_graph_prompt.txt"
45
55
  llm_rate_limit_enabled: bool = False
@@ -53,8 +63,26 @@ class LLMConfig(BaseSettings):
53
63
  fallback_endpoint: str = ""
54
64
  fallback_model: str = ""
55
65
 
66
+ baml_registry: ClassVar[ClientRegistry] = ClientRegistry()
67
+
56
68
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
57
69
 
70
+ def model_post_init(self, __context) -> None:
71
+ """Initialize the BAML registry after the model is created."""
72
+ self.baml_registry.add_llm_client(
73
+ name=self.baml_llm_provider,
74
+ provider=self.baml_llm_provider,
75
+ options={
76
+ "model": self.baml_llm_model,
77
+ "temperature": self.baml_llm_temperature,
78
+ "api_key": self.baml_llm_api_key,
79
+ "base_url": self.baml_llm_endpoint,
80
+ "api_version": self.baml_llm_api_version,
81
+ },
82
+ )
83
+ # Sets the primary client
84
+ self.baml_registry.set_primary(self.baml_llm_provider)
85
+
58
86
  @model_validator(mode="after")
59
87
  def ensure_env_vars_for_ollama(self) -> "LLMConfig":
60
88
  """
@@ -1,2 +1,2 @@
1
1
  Answer the question using the provided context. Be as brief as possible.
2
- Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.
2
+ Each entry in the context is a paragraph, which is represented as a list with two elements [title, sentences] and sentences is a list of strings.
@@ -1,2 +1,2 @@
1
1
  Answer the question using the provided context. Be as brief as possible.
2
- Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.
2
+ Each entry in the context is tuple of length 3, representing an edge of a knowledge graph with its two nodes.
@@ -1 +1 @@
1
- Answer the question using the provided context. Be as brief as possible.
1
+ Answer the question using the provided context. Be as brief as possible.
@@ -1 +1 @@
1
- Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.
1
+ Answer the question using the provided context. If the provided context is not connected to the question, just answer "The provided knowledge base does not contain the answer to the question". Be as brief as possible.
@@ -1,2 +1,2 @@
1
1
  Chose the summary that is the most relevant to the query`{{ query }}`
2
- Here are the categories:`{{ categories }}`
2
+ Here are the categories:`{{ categories }}`
@@ -174,4 +174,4 @@ The possible classifications are:
174
174
  "Recipes and crafting instructions"
175
175
  ]
176
176
  }
177
- }
177
+ }
@@ -1,2 +1,2 @@
1
1
  The question is: `{{ question }}`
2
- And here is the context: `{{ context }}`
2
+ And here is the context: `{{ context }}`
@@ -1,2 +1,2 @@
1
1
  The question is: `{{ question }}`
2
- and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`
2
+ and here is the context provided with a set of relationships from a knowledge graph separated by \n---\n each represented as node1 -- relation -- node2 triplet: `{{ context }}`
@@ -63,4 +63,4 @@ This queries doesn't work. Do NOT use them:
63
63
  Example 1:
64
64
  Get all nodes connected to John
65
65
  MATCH (n:Entity {'name': 'John'})--(neighbor)
66
- RETURN n, neighbor
66
+ RETURN n, neighbor
@@ -1,2 +1,2 @@
1
1
  I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.
2
- Please respond with a single patch file in the following format.
2
+ Please respond with a single patch file in the following format.