cognee 0.3.4.dev3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. cognee/api/client.py +16 -7
  2. cognee/api/health.py +5 -9
  3. cognee/api/v1/add/add.py +3 -1
  4. cognee/api/v1/cognify/cognify.py +44 -7
  5. cognee/api/v1/permissions/routers/get_permissions_router.py +8 -4
  6. cognee/api/v1/search/search.py +3 -0
  7. cognee/api/v1/ui/__init__.py +1 -1
  8. cognee/api/v1/ui/ui.py +215 -150
  9. cognee/api/v1/update/__init__.py +1 -0
  10. cognee/api/v1/update/routers/__init__.py +1 -0
  11. cognee/api/v1/update/routers/get_update_router.py +90 -0
  12. cognee/api/v1/update/update.py +100 -0
  13. cognee/base_config.py +5 -2
  14. cognee/cli/_cognee.py +28 -10
  15. cognee/cli/commands/delete_command.py +34 -2
  16. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +2 -2
  17. cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +3 -2
  18. cognee/eval_framework/modal_eval_dashboard.py +9 -1
  19. cognee/infrastructure/databases/graph/config.py +9 -9
  20. cognee/infrastructure/databases/graph/get_graph_engine.py +4 -21
  21. cognee/infrastructure/databases/graph/kuzu/adapter.py +60 -9
  22. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +3 -3
  23. cognee/infrastructure/databases/relational/config.py +4 -4
  24. cognee/infrastructure/databases/relational/create_relational_engine.py +11 -3
  25. cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +7 -3
  26. cognee/infrastructure/databases/vector/config.py +7 -7
  27. cognee/infrastructure/databases/vector/create_vector_engine.py +7 -15
  28. cognee/infrastructure/databases/vector/embeddings/EmbeddingEngine.py +9 -0
  29. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +11 -0
  30. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +19 -2
  31. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -0
  32. cognee/infrastructure/databases/vector/embeddings/config.py +8 -0
  33. cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +5 -0
  34. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +11 -10
  35. cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +48 -38
  36. cognee/infrastructure/databases/vector/vector_db_interface.py +8 -4
  37. cognee/infrastructure/files/storage/S3FileStorage.py +15 -5
  38. cognee/infrastructure/files/storage/s3_config.py +1 -0
  39. cognee/infrastructure/files/utils/open_data_file.py +7 -14
  40. cognee/infrastructure/llm/LLMGateway.py +19 -117
  41. cognee/infrastructure/llm/config.py +28 -13
  42. cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_categories.py +2 -1
  43. cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_event_entities.py +3 -2
  44. cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/extract_summary.py +3 -2
  45. cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/extract_content_graph.py +2 -1
  46. cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/extract_event_graph.py +3 -2
  47. cognee/infrastructure/llm/prompts/read_query_prompt.py +3 -2
  48. cognee/infrastructure/llm/prompts/show_prompt.py +35 -0
  49. cognee/infrastructure/llm/prompts/test.txt +1 -0
  50. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +2 -2
  51. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +50 -397
  52. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +2 -3
  53. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +8 -88
  54. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +78 -0
  55. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +2 -99
  56. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +49 -401
  57. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +19 -882
  58. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +2 -34
  59. cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +2 -107
  60. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/acreate_structured_output.baml +26 -0
  61. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/__init__.py +1 -2
  62. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +76 -0
  63. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/create_dynamic_baml_type.py +122 -0
  64. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +3 -3
  65. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +0 -32
  66. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +107 -98
  67. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +5 -6
  68. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -6
  69. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +0 -26
  70. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +17 -67
  71. cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +8 -7
  72. cognee/infrastructure/llm/utils.py +4 -4
  73. cognee/infrastructure/loaders/LoaderEngine.py +5 -2
  74. cognee/infrastructure/loaders/external/__init__.py +7 -0
  75. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +244 -0
  76. cognee/infrastructure/loaders/supported_loaders.py +7 -0
  77. cognee/modules/data/methods/create_authorized_dataset.py +9 -0
  78. cognee/modules/data/methods/get_authorized_dataset.py +1 -1
  79. cognee/modules/data/methods/get_authorized_dataset_by_name.py +11 -0
  80. cognee/modules/data/methods/get_deletion_counts.py +92 -0
  81. cognee/modules/graph/cognee_graph/CogneeGraph.py +1 -1
  82. cognee/modules/graph/utils/expand_with_nodes_and_edges.py +22 -8
  83. cognee/modules/graph/utils/retrieve_existing_edges.py +0 -2
  84. cognee/modules/ingestion/data_types/TextData.py +0 -1
  85. cognee/modules/notebooks/methods/create_notebook.py +3 -1
  86. cognee/modules/notebooks/methods/get_notebooks.py +27 -1
  87. cognee/modules/observability/get_observe.py +14 -0
  88. cognee/modules/observability/observers.py +1 -0
  89. cognee/modules/ontology/base_ontology_resolver.py +42 -0
  90. cognee/modules/ontology/get_default_ontology_resolver.py +41 -0
  91. cognee/modules/ontology/matching_strategies.py +53 -0
  92. cognee/modules/ontology/models.py +20 -0
  93. cognee/modules/ontology/ontology_config.py +24 -0
  94. cognee/modules/ontology/ontology_env_config.py +45 -0
  95. cognee/modules/ontology/rdf_xml/{OntologyResolver.py → RDFLibOntologyResolver.py} +20 -28
  96. cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +21 -24
  97. cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +3 -3
  98. cognee/modules/retrieval/code_retriever.py +2 -1
  99. cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -4
  100. cognee/modules/retrieval/graph_completion_cot_retriever.py +6 -5
  101. cognee/modules/retrieval/graph_completion_retriever.py +0 -3
  102. cognee/modules/retrieval/insights_retriever.py +1 -1
  103. cognee/modules/retrieval/jaccard_retrival.py +60 -0
  104. cognee/modules/retrieval/lexical_retriever.py +123 -0
  105. cognee/modules/retrieval/natural_language_retriever.py +2 -1
  106. cognee/modules/retrieval/temporal_retriever.py +3 -2
  107. cognee/modules/retrieval/utils/brute_force_triplet_search.py +2 -12
  108. cognee/modules/retrieval/utils/completion.py +4 -7
  109. cognee/modules/search/methods/get_search_type_tools.py +7 -0
  110. cognee/modules/search/methods/no_access_control_search.py +1 -1
  111. cognee/modules/search/methods/search.py +32 -13
  112. cognee/modules/search/types/SearchType.py +1 -0
  113. cognee/modules/users/methods/create_user.py +0 -2
  114. cognee/modules/users/permissions/methods/authorized_give_permission_on_datasets.py +12 -0
  115. cognee/modules/users/permissions/methods/check_permission_on_dataset.py +11 -0
  116. cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +10 -0
  117. cognee/modules/users/permissions/methods/get_document_ids_for_user.py +10 -0
  118. cognee/modules/users/permissions/methods/get_principal.py +9 -0
  119. cognee/modules/users/permissions/methods/get_principal_datasets.py +11 -0
  120. cognee/modules/users/permissions/methods/get_role.py +10 -0
  121. cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +3 -3
  122. cognee/modules/users/permissions/methods/get_tenant.py +9 -0
  123. cognee/modules/users/permissions/methods/give_default_permission_to_role.py +9 -0
  124. cognee/modules/users/permissions/methods/give_default_permission_to_tenant.py +9 -0
  125. cognee/modules/users/permissions/methods/give_default_permission_to_user.py +9 -0
  126. cognee/modules/users/permissions/methods/give_permission_on_dataset.py +10 -0
  127. cognee/modules/users/roles/methods/add_user_to_role.py +11 -0
  128. cognee/modules/users/roles/methods/create_role.py +12 -1
  129. cognee/modules/users/tenants/methods/add_user_to_tenant.py +12 -0
  130. cognee/modules/users/tenants/methods/create_tenant.py +12 -1
  131. cognee/modules/visualization/cognee_network_visualization.py +13 -9
  132. cognee/shared/data_models.py +0 -1
  133. cognee/shared/utils.py +0 -32
  134. cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
  135. cognee/tasks/codingagents/coding_rule_associations.py +3 -2
  136. cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +3 -2
  137. cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +3 -2
  138. cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +3 -2
  139. cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +3 -2
  140. cognee/tasks/graph/extract_graph_from_code.py +2 -2
  141. cognee/tasks/graph/extract_graph_from_data.py +55 -12
  142. cognee/tasks/graph/extract_graph_from_data_v2.py +16 -4
  143. cognee/tasks/ingestion/migrate_relational_database.py +132 -41
  144. cognee/tasks/ingestion/resolve_data_directories.py +4 -1
  145. cognee/tasks/schema/ingest_database_schema.py +134 -0
  146. cognee/tasks/schema/models.py +40 -0
  147. cognee/tasks/storage/index_data_points.py +1 -1
  148. cognee/tasks/storage/index_graph_edges.py +3 -1
  149. cognee/tasks/summarization/summarize_code.py +2 -2
  150. cognee/tasks/summarization/summarize_text.py +2 -2
  151. cognee/tasks/temporal_graph/enrich_events.py +2 -2
  152. cognee/tasks/temporal_graph/extract_events_and_entities.py +2 -2
  153. cognee/tests/cli_tests/cli_unit_tests/test_cli_commands.py +13 -4
  154. cognee/tests/cli_tests/cli_unit_tests/test_cli_edge_cases.py +13 -3
  155. cognee/tests/test_advanced_pdf_loader.py +141 -0
  156. cognee/tests/test_chromadb.py +40 -0
  157. cognee/tests/test_cognee_server_start.py +6 -1
  158. cognee/tests/test_data/Quantum_computers.txt +9 -0
  159. cognee/tests/test_lancedb.py +211 -0
  160. cognee/tests/test_pgvector.py +40 -0
  161. cognee/tests/test_relational_db_migration.py +76 -0
  162. cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +2 -1
  163. cognee/tests/unit/modules/ontology/test_ontology_adapter.py +330 -13
  164. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +0 -4
  165. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -4
  166. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +0 -4
  167. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/METADATA +92 -96
  168. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/RECORD +176 -162
  169. distributed/pyproject.toml +0 -1
  170. cognee/infrastructure/data/utils/extract_keywords.py +0 -48
  171. cognee/infrastructure/databases/hybrid/falkordb/FalkorDBAdapter.py +0 -1227
  172. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +0 -109
  173. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +0 -343
  174. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_categories.py +0 -0
  175. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +0 -89
  176. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/__init__.py +0 -0
  177. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +0 -44
  178. cognee/tasks/graph/infer_data_ontology.py +0 -309
  179. cognee/tests/test_falkordb.py +0 -174
  180. /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/__init__.py +0 -0
  181. /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/knowledge_graph/__init__.py +0 -0
  182. /cognee/infrastructure/llm/{structured_output_framework/litellm_instructor/extraction → extraction}/texts.json +0 -0
  183. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/WHEEL +0 -0
  184. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/entry_points.txt +0 -0
  185. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/licenses/LICENSE +0 -0
  186. {cognee-0.3.4.dev3.dist-info → cognee-0.3.5.dist-info}/licenses/NOTICE.md +0 -0
@@ -19,8 +19,7 @@ def create_vector_engine(
19
19
  for each provider, raising an EnvironmentError if any are missing, or ImportError if the
20
20
  ChromaDB package is not installed.
21
21
 
22
- Supported providers include: pgvector, FalkorDB, ChromaDB, and
23
- LanceDB.
22
+ Supported providers include: pgvector, ChromaDB, and LanceDB.
24
23
 
25
24
  Parameters:
26
25
  -----------
@@ -66,7 +65,12 @@ def create_vector_engine(
66
65
  f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
67
66
  )
68
67
 
69
- from .pgvector.PGVectorAdapter import PGVectorAdapter
68
+ try:
69
+ from .pgvector.PGVectorAdapter import PGVectorAdapter
70
+ except ImportError:
71
+ raise ImportError(
72
+ "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PGVector functionality."
73
+ )
70
74
 
71
75
  return PGVectorAdapter(
72
76
  connection_string,
@@ -74,18 +78,6 @@ def create_vector_engine(
74
78
  embedding_engine,
75
79
  )
76
80
 
77
- elif vector_db_provider == "falkordb":
78
- if not (vector_db_url and vector_db_port):
79
- raise EnvironmentError("Missing requred FalkorDB credentials!")
80
-
81
- from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
82
-
83
- return FalkorDBAdapter(
84
- database_url=vector_db_url,
85
- database_port=vector_db_port,
86
- embedding_engine=embedding_engine,
87
- )
88
-
89
81
  elif vector_db_provider == "chromadb":
90
82
  try:
91
83
  import chromadb
@@ -34,3 +34,12 @@ class EmbeddingEngine(Protocol):
34
34
  - int: An integer representing the number of dimensions in the embedding vector.
35
35
  """
36
36
  raise NotImplementedError()
37
+
38
+ def get_batch_size(self) -> int:
39
+ """
40
+ Return the desired batch size for embedding calls
41
+
42
+ Returns:
43
+
44
+ """
45
+ raise NotImplementedError()
@@ -42,11 +42,13 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
42
42
  model: Optional[str] = "openai/text-embedding-3-large",
43
43
  dimensions: Optional[int] = 3072,
44
44
  max_completion_tokens: int = 512,
45
+ batch_size: int = 100,
45
46
  ):
46
47
  self.model = model
47
48
  self.dimensions = dimensions
48
49
  self.max_completion_tokens = max_completion_tokens
49
50
  self.tokenizer = self.get_tokenizer()
51
+ self.batch_size = batch_size
50
52
  # self.retry_count = 0
51
53
  self.embedding_model = TextEmbedding(model_name=model)
52
54
 
@@ -101,6 +103,15 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
101
103
  """
102
104
  return self.dimensions
103
105
 
106
+ def get_batch_size(self) -> int:
107
+ """
108
+ Return the desired batch size for embedding calls
109
+
110
+ Returns:
111
+
112
+ """
113
+ return self.batch_size
114
+
104
115
  def get_tokenizer(self):
105
116
  """
106
117
  Instantiate and return the tokenizer used for preparing text for embedding.
@@ -58,6 +58,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
58
58
  endpoint: str = None,
59
59
  api_version: str = None,
60
60
  max_completion_tokens: int = 512,
61
+ batch_size: int = 100,
61
62
  ):
62
63
  self.api_key = api_key
63
64
  self.endpoint = endpoint
@@ -68,6 +69,7 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
68
69
  self.max_completion_tokens = max_completion_tokens
69
70
  self.tokenizer = self.get_tokenizer()
70
71
  self.retry_count = 0
72
+ self.batch_size = batch_size
71
73
 
72
74
  enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
73
75
  if isinstance(enable_mocking, bool):
@@ -165,6 +167,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
165
167
  """
166
168
  return self.dimensions
167
169
 
170
+ def get_batch_size(self) -> int:
171
+ """
172
+ Return the desired batch size for embedding calls
173
+
174
+ Returns:
175
+
176
+ """
177
+ return self.batch_size
178
+
168
179
  def get_tokenizer(self):
169
180
  """
170
181
  Load and return the appropriate tokenizer for the specified model based on the provider.
@@ -183,9 +194,15 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
183
194
  model=model, max_completion_tokens=self.max_completion_tokens
184
195
  )
185
196
  elif "gemini" in self.provider.lower():
186
- tokenizer = GeminiTokenizer(
187
- model=model, max_completion_tokens=self.max_completion_tokens
197
+ # Since Gemini tokenization needs to send an API request to get the token count we will use TikToken to
198
+ # count tokens as we calculate tokens word by word
199
+ tokenizer = TikTokenTokenizer(
200
+ model=None, max_completion_tokens=self.max_completion_tokens
188
201
  )
202
+ # Note: Gemini Tokenizer expects an LLM model as input and not the embedding model
203
+ # tokenizer = GeminiTokenizer(
204
+ # llm_model=llm_model, max_completion_tokens=self.max_completion_tokens
205
+ # )
189
206
  elif "mistral" in self.provider.lower():
190
207
  tokenizer = MistralTokenizer(
191
208
  model=model, max_completion_tokens=self.max_completion_tokens
@@ -54,12 +54,14 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
54
54
  max_completion_tokens: int = 512,
55
55
  endpoint: Optional[str] = "http://localhost:11434/api/embeddings",
56
56
  huggingface_tokenizer: str = "Salesforce/SFR-Embedding-Mistral",
57
+ batch_size: int = 100,
57
58
  ):
58
59
  self.model = model
59
60
  self.dimensions = dimensions
60
61
  self.max_completion_tokens = max_completion_tokens
61
62
  self.endpoint = endpoint
62
63
  self.huggingface_tokenizer_name = huggingface_tokenizer
64
+ self.batch_size = batch_size
63
65
  self.tokenizer = self.get_tokenizer()
64
66
 
65
67
  enable_mocking = os.getenv("MOCK_EMBEDDING", "false")
@@ -122,6 +124,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
122
124
  """
123
125
  return self.dimensions
124
126
 
127
+ def get_batch_size(self) -> int:
128
+ """
129
+ Return the desired batch size for embedding calls
130
+
131
+ Returns:
132
+
133
+ """
134
+ return self.batch_size
135
+
125
136
  def get_tokenizer(self):
126
137
  """
127
138
  Load and return a HuggingFace tokenizer for the embedding engine.
@@ -19,9 +19,17 @@ class EmbeddingConfig(BaseSettings):
19
19
  embedding_api_key: Optional[str] = None
20
20
  embedding_api_version: Optional[str] = None
21
21
  embedding_max_completion_tokens: Optional[int] = 8191
22
+ embedding_batch_size: Optional[int] = None
22
23
  huggingface_tokenizer: Optional[str] = None
23
24
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
24
25
 
26
+ def model_post_init(self, __context) -> None:
27
+ # If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
28
+ if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
29
+ self.embedding_batch_size = 2048
30
+ elif not self.embedding_batch_size:
31
+ self.embedding_batch_size = 100
32
+
25
33
  def to_dict(self) -> dict:
26
34
  """
27
35
  Serialize all embedding configuration settings to a dictionary.
@@ -31,6 +31,7 @@ def get_embedding_engine() -> EmbeddingEngine:
31
31
  config.embedding_endpoint,
32
32
  config.embedding_api_key,
33
33
  config.embedding_api_version,
34
+ config.embedding_batch_size,
34
35
  config.huggingface_tokenizer,
35
36
  llm_config.llm_api_key,
36
37
  llm_config.llm_provider,
@@ -46,6 +47,7 @@ def create_embedding_engine(
46
47
  embedding_endpoint,
47
48
  embedding_api_key,
48
49
  embedding_api_version,
50
+ embedding_batch_size,
49
51
  huggingface_tokenizer,
50
52
  llm_api_key,
51
53
  llm_provider,
@@ -84,6 +86,7 @@ def create_embedding_engine(
84
86
  model=embedding_model,
85
87
  dimensions=embedding_dimensions,
86
88
  max_completion_tokens=embedding_max_completion_tokens,
89
+ batch_size=embedding_batch_size,
87
90
  )
88
91
 
89
92
  if embedding_provider == "ollama":
@@ -95,6 +98,7 @@ def create_embedding_engine(
95
98
  max_completion_tokens=embedding_max_completion_tokens,
96
99
  endpoint=embedding_endpoint,
97
100
  huggingface_tokenizer=huggingface_tokenizer,
101
+ batch_size=embedding_batch_size,
98
102
  )
99
103
 
100
104
  from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
@@ -108,4 +112,5 @@ def create_embedding_engine(
108
112
  model=embedding_model,
109
113
  dimensions=embedding_dimensions,
110
114
  max_completion_tokens=embedding_max_completion_tokens,
115
+ batch_size=embedding_batch_size,
111
116
  )
@@ -205,9 +205,12 @@ class LanceDBAdapter(VectorDBInterface):
205
205
  collection = await self.get_collection(collection_name)
206
206
 
207
207
  if len(data_point_ids) == 1:
208
- results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
208
+ results = await collection.query().where(f"id = '{data_point_ids[0]}'")
209
209
  else:
210
- results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
210
+ results = await collection.query().where(f"id IN {tuple(data_point_ids)}")
211
+
212
+ # Convert query results to list format
213
+ results_list = results.to_list() if hasattr(results, "to_list") else list(results)
211
214
 
212
215
  return [
213
216
  ScoredResult(
@@ -215,7 +218,7 @@ class LanceDBAdapter(VectorDBInterface):
215
218
  payload=result["payload"],
216
219
  score=0,
217
220
  )
218
- for result in results.to_dict("index").values()
221
+ for result in results_list
219
222
  ]
220
223
 
221
224
  async def search(
@@ -223,7 +226,7 @@ class LanceDBAdapter(VectorDBInterface):
223
226
  collection_name: str,
224
227
  query_text: str = None,
225
228
  query_vector: List[float] = None,
226
- limit: int = 15,
229
+ limit: Optional[int] = 15,
227
230
  with_vector: bool = False,
228
231
  normalized: bool = True,
229
232
  ):
@@ -235,16 +238,14 @@ class LanceDBAdapter(VectorDBInterface):
235
238
 
236
239
  collection = await self.get_collection(collection_name)
237
240
 
238
- if limit == 0:
241
+ if limit is None:
239
242
  limit = await collection.count_rows()
240
243
 
241
244
  # LanceDB search will break if limit is 0 so we must return
242
- if limit == 0:
245
+ if limit <= 0:
243
246
  return []
244
247
 
245
- results = await collection.vector_search(query_vector).limit(limit).to_pandas()
246
-
247
- result_values = list(results.to_dict("index").values())
248
+ result_values = await collection.vector_search(query_vector).limit(limit).to_list()
248
249
 
249
250
  if not result_values:
250
251
  return []
@@ -264,7 +265,7 @@ class LanceDBAdapter(VectorDBInterface):
264
265
  self,
265
266
  collection_name: str,
266
267
  query_texts: List[str],
267
- limit: int = None,
268
+ limit: Optional[int] = None,
268
269
  with_vectors: bool = False,
269
270
  ):
270
271
  query_vectors = await self.embedding_engine.embed_text(query_texts)
@@ -3,13 +3,12 @@ from typing import List, Optional, get_type_hints
3
3
  from sqlalchemy.inspection import inspect
4
4
  from sqlalchemy.orm import Mapped, mapped_column
5
5
  from sqlalchemy.dialects.postgresql import insert
6
- from sqlalchemy import JSON, Column, Table, select, delete, MetaData
6
+ from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
7
7
  from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
8
8
  from sqlalchemy.exc import ProgrammingError
9
9
  from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
10
10
  from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
11
11
 
12
-
13
12
  from cognee.shared.logging_utils import get_logger
14
13
  from cognee.infrastructure.engine import DataPoint
15
14
  from cognee.infrastructure.engine.utils import parse_id
@@ -126,41 +125,42 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
126
125
  data_point_types = get_type_hints(DataPoint)
127
126
  vector_size = self.embedding_engine.get_vector_size()
128
127
 
129
- async with self.VECTOR_DB_LOCK:
130
- if not await self.has_collection(collection_name):
131
-
132
- class PGVectorDataPoint(Base):
133
- """
134
- Represent a point in a vector data space with associated data and vector representation.
135
-
136
- This class inherits from Base and is associated with a database table defined by
137
- __tablename__. It maintains the following public methods and instance variables:
138
-
139
- - __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
140
-
141
- Instance variables:
142
- - id: Identifier for the data point, defined by data_point_types.
143
- - payload: JSON data associated with the data point.
144
- - vector: Vector representation of the data point, with size defined by vector_size.
145
- """
146
-
147
- __tablename__ = collection_name
148
- __table_args__ = {"extend_existing": True}
149
- # PGVector requires one column to be the primary key
150
- id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
151
- payload = Column(JSON)
152
- vector = Column(self.Vector(vector_size))
153
-
154
- def __init__(self, id, payload, vector):
155
- self.id = id
156
- self.payload = payload
157
- self.vector = vector
158
-
159
- async with self.engine.begin() as connection:
160
- if len(Base.metadata.tables.keys()) > 0:
161
- await connection.run_sync(
162
- Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
163
- )
128
+ if not await self.has_collection(collection_name):
129
+ async with self.VECTOR_DB_LOCK:
130
+ if not await self.has_collection(collection_name):
131
+
132
+ class PGVectorDataPoint(Base):
133
+ """
134
+ Represent a point in a vector data space with associated data and vector representation.
135
+
136
+ This class inherits from Base and is associated with a database table defined by
137
+ __tablename__. It maintains the following public methods and instance variables:
138
+
139
+ - __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
140
+
141
+ Instance variables:
142
+ - id: Identifier for the data point, defined by data_point_types.
143
+ - payload: JSON data associated with the data point.
144
+ - vector: Vector representation of the data point, with size defined by vector_size.
145
+ """
146
+
147
+ __tablename__ = collection_name
148
+ __table_args__ = {"extend_existing": True}
149
+ # PGVector requires one column to be the primary key
150
+ id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
151
+ payload = Column(JSON)
152
+ vector = Column(self.Vector(vector_size))
153
+
154
+ def __init__(self, id, payload, vector):
155
+ self.id = id
156
+ self.payload = payload
157
+ self.vector = vector
158
+
159
+ async with self.engine.begin() as connection:
160
+ if len(Base.metadata.tables.keys()) > 0:
161
+ await connection.run_sync(
162
+ Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
163
+ )
164
164
 
165
165
  @retry(
166
166
  retry=retry_if_exception_type(DeadlockDetectedError),
@@ -299,7 +299,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
299
299
  collection_name: str,
300
300
  query_text: Optional[str] = None,
301
301
  query_vector: Optional[List[float]] = None,
302
- limit: int = 15,
302
+ limit: Optional[int] = 15,
303
303
  with_vector: bool = False,
304
304
  ) -> List[ScoredResult]:
305
305
  if query_text is None and query_vector is None:
@@ -311,6 +311,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
311
311
  # Get PGVectorDataPoint Table from database
312
312
  PGVectorDataPoint = await self.get_table(collection_name)
313
313
 
314
+ if limit is None:
315
+ async with self.get_async_session() as session:
316
+ query = select(func.count()).select_from(PGVectorDataPoint)
317
+ result = await session.execute(query)
318
+ limit = result.scalar_one()
319
+
320
+ # If limit is still 0, no need to do the search, just return empty results
321
+ if limit <= 0:
322
+ return []
323
+
314
324
  # NOTE: This needs to be initialized in case search doesn't return a value
315
325
  closest_items = []
316
326
 
@@ -83,7 +83,7 @@ class VectorDBInterface(Protocol):
83
83
  collection_name: str,
84
84
  query_text: Optional[str],
85
85
  query_vector: Optional[List[float]],
86
- limit: int,
86
+ limit: Optional[int],
87
87
  with_vector: bool = False,
88
88
  ):
89
89
  """
@@ -98,7 +98,7 @@ class VectorDBInterface(Protocol):
98
98
  collection.
99
99
  - query_vector (Optional[List[float]]): An optional vector representation for
100
100
  searching the collection.
101
- - limit (int): The maximum number of results to return from the search.
101
+ - limit (Optional[int]): The maximum number of results to return from the search.
102
102
  - with_vector (bool): Whether to return the vector representations with search
103
103
  results. (default False)
104
104
  """
@@ -106,7 +106,11 @@ class VectorDBInterface(Protocol):
106
106
 
107
107
  @abstractmethod
108
108
  async def batch_search(
109
- self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
109
+ self,
110
+ collection_name: str,
111
+ query_texts: List[str],
112
+ limit: Optional[int],
113
+ with_vectors: bool = False,
110
114
  ):
111
115
  """
112
116
  Perform a batch search using multiple text queries against a collection.
@@ -116,7 +120,7 @@ class VectorDBInterface(Protocol):
116
120
 
117
121
  - collection_name (str): The name of the collection to conduct the batch search in.
118
122
  - query_texts (List[str]): A list of text queries to use for the search.
119
- - limit (int): The maximum number of results to return for each query.
123
+ - limit (Optional[int]): The maximum number of results to return for each query.
120
124
  - with_vectors (bool): Whether to include vector representations with search
121
125
  results. (default False)
122
126
  """
@@ -1,6 +1,5 @@
1
1
  import os
2
- import s3fs
3
- from typing import BinaryIO, Union
2
+ from typing import BinaryIO, Union, TYPE_CHECKING
4
3
  from contextlib import asynccontextmanager
5
4
 
6
5
  from cognee.infrastructure.files.storage.s3_config import get_s3_config
@@ -8,23 +7,34 @@ from cognee.infrastructure.utils.run_async import run_async
8
7
  from cognee.infrastructure.files.storage.FileBufferedReader import FileBufferedReader
9
8
  from .storage import Storage
10
9
 
10
+ if TYPE_CHECKING:
11
+ import s3fs
12
+
11
13
 
12
14
  class S3FileStorage(Storage):
13
15
  """
14
- Manage local file storage operations such as storing, retrieving, and managing files on
15
- the filesystem.
16
+ Manage S3 file storage operations such as storing, retrieving, and managing files on
17
+ S3-compatible storage.
16
18
  """
17
19
 
18
20
  storage_path: str
19
- s3: s3fs.S3FileSystem
21
+ s3: "s3fs.S3FileSystem"
20
22
 
21
23
  def __init__(self, storage_path: str):
24
+ try:
25
+ import s3fs
26
+ except ImportError:
27
+ raise ImportError(
28
+ 's3fs is required for S3FileStorage. Install it with: pip install cognee"[aws]"'
29
+ )
30
+
22
31
  self.storage_path = storage_path
23
32
  s3_config = get_s3_config()
24
33
  if s3_config.aws_access_key_id is not None and s3_config.aws_secret_access_key is not None:
25
34
  self.s3 = s3fs.S3FileSystem(
26
35
  key=s3_config.aws_access_key_id,
27
36
  secret=s3_config.aws_secret_access_key,
37
+ token=s3_config.aws_session_token,
28
38
  anon=False,
29
39
  endpoint_url=s3_config.aws_endpoint_url,
30
40
  client_kwargs={"region_name": s3_config.aws_region},
@@ -8,6 +8,7 @@ class S3Config(BaseSettings):
8
8
  aws_endpoint_url: Optional[str] = None
9
9
  aws_access_key_id: Optional[str] = None
10
10
  aws_secret_access_key: Optional[str] = None
11
+ aws_session_token: Optional[str] = None
11
12
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
12
13
 
13
14
 
@@ -4,7 +4,6 @@ from urllib.parse import urlparse
4
4
  from contextlib import asynccontextmanager
5
5
 
6
6
  from cognee.infrastructure.files.utils.get_data_file_path import get_data_file_path
7
- from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
8
7
  from cognee.infrastructure.files.storage.LocalFileStorage import LocalFileStorage
9
8
 
10
9
 
@@ -23,23 +22,17 @@ async def open_data_file(file_path: str, mode: str = "rb", encoding: str = None,
23
22
  yield file
24
23
 
25
24
  elif file_path.startswith("s3://"):
25
+ try:
26
+ from cognee.infrastructure.files.storage.S3FileStorage import S3FileStorage
27
+ except ImportError:
28
+ raise ImportError(
29
+ "S3 dependencies are not installed. Please install with 'pip install cognee\"[aws]\"' to use S3 functionality."
30
+ )
31
+
26
32
  normalized_url = get_data_file_path(file_path)
27
33
  s3_dir_path = os.path.dirname(normalized_url)
28
34
  s3_filename = os.path.basename(normalized_url)
29
35
 
30
- # if "/" in s3_path:
31
- # s3_dir = "/".join(s3_path.split("/")[:-1])
32
- # s3_filename = s3_path.split("/")[-1]
33
- # else:
34
- # s3_dir = ""
35
- # s3_filename = s3_path
36
-
37
- # Extract filesystem path from S3 URL structure
38
- # file_dir_path = (
39
- # f"s3://{parsed_url.netloc}/{s3_dir}" if s3_dir else f"s3://{parsed_url.netloc}"
40
- # )
41
- # file_name = s3_filename
42
-
43
36
  file_storage = S3FileStorage(s3_dir_path)
44
37
 
45
38
  async with file_storage.open(s3_filename, mode=mode, **kwargs) as file: