cognee 0.5.0.dev0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. cognee/api/client.py +1 -5
  2. cognee/api/v1/add/add.py +2 -1
  3. cognee/api/v1/cognify/cognify.py +24 -16
  4. cognee/api/v1/cognify/routers/__init__.py +0 -1
  5. cognee/api/v1/cognify/routers/get_cognify_router.py +3 -1
  6. cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
  7. cognee/api/v1/ontologies/ontologies.py +12 -37
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +27 -25
  9. cognee/api/v1/search/search.py +8 -0
  10. cognee/api/v1/ui/node_setup.py +360 -0
  11. cognee/api/v1/ui/npm_utils.py +50 -0
  12. cognee/api/v1/ui/ui.py +38 -68
  13. cognee/context_global_variables.py +61 -16
  14. cognee/eval_framework/Dockerfile +29 -0
  15. cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
  16. cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
  17. cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
  18. cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
  19. cognee/eval_framework/eval_config.py +2 -2
  20. cognee/eval_framework/modal_run_eval.py +16 -28
  21. cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
  22. cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
  23. cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
  24. cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
  25. cognee/infrastructure/databases/graph/config.py +3 -0
  26. cognee/infrastructure/databases/graph/get_graph_engine.py +1 -0
  27. cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
  28. cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
  29. cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
  30. cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
  31. cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
  32. cognee/infrastructure/databases/utils/__init__.py +3 -0
  33. cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
  34. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +62 -48
  35. cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
  36. cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
  37. cognee/infrastructure/databases/vector/config.py +2 -0
  38. cognee/infrastructure/databases/vector/create_vector_engine.py +1 -0
  39. cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
  40. cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
  41. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -10
  42. cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
  43. cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
  44. cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
  45. cognee/infrastructure/files/storage/s3_config.py +2 -0
  46. cognee/infrastructure/llm/LLMGateway.py +5 -2
  47. cognee/infrastructure/llm/config.py +35 -0
  48. cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
  49. cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
  50. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -16
  51. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
  52. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
  53. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +40 -37
  54. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +39 -36
  55. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +19 -1
  56. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +11 -9
  57. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +23 -21
  58. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +42 -34
  59. cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
  60. cognee/modules/cognify/config.py +2 -0
  61. cognee/modules/data/deletion/prune_system.py +52 -2
  62. cognee/modules/data/methods/delete_dataset.py +26 -0
  63. cognee/modules/engine/models/Triplet.py +9 -0
  64. cognee/modules/engine/models/__init__.py +1 -0
  65. cognee/modules/graph/cognee_graph/CogneeGraph.py +85 -37
  66. cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
  67. cognee/modules/memify/memify.py +1 -7
  68. cognee/modules/pipelines/operations/pipeline.py +18 -2
  69. cognee/modules/retrieval/__init__.py +1 -1
  70. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +4 -0
  71. cognee/modules/retrieval/graph_completion_cot_retriever.py +4 -0
  72. cognee/modules/retrieval/graph_completion_retriever.py +10 -0
  73. cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
  74. cognee/modules/retrieval/register_retriever.py +10 -0
  75. cognee/modules/retrieval/registered_community_retrievers.py +1 -0
  76. cognee/modules/retrieval/temporal_retriever.py +4 -0
  77. cognee/modules/retrieval/triplet_retriever.py +182 -0
  78. cognee/modules/retrieval/utils/brute_force_triplet_search.py +42 -10
  79. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +8 -1
  80. cognee/modules/search/methods/get_search_type_tools.py +54 -8
  81. cognee/modules/search/methods/no_access_control_search.py +4 -0
  82. cognee/modules/search/methods/search.py +46 -18
  83. cognee/modules/search/types/SearchType.py +1 -1
  84. cognee/modules/settings/get_settings.py +19 -0
  85. cognee/modules/users/methods/get_authenticated_user.py +2 -2
  86. cognee/modules/users/models/DatasetDatabase.py +15 -3
  87. cognee/shared/logging_utils.py +4 -0
  88. cognee/shared/rate_limiting.py +30 -0
  89. cognee/tasks/documents/__init__.py +0 -1
  90. cognee/tasks/graph/extract_graph_from_data.py +9 -10
  91. cognee/tasks/memify/get_triplet_datapoints.py +289 -0
  92. cognee/tasks/storage/add_data_points.py +142 -2
  93. cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
  94. cognee/tests/integration/tasks/test_add_data_points.py +139 -0
  95. cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
  96. cognee/tests/test_cognee_server_start.py +2 -4
  97. cognee/tests/test_conversation_history.py +23 -1
  98. cognee/tests/test_dataset_database_handler.py +137 -0
  99. cognee/tests/test_dataset_delete.py +76 -0
  100. cognee/tests/test_edge_centered_payload.py +170 -0
  101. cognee/tests/test_pipeline_cache.py +164 -0
  102. cognee/tests/test_search_db.py +37 -1
  103. cognee/tests/unit/api/test_ontology_endpoint.py +77 -89
  104. cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
  105. cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
  106. cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
  107. cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
  108. cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
  109. cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
  110. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
  111. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
  112. cognee/tests/unit/modules/search/test_search.py +100 -0
  113. cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
  114. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/METADATA +76 -89
  115. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/RECORD +119 -97
  116. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/WHEEL +1 -1
  117. cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
  118. cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
  119. cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
  120. cognee/modules/retrieval/code_retriever.py +0 -232
  121. cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
  122. cognee/tasks/code/get_local_dependencies_checker.py +0 -20
  123. cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
  124. cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
  125. cognee/tasks/repo_processor/__init__.py +0 -2
  126. cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
  127. cognee/tasks/repo_processor/get_non_code_files.py +0 -158
  128. cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
  129. cognee/tests/test_delete_bmw_example.py +0 -60
  130. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/entry_points.txt +0 -0
  131. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/LICENSE +0 -0
  132. {cognee-0.5.0.dev0.dist-info → cognee-0.5.1.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,10 @@
1
+ from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
2
+
3
+
4
+ def get_vector_dataset_database_handler(dataset_database: DatasetDatabase) -> dict:
5
+ from cognee.infrastructure.databases.dataset_database_handler.supported_dataset_database_handlers import (
6
+ supported_dataset_database_handlers,
7
+ )
8
+
9
+ handler = supported_dataset_database_handlers[dataset_database.vector_dataset_database_handler]
10
+ return handler
@@ -0,0 +1,30 @@
1
+ from cognee.infrastructure.databases.utils.get_graph_dataset_database_handler import (
2
+ get_graph_dataset_database_handler,
3
+ )
4
+ from cognee.infrastructure.databases.utils.get_vector_dataset_database_handler import (
5
+ get_vector_dataset_database_handler,
6
+ )
7
+ from cognee.modules.users.models.DatasetDatabase import DatasetDatabase
8
+
9
+
10
+ async def resolve_dataset_database_connection_info(
11
+ dataset_database: DatasetDatabase,
12
+ ) -> DatasetDatabase:
13
+ """
14
+ Resolve the connection info for the given DatasetDatabase instance.
15
+ Resolve both vector and graph database connection info and return the updated DatasetDatabase instance.
16
+
17
+ Args:
18
+ dataset_database: DatasetDatabase instance
19
+ Returns:
20
+ DatasetDatabase instance with resolved connection info
21
+ """
22
+ vector_dataset_database_handler = get_vector_dataset_database_handler(dataset_database)
23
+ graph_dataset_database_handler = get_graph_dataset_database_handler(dataset_database)
24
+ dataset_database = await vector_dataset_database_handler[
25
+ "handler_instance"
26
+ ].resolve_dataset_connection_info(dataset_database)
27
+ dataset_database = await graph_dataset_database_handler[
28
+ "handler_instance"
29
+ ].resolve_dataset_connection_info(dataset_database)
30
+ return dataset_database
@@ -28,6 +28,7 @@ class VectorConfig(BaseSettings):
28
28
  vector_db_name: str = ""
29
29
  vector_db_key: str = ""
30
30
  vector_db_provider: str = "lancedb"
31
+ vector_dataset_database_handler: str = "lancedb"
31
32
 
32
33
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
33
34
 
@@ -63,6 +64,7 @@ class VectorConfig(BaseSettings):
63
64
  "vector_db_name": self.vector_db_name,
64
65
  "vector_db_key": self.vector_db_key,
65
66
  "vector_db_provider": self.vector_db_provider,
67
+ "vector_dataset_database_handler": self.vector_dataset_database_handler,
66
68
  }
67
69
 
68
70
 
@@ -12,6 +12,7 @@ def create_vector_engine(
12
12
  vector_db_name: str,
13
13
  vector_db_port: str = "",
14
14
  vector_db_key: str = "",
15
+ vector_dataset_database_handler: str = "",
15
16
  ):
16
17
  """
17
18
  Create a vector database engine based on the specified provider.
@@ -17,6 +17,7 @@ from cognee.infrastructure.databases.exceptions import EmbeddingException
17
17
  from cognee.infrastructure.llm.tokenizer.TikToken import (
18
18
  TikTokenTokenizer,
19
19
  )
20
+ from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
20
21
 
21
22
  litellm.set_verbose = False
22
23
  logger = get_logger("FastembedEmbeddingEngine")
@@ -68,7 +69,7 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
68
69
 
69
70
  @retry(
70
71
  stop=stop_after_delay(128),
71
- wait=wait_exponential_jitter(2, 128),
72
+ wait=wait_exponential_jitter(8, 128),
72
73
  retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
73
74
  before_sleep=before_sleep_log(logger, logging.DEBUG),
74
75
  reraise=True,
@@ -96,11 +97,12 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
96
97
  if self.mock:
97
98
  return [[0.0] * self.dimensions for _ in text]
98
99
  else:
99
- embeddings = self.embedding_model.embed(
100
- text,
101
- batch_size=len(text),
102
- parallel=None,
103
- )
100
+ async with embedding_rate_limiter_context_manager():
101
+ embeddings = self.embedding_model.embed(
102
+ text,
103
+ batch_size=len(text),
104
+ parallel=None,
105
+ )
104
106
 
105
107
  return list(embeddings)
106
108
 
@@ -25,6 +25,7 @@ from cognee.infrastructure.llm.tokenizer.Mistral import (
25
25
  from cognee.infrastructure.llm.tokenizer.TikToken import (
26
26
  TikTokenTokenizer,
27
27
  )
28
+ from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
28
29
 
29
30
  litellm.set_verbose = False
30
31
  logger = get_logger("LiteLLMEmbeddingEngine")
@@ -109,13 +110,14 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
109
110
  response = {"data": [{"embedding": [0.0] * self.dimensions} for _ in text]}
110
111
  return [data["embedding"] for data in response["data"]]
111
112
  else:
112
- response = await litellm.aembedding(
113
- model=self.model,
114
- input=text,
115
- api_key=self.api_key,
116
- api_base=self.endpoint,
117
- api_version=self.api_version,
118
- )
113
+ async with embedding_rate_limiter_context_manager():
114
+ response = await litellm.aembedding(
115
+ model=self.model,
116
+ input=text,
117
+ api_key=self.api_key,
118
+ api_base=self.endpoint,
119
+ api_version=self.api_version,
120
+ )
119
121
 
120
122
  return [data["embedding"] for data in response.data]
121
123
 
@@ -18,10 +18,7 @@ from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import Em
18
18
  from cognee.infrastructure.llm.tokenizer.HuggingFace import (
19
19
  HuggingFaceTokenizer,
20
20
  )
21
- from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
22
- embedding_rate_limit_async,
23
- embedding_sleep_and_retry_async,
24
- )
21
+ from cognee.shared.rate_limiting import embedding_rate_limiter_context_manager
25
22
  from cognee.shared.utils import create_secure_ssl_context
26
23
 
27
24
  logger = get_logger("OllamaEmbeddingEngine")
@@ -101,7 +98,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
101
98
 
102
99
  @retry(
103
100
  stop=stop_after_delay(128),
104
- wait=wait_exponential_jitter(2, 128),
101
+ wait=wait_exponential_jitter(8, 128),
105
102
  retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
106
103
  before_sleep=before_sleep_log(logger, logging.DEBUG),
107
104
  reraise=True,
@@ -120,11 +117,15 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
120
117
  ssl_context = create_secure_ssl_context()
121
118
  connector = aiohttp.TCPConnector(ssl=ssl_context)
122
119
  async with aiohttp.ClientSession(connector=connector) as session:
123
- async with session.post(
124
- self.endpoint, json=payload, headers=headers, timeout=60.0
125
- ) as response:
126
- data = await response.json()
127
- return data["embeddings"][0]
120
+ async with embedding_rate_limiter_context_manager():
121
+ async with session.post(
122
+ self.endpoint, json=payload, headers=headers, timeout=60.0
123
+ ) as response:
124
+ data = await response.json()
125
+ if "embeddings" in data:
126
+ return data["embeddings"][0]
127
+ else:
128
+ return data["data"][0]["embedding"]
128
129
 
129
130
  def get_vector_size(self) -> int:
130
131
  """
@@ -193,6 +193,8 @@ class LanceDBAdapter(VectorDBInterface):
193
193
  for (data_point_index, data_point) in enumerate(data_points)
194
194
  ]
195
195
 
196
+ lance_data_points = list({dp.id: dp for dp in lance_data_points}.values())
197
+
196
198
  async with self.VECTOR_DB_LOCK:
197
199
  await (
198
200
  collection.merge_insert("id")
@@ -0,0 +1,50 @@
1
+ import os
2
+ from uuid import UUID
3
+ from typing import Optional
4
+
5
+ from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
6
+ from cognee.modules.users.models import User
7
+ from cognee.modules.users.models import DatasetDatabase
8
+ from cognee.base_config import get_base_config
9
+ from cognee.infrastructure.databases.vector import get_vectordb_config
10
+ from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
11
+
12
+
13
+ class LanceDBDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
14
+ """
15
+ Handler for interacting with LanceDB Dataset databases.
16
+ """
17
+
18
+ @classmethod
19
+ async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
20
+ vector_config = get_vectordb_config()
21
+ base_config = get_base_config()
22
+
23
+ if vector_config.vector_db_provider != "lancedb":
24
+ raise ValueError(
25
+ "LanceDBDatasetDatabaseHandler can only be used with LanceDB vector database provider."
26
+ )
27
+
28
+ databases_directory_path = os.path.join(
29
+ base_config.system_root_directory, "databases", str(user.id)
30
+ )
31
+
32
+ vector_db_name = f"{dataset_id}.lance.db"
33
+
34
+ return {
35
+ "vector_database_provider": vector_config.vector_db_provider,
36
+ "vector_database_url": os.path.join(databases_directory_path, vector_db_name),
37
+ "vector_database_key": vector_config.vector_db_key,
38
+ "vector_database_name": vector_db_name,
39
+ "vector_dataset_database_handler": "lancedb",
40
+ }
41
+
42
+ @classmethod
43
+ async def delete_dataset(cls, dataset_database: DatasetDatabase):
44
+ vector_engine = create_vector_engine(
45
+ vector_db_provider=dataset_database.vector_database_provider,
46
+ vector_db_url=dataset_database.vector_database_url,
47
+ vector_db_key=dataset_database.vector_database_key,
48
+ vector_db_name=dataset_database.vector_database_name,
49
+ )
50
+ await vector_engine.prune()
@@ -2,6 +2,8 @@ from typing import List, Protocol, Optional, Union, Any
2
2
  from abc import abstractmethod
3
3
  from cognee.infrastructure.engine import DataPoint
4
4
  from .models.PayloadSchema import PayloadSchema
5
+ from uuid import UUID
6
+ from cognee.modules.users.models import User
5
7
 
6
8
 
7
9
  class VectorDBInterface(Protocol):
@@ -217,3 +219,36 @@ class VectorDBInterface(Protocol):
217
219
  - Any: The schema object suitable for this vector database
218
220
  """
219
221
  return model_type
222
+
223
+ @classmethod
224
+ async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
225
+ """
226
+ Return a dictionary with connection info for a vector database for the given dataset.
227
+ Function can auto handle deploying of the actual database if needed, but is not necessary.
228
+ Only providing connection info is sufficient, this info will be mapped when trying to connect to the provided dataset in the future.
229
+ Needed for Cognee multi-tenant/multi-user and backend access control support.
230
+
231
+ Dictionary returned from this function will be used to create a DatasetDatabase row in the relational database.
232
+ From which internal mapping of dataset -> database connection info will be done.
233
+
234
+ Each dataset needs to map to a unique vector database when backend access control is enabled to facilitate a separation of concern for data.
235
+
236
+ Args:
237
+ dataset_id: UUID of the dataset if needed by the database creation logic
238
+ user: User object if needed by the database creation logic
239
+ Returns:
240
+ dict: Connection info for the created vector database instance.
241
+ """
242
+ pass
243
+
244
+ async def delete_dataset(self, dataset_id: UUID, user: User) -> None:
245
+ """
246
+ Delete the vector database for the given dataset.
247
+ Function should auto handle deleting of the actual database or send a request to the proper service to delete the database.
248
+ Needed for maintaining a database for Cognee multi-tenant/multi-user and backend access control.
249
+
250
+ Args:
251
+ dataset_id: UUID of the dataset
252
+ user: User object
253
+ """
254
+ pass
@@ -9,6 +9,8 @@ class S3Config(BaseSettings):
9
9
  aws_access_key_id: Optional[str] = None
10
10
  aws_secret_access_key: Optional[str] = None
11
11
  aws_session_token: Optional[str] = None
12
+ aws_profile_name: Optional[str] = None
13
+ aws_bedrock_runtime_endpoint: Optional[str] = None
12
14
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
13
15
 
14
16
 
@@ -11,7 +11,7 @@ class LLMGateway:
11
11
 
12
12
  @staticmethod
13
13
  def acreate_structured_output(
14
- text_input: str, system_prompt: str, response_model: Type[BaseModel]
14
+ text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
15
15
  ) -> Coroutine:
16
16
  llm_config = get_llm_config()
17
17
  if llm_config.structured_output_framework.upper() == "BAML":
@@ -31,7 +31,10 @@ class LLMGateway:
31
31
 
32
32
  llm_client = get_llm_client()
33
33
  return llm_client.acreate_structured_output(
34
- text_input=text_input, system_prompt=system_prompt, response_model=response_model
34
+ text_input=text_input,
35
+ system_prompt=system_prompt,
36
+ response_model=response_model,
37
+ **kwargs,
35
38
  )
36
39
 
37
40
  @staticmethod
@@ -74,6 +74,41 @@ class LLMConfig(BaseSettings):
74
74
 
75
75
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
76
76
 
77
+ @model_validator(mode="after")
78
+ def strip_quotes_from_strings(self) -> "LLMConfig":
79
+ """
80
+ Strip surrounding quotes from specific string fields that often come from
81
+ environment variables with extra quotes (e.g., via Docker's --env-file).
82
+
83
+ Only applies to known config keys where quotes are invalid or cause issues.
84
+ """
85
+ string_fields_to_strip = [
86
+ "llm_api_key",
87
+ "llm_endpoint",
88
+ "llm_api_version",
89
+ "baml_llm_api_key",
90
+ "baml_llm_endpoint",
91
+ "baml_llm_api_version",
92
+ "fallback_api_key",
93
+ "fallback_endpoint",
94
+ "fallback_model",
95
+ "llm_provider",
96
+ "llm_model",
97
+ "baml_llm_provider",
98
+ "baml_llm_model",
99
+ ]
100
+
101
+ cls = self.__class__
102
+ for field_name in string_fields_to_strip:
103
+ if field_name not in cls.model_fields:
104
+ continue
105
+ value = getattr(self, field_name, None)
106
+ if isinstance(value, str) and len(value) >= 2:
107
+ if value[0] == value[-1] and value[0] in ("'", '"'):
108
+ setattr(self, field_name, value[1:-1])
109
+
110
+ return self
111
+
77
112
  def model_post_init(self, __context) -> None:
78
113
  """Initialize the BAML registry after the model is created."""
79
114
  # Check if BAML is selected as structured output framework but not available
@@ -10,7 +10,7 @@ from cognee.infrastructure.llm.config import (
10
10
 
11
11
 
12
12
  async def extract_content_graph(
13
- content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
13
+ content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None, **kwargs
14
14
  ):
15
15
  if custom_prompt:
16
16
  system_prompt = custom_prompt
@@ -30,7 +30,7 @@ async def extract_content_graph(
30
30
  system_prompt = render_prompt(prompt_path, {}, base_directory=base_directory)
31
31
 
32
32
  content_graph = await LLMGateway.acreate_structured_output(
33
- content, system_prompt, response_model
33
+ content, system_prompt, response_model, **kwargs
34
34
  )
35
35
 
36
36
  return content_graph
@@ -1,7 +1,15 @@
1
1
  import asyncio
2
2
  from typing import Type
3
- from cognee.shared.logging_utils import get_logger
3
+ from pydantic import BaseModel
4
+ from tenacity import (
5
+ retry,
6
+ stop_after_delay,
7
+ wait_exponential_jitter,
8
+ retry_if_not_exception_type,
9
+ before_sleep_log,
10
+ )
4
11
 
12
+ from cognee.shared.logging_utils import get_logger
5
13
  from cognee.infrastructure.llm.config import get_llm_config
6
14
  from cognee.infrastructure.llm.structured_output_framework.baml.baml_src.extraction.create_dynamic_baml_type import (
7
15
  create_dynamic_baml_type,
@@ -10,12 +18,18 @@ from cognee.infrastructure.llm.structured_output_framework.baml.baml_client.type
10
18
  TypeBuilder,
11
19
  )
12
20
  from cognee.infrastructure.llm.structured_output_framework.baml.baml_client import b
13
- from pydantic import BaseModel
14
-
21
+ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
22
+ import logging
15
23
 
16
24
  logger = get_logger()
17
25
 
18
26
 
27
+ @retry(
28
+ stop=stop_after_delay(128),
29
+ wait=wait_exponential_jitter(8, 128),
30
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
31
+ reraise=True,
32
+ )
19
33
  async def acreate_structured_output(
20
34
  text_input: str, system_prompt: str, response_model: Type[BaseModel]
21
35
  ):
@@ -45,11 +59,12 @@ async def acreate_structured_output(
45
59
  tb = TypeBuilder()
46
60
  type_builder = create_dynamic_baml_type(tb, tb.ResponseModel, response_model)
47
61
 
48
- result = await b.AcreateStructuredOutput(
49
- text_input=text_input,
50
- system_prompt=system_prompt,
51
- baml_options={"client_registry": config.baml_registry, "tb": type_builder},
52
- )
62
+ async with llm_rate_limiter_context_manager():
63
+ result = await b.AcreateStructuredOutput(
64
+ text_input=text_input,
65
+ system_prompt=system_prompt,
66
+ baml_options={"client_registry": config.baml_registry, "tb": type_builder},
67
+ )
53
68
 
54
69
  # Transform BAML response to proper pydantic reponse model
55
70
  if response_model is str:
@@ -15,6 +15,7 @@ from tenacity import (
15
15
  from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
16
16
  LLMInterface,
17
17
  )
18
+ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
18
19
  from cognee.infrastructure.llm.config import get_llm_config
19
20
 
20
21
  logger = get_logger()
@@ -45,13 +46,13 @@ class AnthropicAdapter(LLMInterface):
45
46
 
46
47
  @retry(
47
48
  stop=stop_after_delay(128),
48
- wait=wait_exponential_jitter(2, 128),
49
+ wait=wait_exponential_jitter(8, 128),
49
50
  retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
50
51
  before_sleep=before_sleep_log(logger, logging.DEBUG),
51
52
  reraise=True,
52
53
  )
53
54
  async def acreate_structured_output(
54
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
55
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
55
56
  ) -> BaseModel:
56
57
  """
57
58
  Generate a response from a user query.
@@ -69,17 +70,17 @@ class AnthropicAdapter(LLMInterface):
69
70
 
70
71
  - BaseModel: An instance of BaseModel containing the structured response.
71
72
  """
72
-
73
- return await self.aclient(
74
- model=self.model,
75
- max_tokens=4096,
76
- max_retries=5,
77
- messages=[
78
- {
79
- "role": "user",
80
- "content": f"""Use the given format to extract information
81
- from the following input: {text_input}. {system_prompt}""",
82
- }
83
- ],
84
- response_model=response_model,
85
- )
73
+ async with llm_rate_limiter_context_manager():
74
+ return await self.aclient(
75
+ model=self.model,
76
+ max_tokens=4096,
77
+ max_retries=2,
78
+ messages=[
79
+ {
80
+ "role": "user",
81
+ "content": f"""Use the given format to extract information
82
+ from the following input: {text_input}. {system_prompt}""",
83
+ }
84
+ ],
85
+ response_model=response_model,
86
+ )
@@ -0,0 +1,5 @@
1
+ """Bedrock LLM adapter module."""
2
+
3
+ from .adapter import BedrockAdapter
4
+
5
+ __all__ = ["BedrockAdapter"]
@@ -0,0 +1,153 @@
1
+ import litellm
2
+ import instructor
3
+ from typing import Type
4
+ from pydantic import BaseModel
5
+ from litellm.exceptions import ContentPolicyViolationError
6
+ from instructor.exceptions import InstructorRetryException
7
+
8
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
9
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
10
+ LLMInterface,
11
+ )
12
+ from cognee.infrastructure.llm.exceptions import (
13
+ ContentPolicyFilterError,
14
+ MissingSystemPromptPathError,
15
+ )
16
+ from cognee.infrastructure.files.storage.s3_config import get_s3_config
17
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
18
+ rate_limit_async,
19
+ rate_limit_sync,
20
+ sleep_and_retry_async,
21
+ sleep_and_retry_sync,
22
+ )
23
+ from cognee.modules.observability.get_observe import get_observe
24
+
25
+ observe = get_observe()
26
+
27
+
28
+ class BedrockAdapter(LLMInterface):
29
+ """
30
+ Adapter for AWS Bedrock API with support for three authentication methods:
31
+ 1. API Key (Bearer Token)
32
+ 2. AWS Credentials (access key + secret key)
33
+ 3. AWS Profile (boto3 credential chain)
34
+ """
35
+
36
+ name = "Bedrock"
37
+ model: str
38
+ api_key: str
39
+ default_instructor_mode = "json_schema_mode"
40
+
41
+ MAX_RETRIES = 5
42
+
43
+ def __init__(
44
+ self,
45
+ model: str,
46
+ api_key: str = None,
47
+ max_completion_tokens: int = 16384,
48
+ streaming: bool = False,
49
+ instructor_mode: str = None,
50
+ ):
51
+ self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
52
+
53
+ self.aclient = instructor.from_litellm(
54
+ litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
55
+ )
56
+ self.client = instructor.from_litellm(litellm.completion)
57
+ self.model = model
58
+ self.api_key = api_key
59
+ self.max_completion_tokens = max_completion_tokens
60
+ self.streaming = streaming
61
+
62
+ def _create_bedrock_request(
63
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
64
+ ) -> dict:
65
+ """Create Bedrock request with authentication."""
66
+
67
+ request_params = {
68
+ "model": self.model,
69
+ "custom_llm_provider": "bedrock",
70
+ "drop_params": True,
71
+ "messages": [
72
+ {"role": "user", "content": text_input},
73
+ {"role": "system", "content": system_prompt},
74
+ ],
75
+ "response_model": response_model,
76
+ "max_retries": self.MAX_RETRIES,
77
+ "max_completion_tokens": self.max_completion_tokens,
78
+ "stream": self.streaming,
79
+ }
80
+
81
+ s3_config = get_s3_config()
82
+
83
+ # Add authentication parameters
84
+ if self.api_key:
85
+ request_params["api_key"] = self.api_key
86
+ elif s3_config.aws_access_key_id and s3_config.aws_secret_access_key:
87
+ request_params["aws_access_key_id"] = s3_config.aws_access_key_id
88
+ request_params["aws_secret_access_key"] = s3_config.aws_secret_access_key
89
+ if s3_config.aws_session_token:
90
+ request_params["aws_session_token"] = s3_config.aws_session_token
91
+ elif s3_config.aws_profile_name:
92
+ request_params["aws_profile_name"] = s3_config.aws_profile_name
93
+
94
+ if s3_config.aws_region:
95
+ request_params["aws_region_name"] = s3_config.aws_region
96
+
97
+ # Add optional parameters
98
+ if s3_config.aws_bedrock_runtime_endpoint:
99
+ request_params["aws_bedrock_runtime_endpoint"] = s3_config.aws_bedrock_runtime_endpoint
100
+
101
+ return request_params
102
+
103
+ @observe(as_type="generation")
104
+ @sleep_and_retry_async()
105
+ @rate_limit_async
106
+ async def acreate_structured_output(
107
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
108
+ ) -> BaseModel:
109
+ """Generate structured output from AWS Bedrock API."""
110
+
111
+ try:
112
+ request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
113
+ return await self.aclient.chat.completions.create(**request_params)
114
+
115
+ except (
116
+ ContentPolicyViolationError,
117
+ InstructorRetryException,
118
+ ) as error:
119
+ if (
120
+ isinstance(error, InstructorRetryException)
121
+ and "content management policy" not in str(error).lower()
122
+ ):
123
+ raise error
124
+
125
+ raise ContentPolicyFilterError(
126
+ f"The provided input contains content that is not aligned with our content policy: {text_input}"
127
+ )
128
+
129
+ @observe
130
+ @sleep_and_retry_sync()
131
+ @rate_limit_sync
132
+ def create_structured_output(
133
+ self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
134
+ ) -> BaseModel:
135
+ """Generate structured output from AWS Bedrock API (synchronous)."""
136
+
137
+ request_params = self._create_bedrock_request(text_input, system_prompt, response_model)
138
+ return self.client.chat.completions.create(**request_params)
139
+
140
+ def show_prompt(self, text_input: str, system_prompt: str) -> str:
141
+ """Format and display the prompt for a user query."""
142
+ if not text_input:
143
+ text_input = "No user input provided."
144
+ if not system_prompt:
145
+ raise MissingSystemPromptPathError()
146
+ system_prompt = LLMGateway.read_query_prompt(system_prompt)
147
+
148
+ formatted_prompt = (
149
+ f"""System Prompt:\n{system_prompt}\n\nUser Input:\n{text_input}\n"""
150
+ if system_prompt
151
+ else None
152
+ )
153
+ return formatted_prompt