cognee 0.4.1__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/client.py +8 -0
  3. cognee/api/v1/add/routers/get_add_router.py +3 -1
  4. cognee/api/v1/cognify/routers/get_cognify_router.py +28 -1
  5. cognee/api/v1/ontologies/__init__.py +4 -0
  6. cognee/api/v1/ontologies/ontologies.py +183 -0
  7. cognee/api/v1/ontologies/routers/__init__.py +0 -0
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +107 -0
  9. cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
  10. cognee/cli/commands/cognify_command.py +8 -1
  11. cognee/cli/config.py +1 -1
  12. cognee/context_global_variables.py +41 -9
  13. cognee/infrastructure/databases/cache/config.py +3 -1
  14. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
  15. cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
  16. cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
  17. cognee/infrastructure/databases/graph/config.py +4 -0
  18. cognee/infrastructure/databases/graph/get_graph_engine.py +2 -0
  19. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
  20. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +37 -3
  21. cognee/infrastructure/databases/vector/config.py +3 -0
  22. cognee/infrastructure/databases/vector/create_vector_engine.py +5 -1
  23. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +1 -4
  24. cognee/infrastructure/engine/models/Edge.py +13 -1
  25. cognee/infrastructure/files/utils/guess_file_type.py +4 -0
  26. cognee/infrastructure/llm/config.py +2 -0
  27. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +5 -2
  28. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +7 -1
  29. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +7 -1
  30. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +8 -16
  31. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +12 -2
  32. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +13 -2
  33. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +5 -2
  34. cognee/infrastructure/loaders/LoaderEngine.py +1 -0
  35. cognee/infrastructure/loaders/core/__init__.py +2 -1
  36. cognee/infrastructure/loaders/core/csv_loader.py +93 -0
  37. cognee/infrastructure/loaders/core/text_loader.py +1 -2
  38. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
  39. cognee/infrastructure/loaders/supported_loaders.py +2 -1
  40. cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
  41. cognee/modules/chunking/CsvChunker.py +35 -0
  42. cognee/modules/chunking/models/DocumentChunk.py +2 -1
  43. cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
  44. cognee/modules/data/methods/__init__.py +1 -0
  45. cognee/modules/data/methods/create_dataset.py +4 -2
  46. cognee/modules/data/methods/get_dataset_ids.py +5 -1
  47. cognee/modules/data/methods/get_unique_data_id.py +68 -0
  48. cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
  49. cognee/modules/data/models/Dataset.py +2 -0
  50. cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
  51. cognee/modules/data/processing/document_types/__init__.py +1 -0
  52. cognee/modules/graph/cognee_graph/CogneeGraph.py +4 -2
  53. cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
  54. cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
  55. cognee/modules/ingestion/identify.py +4 -4
  56. cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
  57. cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
  58. cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
  59. cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
  60. cognee/modules/retrieval/base_graph_retriever.py +7 -3
  61. cognee/modules/retrieval/base_retriever.py +7 -3
  62. cognee/modules/retrieval/completion_retriever.py +11 -4
  63. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +6 -2
  64. cognee/modules/retrieval/graph_completion_cot_retriever.py +14 -51
  65. cognee/modules/retrieval/graph_completion_retriever.py +4 -1
  66. cognee/modules/retrieval/temporal_retriever.py +9 -2
  67. cognee/modules/retrieval/utils/brute_force_triplet_search.py +1 -1
  68. cognee/modules/retrieval/utils/completion.py +2 -22
  69. cognee/modules/run_custom_pipeline/__init__.py +1 -0
  70. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +69 -0
  71. cognee/modules/search/methods/search.py +5 -3
  72. cognee/modules/users/methods/create_user.py +12 -27
  73. cognee/modules/users/methods/get_authenticated_user.py +2 -1
  74. cognee/modules/users/methods/get_default_user.py +4 -2
  75. cognee/modules/users/methods/get_user.py +1 -1
  76. cognee/modules/users/methods/get_user_by_email.py +1 -1
  77. cognee/modules/users/models/DatasetDatabase.py +9 -0
  78. cognee/modules/users/models/Tenant.py +6 -7
  79. cognee/modules/users/models/User.py +6 -5
  80. cognee/modules/users/models/UserTenant.py +12 -0
  81. cognee/modules/users/models/__init__.py +1 -0
  82. cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
  83. cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
  84. cognee/modules/users/tenants/methods/__init__.py +1 -0
  85. cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
  86. cognee/modules/users/tenants/methods/create_tenant.py +22 -8
  87. cognee/modules/users/tenants/methods/select_tenant.py +62 -0
  88. cognee/shared/logging_utils.py +2 -0
  89. cognee/tasks/chunks/__init__.py +1 -0
  90. cognee/tasks/chunks/chunk_by_row.py +94 -0
  91. cognee/tasks/documents/classify_documents.py +2 -0
  92. cognee/tasks/feedback/generate_improved_answers.py +3 -3
  93. cognee/tasks/ingestion/ingest_data.py +1 -1
  94. cognee/tasks/memify/__init__.py +2 -0
  95. cognee/tasks/memify/cognify_session.py +41 -0
  96. cognee/tasks/memify/extract_user_sessions.py +73 -0
  97. cognee/tasks/storage/index_data_points.py +33 -22
  98. cognee/tasks/storage/index_graph_edges.py +37 -57
  99. cognee/tests/integration/documents/CsvDocument_test.py +70 -0
  100. cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
  101. cognee/tests/test_add_docling_document.py +2 -2
  102. cognee/tests/test_cognee_server_start.py +84 -1
  103. cognee/tests/test_conversation_history.py +45 -4
  104. cognee/tests/test_data/example_with_header.csv +3 -0
  105. cognee/tests/test_delete_bmw_example.py +60 -0
  106. cognee/tests/test_edge_ingestion.py +27 -0
  107. cognee/tests/test_feedback_enrichment.py +1 -1
  108. cognee/tests/test_library.py +6 -4
  109. cognee/tests/test_load.py +62 -0
  110. cognee/tests/test_multi_tenancy.py +165 -0
  111. cognee/tests/test_parallel_databases.py +2 -0
  112. cognee/tests/test_relational_db_migration.py +54 -2
  113. cognee/tests/test_search_db.py +7 -1
  114. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
  115. cognee/tests/unit/api/test_ontology_endpoint.py +264 -0
  116. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
  117. cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
  118. cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
  119. cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
  120. cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
  121. cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
  122. cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
  123. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
  124. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
  125. cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
  126. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
  127. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
  128. cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
  129. cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
  130. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +88 -71
  131. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +135 -104
  132. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  133. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -1
  134. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  135. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -146,7 +146,8 @@ class GraphCompletionRetriever(BaseGraphRetriever):
146
146
  query: str,
147
147
  context: Optional[List[Edge]] = None,
148
148
  session_id: Optional[str] = None,
149
- ) -> List[str]:
149
+ response_model: Type = str,
150
+ ) -> List[Any]:
150
151
  """
151
152
  Generates a completion using graph connections context based on a query.
152
153
 
@@ -188,6 +189,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
188
189
  system_prompt_path=self.system_prompt_path,
189
190
  system_prompt=self.system_prompt,
190
191
  conversation_history=conversation_history,
192
+ response_model=response_model,
191
193
  ),
192
194
  )
193
195
  else:
@@ -197,6 +199,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
197
199
  user_prompt_path=self.user_prompt_path,
198
200
  system_prompt_path=self.system_prompt_path,
199
201
  system_prompt=self.system_prompt,
202
+ response_model=response_model,
200
203
  )
201
204
 
202
205
  if self.save_interaction and context and triplets and completion:
@@ -146,8 +146,12 @@ class TemporalRetriever(GraphCompletionRetriever):
146
146
  return self.descriptions_to_string(top_k_events)
147
147
 
148
148
  async def get_completion(
149
- self, query: str, context: Optional[str] = None, session_id: Optional[str] = None
150
- ) -> List[str]:
149
+ self,
150
+ query: str,
151
+ context: Optional[str] = None,
152
+ session_id: Optional[str] = None,
153
+ response_model: Type = str,
154
+ ) -> List[Any]:
151
155
  """
152
156
  Generates a response using the query and optional context.
153
157
 
@@ -159,6 +163,7 @@ class TemporalRetriever(GraphCompletionRetriever):
159
163
  retrieved based on the query. (default None)
160
164
  - session_id (Optional[str]): Optional session identifier for caching. If None,
161
165
  defaults to 'default_session'. (default None)
166
+ - response_model (Type): The Pydantic model type for structured output. (default str)
162
167
 
163
168
  Returns:
164
169
  --------
@@ -186,6 +191,7 @@ class TemporalRetriever(GraphCompletionRetriever):
186
191
  user_prompt_path=self.user_prompt_path,
187
192
  system_prompt_path=self.system_prompt_path,
188
193
  conversation_history=conversation_history,
194
+ response_model=response_model,
189
195
  ),
190
196
  )
191
197
  else:
@@ -194,6 +200,7 @@ class TemporalRetriever(GraphCompletionRetriever):
194
200
  context=context,
195
201
  user_prompt_path=self.user_prompt_path,
196
202
  system_prompt_path=self.system_prompt_path,
203
+ response_model=response_model,
197
204
  )
198
205
 
199
206
  if session_save:
@@ -71,7 +71,7 @@ async def get_memory_fragment(
71
71
  await memory_fragment.project_graph_from_db(
72
72
  graph_engine,
73
73
  node_properties_to_project=properties_to_project,
74
- edge_properties_to_project=["relationship_name"],
74
+ edge_properties_to_project=["relationship_name", "edge_text"],
75
75
  node_type=node_type,
76
76
  node_name=node_name,
77
77
  )
@@ -3,7 +3,7 @@ from cognee.infrastructure.llm.LLMGateway import LLMGateway
3
3
  from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt
4
4
 
5
5
 
6
- async def generate_structured_completion(
6
+ async def generate_completion(
7
7
  query: str,
8
8
  context: str,
9
9
  user_prompt_path: str,
@@ -12,7 +12,7 @@ async def generate_structured_completion(
12
12
  conversation_history: Optional[str] = None,
13
13
  response_model: Type = str,
14
14
  ) -> Any:
15
- """Generates a structured completion using LLM with given context and prompts."""
15
+ """Generates a completion using LLM with given context and prompts."""
16
16
  args = {"question": query, "context": context}
17
17
  user_prompt = render_prompt(user_prompt_path, args)
18
18
  system_prompt = system_prompt if system_prompt else read_query_prompt(system_prompt_path)
@@ -28,26 +28,6 @@ async def generate_structured_completion(
28
28
  )
29
29
 
30
30
 
31
- async def generate_completion(
32
- query: str,
33
- context: str,
34
- user_prompt_path: str,
35
- system_prompt_path: str,
36
- system_prompt: Optional[str] = None,
37
- conversation_history: Optional[str] = None,
38
- ) -> str:
39
- """Generates a completion using LLM with given context and prompts."""
40
- return await generate_structured_completion(
41
- query=query,
42
- context=context,
43
- user_prompt_path=user_prompt_path,
44
- system_prompt_path=system_prompt_path,
45
- system_prompt=system_prompt,
46
- conversation_history=conversation_history,
47
- response_model=str,
48
- )
49
-
50
-
51
31
  async def summarize_text(
52
32
  text: str,
53
33
  system_prompt_path: str = "summarize_search_results.txt",
@@ -0,0 +1 @@
1
+ from .run_custom_pipeline import run_custom_pipeline
@@ -0,0 +1,69 @@
1
+ from typing import Union, Optional, List, Type, Any
2
+ from uuid import UUID
3
+
4
+ from cognee.shared.logging_utils import get_logger
5
+
6
+ from cognee.modules.pipelines import run_pipeline
7
+ from cognee.modules.pipelines.tasks.task import Task
8
+ from cognee.modules.users.models import User
9
+ from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
10
+
11
+ logger = get_logger()
12
+
13
+
14
+ async def run_custom_pipeline(
15
+ tasks: Union[List[Task], List[str]] = None,
16
+ data: Any = None,
17
+ dataset: Union[str, UUID] = "main_dataset",
18
+ user: User = None,
19
+ vector_db_config: Optional[dict] = None,
20
+ graph_db_config: Optional[dict] = None,
21
+ data_per_batch: int = 20,
22
+ run_in_background: bool = False,
23
+ pipeline_name: str = "custom_pipeline",
24
+ ):
25
+ """
26
+ Custom pipeline in Cognee, can work with already built graphs. Data needs to be provided which can be processed
27
+ with provided tasks.
28
+
29
+ Provided tasks and data will be arranged to run the Cognee pipeline and execute graph enrichment/creation.
30
+
31
+ This is the core processing step in Cognee that converts raw text and documents
32
+ into an intelligent knowledge graph. It analyzes content, extracts entities and
33
+ relationships, and creates semantic connections for enhanced search and reasoning.
34
+
35
+ Args:
36
+ tasks: List of Cognee Tasks to execute.
37
+ data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used.
38
+ Data provided here will be forwarded to the first extraction task in the pipeline as input.
39
+ dataset: Dataset name or dataset uuid to process.
40
+ user: User context for authentication and data access. Uses default if None.
41
+ vector_db_config: Custom vector database configuration for embeddings storage.
42
+ graph_db_config: Custom graph database configuration for relationship storage.
43
+ data_per_batch: Number of data items to be processed in parallel.
44
+ run_in_background: If True, starts processing asynchronously and returns immediately.
45
+ If False, waits for completion before returning.
46
+ Background mode recommended for large datasets (>100MB).
47
+ Use pipeline_run_id from return value to monitor progress.
48
+ """
49
+
50
+ custom_tasks = [
51
+ *tasks,
52
+ ]
53
+
54
+ # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
55
+ pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
56
+
57
+ # Run the run_pipeline in the background or blocking based on executor
58
+ return await pipeline_executor_func(
59
+ pipeline=run_pipeline,
60
+ tasks=custom_tasks,
61
+ user=user,
62
+ data=data,
63
+ datasets=dataset,
64
+ vector_db_config=vector_db_config,
65
+ graph_db_config=graph_db_config,
66
+ incremental_loading=False,
67
+ data_per_batch=data_per_batch,
68
+ pipeline_name=pipeline_name,
69
+ )
@@ -1,4 +1,3 @@
1
- import os
2
1
  import json
3
2
  import asyncio
4
3
  from uuid import UUID
@@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
9
8
  from cognee.shared.logging_utils import get_logger
10
9
  from cognee.shared.utils import send_telemetry
11
10
  from cognee.context_global_variables import set_database_global_context_variables
11
+ from cognee.context_global_variables import backend_access_control_enabled
12
12
 
13
13
  from cognee.modules.engine.models.node_set import NodeSet
14
14
  from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@@ -74,7 +74,7 @@ async def search(
74
74
  )
75
75
 
76
76
  # Use search function filtered by permissions if access control is enabled
77
- if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
77
+ if backend_access_control_enabled():
78
78
  search_results = await authorized_search(
79
79
  query_type=query_type,
80
80
  query_text=query_text,
@@ -156,7 +156,7 @@ async def search(
156
156
  )
157
157
  else:
158
158
  # This is for maintaining backwards compatibility
159
- if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
159
+ if backend_access_control_enabled():
160
160
  return_value = []
161
161
  for search_result in search_results:
162
162
  prepared_search_results = await prepare_search_result(search_result)
@@ -172,6 +172,7 @@ async def search(
172
172
  "search_result": [context] if context else None,
173
173
  "dataset_id": datasets[0].id,
174
174
  "dataset_name": datasets[0].name,
175
+ "dataset_tenant_id": datasets[0].tenant_id,
175
176
  "graphs": graphs,
176
177
  }
177
178
  )
@@ -181,6 +182,7 @@ async def search(
181
182
  "search_result": [result] if result else None,
182
183
  "dataset_id": datasets[0].id,
183
184
  "dataset_name": datasets[0].name,
185
+ "dataset_tenant_id": datasets[0].tenant_id,
184
186
  "graphs": graphs,
185
187
  }
186
188
  )
@@ -18,7 +18,6 @@ from typing import Optional
18
18
  async def create_user(
19
19
  email: str,
20
20
  password: str,
21
- tenant_id: Optional[str] = None,
22
21
  is_superuser: bool = False,
23
22
  is_active: bool = True,
24
23
  is_verified: bool = False,
@@ -30,37 +29,23 @@ async def create_user(
30
29
  async with relational_engine.get_async_session() as session:
31
30
  async with get_user_db_context(session) as user_db:
32
31
  async with get_user_manager_context(user_db) as user_manager:
33
- if tenant_id:
34
- # Check if the tenant already exists
35
- result = await session.execute(select(Tenant).where(Tenant.id == tenant_id))
36
- tenant = result.scalars().first()
37
- if not tenant:
38
- raise TenantNotFoundError
39
-
40
- user = await user_manager.create(
41
- UserCreate(
42
- email=email,
43
- password=password,
44
- tenant_id=tenant.id,
45
- is_superuser=is_superuser,
46
- is_active=is_active,
47
- is_verified=is_verified,
48
- )
49
- )
50
- else:
51
- user = await user_manager.create(
52
- UserCreate(
53
- email=email,
54
- password=password,
55
- is_superuser=is_superuser,
56
- is_active=is_active,
57
- is_verified=is_verified,
58
- )
32
+ user = await user_manager.create(
33
+ UserCreate(
34
+ email=email,
35
+ password=password,
36
+ is_superuser=is_superuser,
37
+ is_active=is_active,
38
+ is_verified=is_verified,
59
39
  )
40
+ )
60
41
 
61
42
  if auto_login:
62
43
  await session.refresh(user)
63
44
 
45
+ # Update tenants and roles information for User object
46
+ _ = await user.awaitable_attrs.tenants
47
+ _ = await user.awaitable_attrs.roles
48
+
64
49
  return user
65
50
  except UserAlreadyExists as error:
66
51
  print(f"User {email} already exists")
@@ -5,6 +5,7 @@ from ..models import User
5
5
  from ..get_fastapi_users import get_fastapi_users
6
6
  from .get_default_user import get_default_user
7
7
  from cognee.shared.logging_utils import get_logger
8
+ from cognee.context_global_variables import backend_access_control_enabled
8
9
 
9
10
 
10
11
  logger = get_logger("get_authenticated_user")
@@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
12
13
  # Check environment variable to determine authentication requirement
13
14
  REQUIRE_AUTHENTICATION = (
14
15
  os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
15
- or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
16
+ or backend_access_control_enabled()
16
17
  )
17
18
 
18
19
  fastapi_users = get_fastapi_users()
@@ -10,7 +10,7 @@ from cognee.infrastructure.databases.relational import get_relational_engine
10
10
  from cognee.modules.users.methods.create_default_user import create_default_user
11
11
 
12
12
 
13
- async def get_default_user() -> SimpleNamespace:
13
+ async def get_default_user() -> User:
14
14
  db_engine = get_relational_engine()
15
15
  base_config = get_base_config()
16
16
  default_email = base_config.default_user_email or "default_user@example.com"
@@ -18,7 +18,9 @@ async def get_default_user() -> SimpleNamespace:
18
18
  try:
19
19
  async with db_engine.get_async_session() as session:
20
20
  query = (
21
- select(User).options(selectinload(User.roles)).where(User.email == default_email)
21
+ select(User)
22
+ .options(selectinload(User.roles), selectinload(User.tenants))
23
+ .where(User.email == default_email)
22
24
  )
23
25
 
24
26
  result = await session.execute(query)
@@ -14,7 +14,7 @@ async def get_user(user_id: UUID):
14
14
  user = (
15
15
  await session.execute(
16
16
  select(User)
17
- .options(selectinload(User.roles), selectinload(User.tenant))
17
+ .options(selectinload(User.roles), selectinload(User.tenants))
18
18
  .where(User.id == user_id)
19
19
  )
20
20
  ).scalar()
@@ -13,7 +13,7 @@ async def get_user_by_email(user_email: str):
13
13
  user = (
14
14
  await session.execute(
15
15
  select(User)
16
- .options(joinedload(User.roles), joinedload(User.tenant))
16
+ .options(joinedload(User.roles), joinedload(User.tenants))
17
17
  .where(User.email == user_email)
18
18
  )
19
19
  ).scalar()
@@ -15,5 +15,14 @@ class DatasetDatabase(Base):
15
15
  vector_database_name = Column(String, unique=True, nullable=False)
16
16
  graph_database_name = Column(String, unique=True, nullable=False)
17
17
 
18
+ vector_database_provider = Column(String, unique=False, nullable=False)
19
+ graph_database_provider = Column(String, unique=False, nullable=False)
20
+
21
+ vector_database_url = Column(String, unique=False, nullable=True)
22
+ graph_database_url = Column(String, unique=False, nullable=True)
23
+
24
+ vector_database_key = Column(String, unique=False, nullable=True)
25
+ graph_database_key = Column(String, unique=False, nullable=True)
26
+
18
27
  created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
19
28
  updated_at = Column(DateTime(timezone=True), onupdate=lambda: datetime.now(timezone.utc))
@@ -1,7 +1,7 @@
1
- from sqlalchemy.orm import relationship
1
+ from sqlalchemy.orm import relationship, Mapped
2
2
  from sqlalchemy import Column, String, ForeignKey, UUID
3
3
  from .Principal import Principal
4
- from .User import User
4
+ from .UserTenant import UserTenant
5
5
  from .Role import Role
6
6
 
7
7
 
@@ -13,14 +13,13 @@ class Tenant(Principal):
13
13
 
14
14
  owner_id = Column(UUID, index=True)
15
15
 
16
- # One-to-Many relationship with User; specify the join via User.tenant_id
17
- users = relationship(
16
+ users: Mapped[list["User"]] = relationship( # noqa: F821
18
17
  "User",
19
- back_populates="tenant",
20
- foreign_keys=lambda: [User.tenant_id],
18
+ secondary=UserTenant.__tablename__,
19
+ back_populates="tenants",
21
20
  )
22
21
 
23
- # One-to-Many relationship with Role (if needed; similar fix)
22
+ # One-to-Many relationship with Role
24
23
  roles = relationship(
25
24
  "Role",
26
25
  back_populates="tenant",
@@ -6,8 +6,10 @@ from sqlalchemy import ForeignKey, Column, UUID
6
6
  from sqlalchemy.orm import relationship, Mapped
7
7
 
8
8
  from .Principal import Principal
9
+ from .UserTenant import UserTenant
9
10
  from .UserRole import UserRole
10
11
  from .Role import Role
12
+ from .Tenant import Tenant
11
13
 
12
14
 
13
15
  class User(SQLAlchemyBaseUserTableUUID, Principal):
@@ -15,7 +17,7 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
15
17
 
16
18
  id = Column(UUID, ForeignKey("principals.id", ondelete="CASCADE"), primary_key=True)
17
19
 
18
- # Foreign key to Tenant (Many-to-One relationship)
20
+ # Foreign key to current Tenant (Many-to-One relationship)
19
21
  tenant_id = Column(UUID, ForeignKey("tenants.id"))
20
22
 
21
23
  # Many-to-Many Relationship with Roles
@@ -25,11 +27,11 @@ class User(SQLAlchemyBaseUserTableUUID, Principal):
25
27
  back_populates="users",
26
28
  )
27
29
 
28
- # Relationship to Tenant
29
- tenant = relationship(
30
+ # Many-to-Many Relationship with Tenants user is a part of
31
+ tenants: Mapped[list["Tenant"]] = relationship(
30
32
  "Tenant",
33
+ secondary=UserTenant.__tablename__,
31
34
  back_populates="users",
32
- foreign_keys=[tenant_id],
33
35
  )
34
36
 
35
37
  # ACL Relationship (One-to-Many)
@@ -46,7 +48,6 @@ class UserRead(schemas.BaseUser[uuid_UUID]):
46
48
 
47
49
 
48
50
  class UserCreate(schemas.BaseUserCreate):
49
- tenant_id: Optional[uuid_UUID] = None
50
51
  is_verified: bool = True
51
52
 
52
53
 
@@ -0,0 +1,12 @@
1
+ from datetime import datetime, timezone
2
+ from sqlalchemy import Column, ForeignKey, DateTime, UUID
3
+ from cognee.infrastructure.databases.relational import Base
4
+
5
+
6
+ class UserTenant(Base):
7
+ __tablename__ = "user_tenants"
8
+
9
+ created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
10
+
11
+ user_id = Column(UUID, ForeignKey("users.id"), primary_key=True)
12
+ tenant_id = Column(UUID, ForeignKey("tenants.id"), primary_key=True)
@@ -1,6 +1,7 @@
1
1
  from .User import User
2
2
  from .Role import Role
3
3
  from .UserRole import UserRole
4
+ from .UserTenant import UserTenant
4
5
  from .DatasetDatabase import DatasetDatabase
5
6
  from .RoleDefaultPermissions import RoleDefaultPermissions
6
7
  from .UserDefaultPermissions import UserDefaultPermissions
@@ -1,11 +1,8 @@
1
- from types import SimpleNamespace
2
-
3
1
  from cognee.shared.logging_utils import get_logger
4
2
 
5
3
  from ...models.User import User
6
4
  from cognee.modules.data.models.Dataset import Dataset
7
5
  from cognee.modules.users.permissions.methods import get_principal_datasets
8
- from cognee.modules.users.permissions.methods import get_role, get_tenant
9
6
 
10
7
  logger = get_logger()
11
8
 
@@ -25,17 +22,14 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
25
22
  # Get all datasets User has explicit access to
26
23
  datasets.extend(await get_principal_datasets(user, permission_type))
27
24
 
28
- if user.tenant_id:
29
- # Get all datasets all tenants have access to
30
- tenant = await get_tenant(user.tenant_id)
25
+ # Get all tenants user is a part of
26
+ tenants = await user.awaitable_attrs.tenants
27
+ for tenant in tenants:
28
+ # Get all datasets all tenant members have access to
31
29
  datasets.extend(await get_principal_datasets(tenant, permission_type))
32
30
 
33
- # Get all datasets Users roles have access to
34
- if isinstance(user, SimpleNamespace):
35
- # If simple namespace use roles defined in user
36
- roles = user.roles
37
- else:
38
- roles = await user.awaitable_attrs.roles
31
+ # Get all datasets accessible by roles user is a part of
32
+ roles = await user.awaitable_attrs.roles
39
33
  for role in roles:
40
34
  datasets.extend(await get_principal_datasets(role, permission_type))
41
35
 
@@ -45,4 +39,10 @@ async def get_all_user_permission_datasets(user: User, permission_type: str) ->
45
39
  # If the dataset id key already exists, leave the dictionary unchanged.
46
40
  unique.setdefault(dataset.id, dataset)
47
41
 
48
- return list(unique.values())
42
+ # Filter out dataset that aren't part of the selected user's tenant
43
+ filtered_datasets = []
44
+ for dataset in list(unique.values()):
45
+ if dataset.tenant_id == user.tenant_id:
46
+ filtered_datasets.append(dataset)
47
+
48
+ return filtered_datasets
@@ -42,11 +42,13 @@ async def add_user_to_role(user_id: UUID, role_id: UUID, owner_id: UUID):
42
42
  .first()
43
43
  )
44
44
 
45
+ user_tenants = await user.awaitable_attrs.tenants
46
+
45
47
  if not user:
46
48
  raise UserNotFoundError
47
49
  elif not role:
48
50
  raise RoleNotFoundError
49
- elif user.tenant_id != role.tenant_id:
51
+ elif role.tenant_id not in [tenant.id for tenant in user_tenants]:
50
52
  raise TenantNotFoundError(
51
53
  message="User tenant does not match role tenant. User cannot be added to role."
52
54
  )
@@ -1,2 +1,3 @@
1
1
  from .create_tenant import create_tenant
2
2
  from .add_user_to_tenant import add_user_to_tenant
3
+ from .select_tenant import select_tenant
@@ -1,8 +1,11 @@
1
+ from typing import Optional
1
2
  from uuid import UUID
2
3
  from sqlalchemy.exc import IntegrityError
4
+ from sqlalchemy import insert
3
5
 
4
6
  from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
5
7
  from cognee.infrastructure.databases.relational import get_relational_engine
8
+ from cognee.modules.users.models.UserTenant import UserTenant
6
9
  from cognee.modules.users.methods import get_user
7
10
  from cognee.modules.users.permissions.methods import get_tenant
8
11
  from cognee.modules.users.exceptions import (
@@ -12,14 +15,19 @@ from cognee.modules.users.exceptions import (
12
15
  )
13
16
 
14
17
 
15
- async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
18
+ async def add_user_to_tenant(
19
+ user_id: UUID, tenant_id: UUID, owner_id: UUID, set_as_active_tenant: Optional[bool] = False
20
+ ):
16
21
  """
17
22
  Add a user with the given id to the tenant with the given id.
18
23
  This can only be successful if the request owner with the given id is the tenant owner.
24
+
25
+ If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
19
26
  Args:
20
27
  user_id: Id of the user.
21
28
  tenant_id: Id of the tenant.
22
29
  owner_id: Id of the request owner.
30
+ set_as_active_tenant: If set_as_active_tenant is true it will automatically set the users active tenant to provided tenant.
23
31
 
24
32
  Returns:
25
33
  None
@@ -40,17 +48,18 @@ async def add_user_to_tenant(user_id: UUID, tenant_id: UUID, owner_id: UUID):
40
48
  message="Only tenant owner can add other users to organization."
41
49
  )
42
50
 
43
- try:
44
- if user.tenant_id is None:
45
- user.tenant_id = tenant_id
46
- elif user.tenant_id == tenant_id:
47
- return
48
- else:
49
- raise IntegrityError
50
-
51
+ if set_as_active_tenant:
52
+ user.tenant_id = tenant_id
51
53
  await session.merge(user)
52
54
  await session.commit()
53
- except IntegrityError:
54
- raise EntityAlreadyExistsError(
55
- message="User is already part of a tenant. Only one tenant can be assigned to user."
55
+
56
+ try:
57
+ # Add association directly to the association table
58
+ create_user_tenant_statement = insert(UserTenant).values(
59
+ user_id=user_id, tenant_id=tenant_id
56
60
  )
61
+ await session.execute(create_user_tenant_statement)
62
+ await session.commit()
63
+
64
+ except IntegrityError:
65
+ raise EntityAlreadyExistsError(message="User is already part of group.")
@@ -1,19 +1,25 @@
1
1
  from uuid import UUID
2
+ from sqlalchemy import insert
2
3
  from sqlalchemy.exc import IntegrityError
4
+ from typing import Optional
3
5
 
6
+ from cognee.modules.users.models.UserTenant import UserTenant
4
7
  from cognee.infrastructure.databases.exceptions import EntityAlreadyExistsError
5
8
  from cognee.infrastructure.databases.relational import get_relational_engine
6
9
  from cognee.modules.users.models import Tenant
7
10
  from cognee.modules.users.methods import get_user
8
11
 
9
12
 
10
- async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
13
+ async def create_tenant(
14
+ tenant_name: str, user_id: UUID, set_as_active_tenant: Optional[bool] = True
15
+ ) -> UUID:
11
16
  """
12
17
  Create a new tenant with the given name, for the user with the given id.
13
18
  This user is the owner of the tenant.
14
19
  Args:
15
20
  tenant_name: Name of the new tenant.
16
21
  user_id: Id of the user.
22
+ set_as_active_tenant: If true, set the newly created tenant as the active tenant for the user.
17
23
 
18
24
  Returns:
19
25
  None
@@ -22,18 +28,26 @@ async def create_tenant(tenant_name: str, user_id: UUID) -> UUID:
22
28
  async with db_engine.get_async_session() as session:
23
29
  try:
24
30
  user = await get_user(user_id)
25
- if user.tenant_id:
26
- raise EntityAlreadyExistsError(
27
- message="User already has a tenant. New tenant cannot be created."
28
- )
29
31
 
30
32
  tenant = Tenant(name=tenant_name, owner_id=user_id)
31
33
  session.add(tenant)
32
34
  await session.flush()
33
35
 
34
- user.tenant_id = tenant.id
35
- await session.merge(user)
36
- await session.commit()
36
+ if set_as_active_tenant:
37
+ user.tenant_id = tenant.id
38
+ await session.merge(user)
39
+ await session.commit()
40
+
41
+ try:
42
+ # Add association directly to the association table
43
+ create_user_tenant_statement = insert(UserTenant).values(
44
+ user_id=user_id, tenant_id=tenant.id
45
+ )
46
+ await session.execute(create_user_tenant_statement)
47
+ await session.commit()
48
+ except IntegrityError:
49
+ raise EntityAlreadyExistsError(message="User is already part of tenant.")
50
+
37
51
  return tenant.id
38
52
  except IntegrityError as e:
39
53
  raise EntityAlreadyExistsError(message="Tenant already exists.") from e