cognee 0.4.1__py3-none-any.whl → 0.5.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. cognee/__init__.py +1 -0
  2. cognee/api/client.py +8 -0
  3. cognee/api/v1/add/routers/get_add_router.py +3 -1
  4. cognee/api/v1/cognify/routers/get_cognify_router.py +28 -1
  5. cognee/api/v1/ontologies/__init__.py +4 -0
  6. cognee/api/v1/ontologies/ontologies.py +183 -0
  7. cognee/api/v1/ontologies/routers/__init__.py +0 -0
  8. cognee/api/v1/ontologies/routers/get_ontology_router.py +107 -0
  9. cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
  10. cognee/cli/commands/cognify_command.py +8 -1
  11. cognee/cli/config.py +1 -1
  12. cognee/context_global_variables.py +41 -9
  13. cognee/infrastructure/databases/cache/config.py +3 -1
  14. cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
  15. cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
  16. cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
  17. cognee/infrastructure/databases/graph/config.py +4 -0
  18. cognee/infrastructure/databases/graph/get_graph_engine.py +2 -0
  19. cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
  20. cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +37 -3
  21. cognee/infrastructure/databases/vector/config.py +3 -0
  22. cognee/infrastructure/databases/vector/create_vector_engine.py +5 -1
  23. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +1 -4
  24. cognee/infrastructure/engine/models/Edge.py +13 -1
  25. cognee/infrastructure/files/utils/guess_file_type.py +4 -0
  26. cognee/infrastructure/llm/config.py +2 -0
  27. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +5 -2
  28. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +7 -1
  29. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +7 -1
  30. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +8 -16
  31. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +12 -2
  32. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +13 -2
  33. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +5 -2
  34. cognee/infrastructure/loaders/LoaderEngine.py +1 -0
  35. cognee/infrastructure/loaders/core/__init__.py +2 -1
  36. cognee/infrastructure/loaders/core/csv_loader.py +93 -0
  37. cognee/infrastructure/loaders/core/text_loader.py +1 -2
  38. cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
  39. cognee/infrastructure/loaders/supported_loaders.py +2 -1
  40. cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
  41. cognee/modules/chunking/CsvChunker.py +35 -0
  42. cognee/modules/chunking/models/DocumentChunk.py +2 -1
  43. cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
  44. cognee/modules/data/methods/__init__.py +1 -0
  45. cognee/modules/data/methods/create_dataset.py +4 -2
  46. cognee/modules/data/methods/get_dataset_ids.py +5 -1
  47. cognee/modules/data/methods/get_unique_data_id.py +68 -0
  48. cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
  49. cognee/modules/data/models/Dataset.py +2 -0
  50. cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
  51. cognee/modules/data/processing/document_types/__init__.py +1 -0
  52. cognee/modules/graph/cognee_graph/CogneeGraph.py +4 -2
  53. cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
  54. cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
  55. cognee/modules/ingestion/identify.py +4 -4
  56. cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
  57. cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
  58. cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
  59. cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
  60. cognee/modules/retrieval/base_graph_retriever.py +7 -3
  61. cognee/modules/retrieval/base_retriever.py +7 -3
  62. cognee/modules/retrieval/completion_retriever.py +11 -4
  63. cognee/modules/retrieval/graph_completion_context_extension_retriever.py +6 -2
  64. cognee/modules/retrieval/graph_completion_cot_retriever.py +14 -51
  65. cognee/modules/retrieval/graph_completion_retriever.py +4 -1
  66. cognee/modules/retrieval/temporal_retriever.py +9 -2
  67. cognee/modules/retrieval/utils/brute_force_triplet_search.py +1 -1
  68. cognee/modules/retrieval/utils/completion.py +2 -22
  69. cognee/modules/run_custom_pipeline/__init__.py +1 -0
  70. cognee/modules/run_custom_pipeline/run_custom_pipeline.py +69 -0
  71. cognee/modules/search/methods/search.py +5 -3
  72. cognee/modules/users/methods/create_user.py +12 -27
  73. cognee/modules/users/methods/get_authenticated_user.py +2 -1
  74. cognee/modules/users/methods/get_default_user.py +4 -2
  75. cognee/modules/users/methods/get_user.py +1 -1
  76. cognee/modules/users/methods/get_user_by_email.py +1 -1
  77. cognee/modules/users/models/DatasetDatabase.py +9 -0
  78. cognee/modules/users/models/Tenant.py +6 -7
  79. cognee/modules/users/models/User.py +6 -5
  80. cognee/modules/users/models/UserTenant.py +12 -0
  81. cognee/modules/users/models/__init__.py +1 -0
  82. cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
  83. cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
  84. cognee/modules/users/tenants/methods/__init__.py +1 -0
  85. cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
  86. cognee/modules/users/tenants/methods/create_tenant.py +22 -8
  87. cognee/modules/users/tenants/methods/select_tenant.py +62 -0
  88. cognee/shared/logging_utils.py +2 -0
  89. cognee/tasks/chunks/__init__.py +1 -0
  90. cognee/tasks/chunks/chunk_by_row.py +94 -0
  91. cognee/tasks/documents/classify_documents.py +2 -0
  92. cognee/tasks/feedback/generate_improved_answers.py +3 -3
  93. cognee/tasks/ingestion/ingest_data.py +1 -1
  94. cognee/tasks/memify/__init__.py +2 -0
  95. cognee/tasks/memify/cognify_session.py +41 -0
  96. cognee/tasks/memify/extract_user_sessions.py +73 -0
  97. cognee/tasks/storage/index_data_points.py +33 -22
  98. cognee/tasks/storage/index_graph_edges.py +37 -57
  99. cognee/tests/integration/documents/CsvDocument_test.py +70 -0
  100. cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
  101. cognee/tests/test_add_docling_document.py +2 -2
  102. cognee/tests/test_cognee_server_start.py +84 -1
  103. cognee/tests/test_conversation_history.py +45 -4
  104. cognee/tests/test_data/example_with_header.csv +3 -0
  105. cognee/tests/test_delete_bmw_example.py +60 -0
  106. cognee/tests/test_edge_ingestion.py +27 -0
  107. cognee/tests/test_feedback_enrichment.py +1 -1
  108. cognee/tests/test_library.py +6 -4
  109. cognee/tests/test_load.py +62 -0
  110. cognee/tests/test_multi_tenancy.py +165 -0
  111. cognee/tests/test_parallel_databases.py +2 -0
  112. cognee/tests/test_relational_db_migration.py +54 -2
  113. cognee/tests/test_search_db.py +7 -1
  114. cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
  115. cognee/tests/unit/api/test_ontology_endpoint.py +264 -0
  116. cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
  117. cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
  118. cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
  119. cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
  120. cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
  121. cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
  122. cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
  123. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
  124. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
  125. cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
  126. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
  127. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
  128. cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
  129. cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
  130. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/METADATA +88 -71
  131. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/RECORD +135 -104
  132. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/WHEEL +1 -1
  133. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/entry_points.txt +0 -1
  134. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/LICENSE +0 -0
  135. {cognee-0.4.1.dist-info → cognee-0.5.0.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -0,0 +1,62 @@
1
+ from uuid import UUID
2
+ from typing import Union
3
+
4
+ import sqlalchemy.exc
5
+ from sqlalchemy import select
6
+
7
+ from cognee.infrastructure.databases.relational import get_relational_engine
8
+ from cognee.modules.users.methods.get_user import get_user
9
+ from cognee.modules.users.models.UserTenant import UserTenant
10
+ from cognee.modules.users.models.User import User
11
+ from cognee.modules.users.permissions.methods import get_tenant
12
+ from cognee.modules.users.exceptions import UserNotFoundError, TenantNotFoundError
13
+
14
+
15
+ async def select_tenant(user_id: UUID, tenant_id: Union[UUID, None]) -> User:
16
+ """
17
+ Set the users active tenant to provided tenant.
18
+
19
+ If None tenant_id is provided set current Tenant to the default single user-tenant
20
+ Args:
21
+ user_id: UUID of the user.
22
+ tenant_id: Id of the tenant.
23
+
24
+ Returns:
25
+ None
26
+
27
+ """
28
+ db_engine = get_relational_engine()
29
+ async with db_engine.get_async_session() as session:
30
+ user = await get_user(user_id)
31
+ if tenant_id is None:
32
+ # If no tenant_id is provided set current Tenant to the single user-tenant
33
+ user.tenant_id = None
34
+ await session.merge(user)
35
+ await session.commit()
36
+ return user
37
+
38
+ tenant = await get_tenant(tenant_id)
39
+
40
+ if not user:
41
+ raise UserNotFoundError
42
+ elif not tenant:
43
+ raise TenantNotFoundError
44
+
45
+ # Check if User is part of Tenant
46
+ result = await session.execute(
47
+ select(UserTenant)
48
+ .where(UserTenant.user_id == user.id)
49
+ .where(UserTenant.tenant_id == tenant_id)
50
+ )
51
+
52
+ try:
53
+ result = result.scalar_one()
54
+ except sqlalchemy.exc.NoResultFound as e:
55
+ raise TenantNotFoundError("User is not part of the tenant.") from e
56
+
57
+ if result:
58
+ # If user is part of tenant update current tenant of user
59
+ user.tenant_id = tenant_id
60
+ await session.merge(user)
61
+ await session.commit()
62
+ return user
@@ -450,6 +450,8 @@ def setup_logging(log_level=None, name=None):
450
450
  try:
451
451
  msg = self.format(record)
452
452
  stream = self.stream
453
+ if hasattr(stream, "closed") and stream.closed:
454
+ return
453
455
  stream.write("\n" + msg + self.terminator)
454
456
  self.flush()
455
457
  except Exception:
@@ -1,4 +1,5 @@
1
1
  from .chunk_by_word import chunk_by_word
2
2
  from .chunk_by_sentence import chunk_by_sentence
3
3
  from .chunk_by_paragraph import chunk_by_paragraph
4
+ from .chunk_by_row import chunk_by_row
4
5
  from .remove_disconnected_chunks import remove_disconnected_chunks
@@ -0,0 +1,94 @@
1
+ from typing import Any, Dict, Iterator
2
+ from uuid import NAMESPACE_OID, uuid5
3
+
4
+ from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
5
+
6
+
7
+ def _get_pair_size(pair_text: str) -> int:
8
+ """
9
+ Calculate the size of a given text in terms of tokens.
10
+
11
+ If an embedding engine's tokenizer is available, count the tokens for the provided word.
12
+ If the tokenizer is not available, assume the word counts as one token.
13
+
14
+ Parameters:
15
+ -----------
16
+
17
+ - pair_text (str): The key:value pair text for which the token size is to be calculated.
18
+
19
+ Returns:
20
+ --------
21
+
22
+ - int: The number of tokens representing the text, typically an integer, depending
23
+ on the tokenizer's output.
24
+ """
25
+ embedding_engine = get_embedding_engine()
26
+ if embedding_engine.tokenizer:
27
+ return embedding_engine.tokenizer.count_tokens(pair_text)
28
+ else:
29
+ return 3
30
+
31
+
32
+ def chunk_by_row(
33
+ data: str,
34
+ max_chunk_size,
35
+ ) -> Iterator[Dict[str, Any]]:
36
+ """
37
+ Chunk the input text by row while enabling exact text reconstruction.
38
+
39
+ This function divides the given text data into smaller chunks on a line-by-line basis,
40
+ ensuring that the size of each chunk is less than or equal to the specified maximum
41
+ chunk size. It guarantees that when the generated chunks are concatenated, they
42
+ reproduce the original text accurately. The tokenization process is handled by
43
+ adapters compatible with the vector engine's embedding model.
44
+
45
+ Parameters:
46
+ -----------
47
+
48
+ - data (str): The input text to be chunked.
49
+ - max_chunk_size: The maximum allowed size for each chunk, in terms of tokens or
50
+ words.
51
+ """
52
+ current_chunk_list = []
53
+ chunk_index = 0
54
+ current_chunk_size = 0
55
+
56
+ lines = data.split("\n\n")
57
+ for line in lines:
58
+ pairs_text = line.split(", ")
59
+
60
+ for pair_text in pairs_text:
61
+ pair_size = _get_pair_size(pair_text)
62
+ if current_chunk_size > 0 and (current_chunk_size + pair_size > max_chunk_size):
63
+ # Yield current cut chunk
64
+ current_chunk = ", ".join(current_chunk_list)
65
+ chunk_dict = {
66
+ "text": current_chunk,
67
+ "chunk_size": current_chunk_size,
68
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
69
+ "chunk_index": chunk_index,
70
+ "cut_type": "row_cut",
71
+ }
72
+
73
+ yield chunk_dict
74
+
75
+ # Start new chunk with current pair text
76
+ current_chunk_list = []
77
+ current_chunk_size = 0
78
+ chunk_index += 1
79
+
80
+ current_chunk_list.append(pair_text)
81
+ current_chunk_size += pair_size
82
+
83
+ # Yield row chunk
84
+ current_chunk = ", ".join(current_chunk_list)
85
+ if current_chunk:
86
+ chunk_dict = {
87
+ "text": current_chunk,
88
+ "chunk_size": current_chunk_size,
89
+ "chunk_id": uuid5(NAMESPACE_OID, current_chunk),
90
+ "chunk_index": chunk_index,
91
+ "cut_type": "row_end",
92
+ }
93
+
94
+ yield chunk_dict
@@ -7,6 +7,7 @@ from cognee.modules.data.processing.document_types import (
7
7
  ImageDocument,
8
8
  TextDocument,
9
9
  UnstructuredDocument,
10
+ CsvDocument,
10
11
  )
11
12
  from cognee.modules.engine.models.node_set import NodeSet
12
13
  from cognee.modules.engine.utils.generate_node_id import generate_node_id
@@ -15,6 +16,7 @@ from cognee.tasks.documents.exceptions import WrongDataDocumentInputError
15
16
  EXTENSION_TO_DOCUMENT_CLASS = {
16
17
  "pdf": PdfDocument, # Text documents
17
18
  "txt": TextDocument,
19
+ "csv": CsvDocument,
18
20
  "docx": UnstructuredDocument,
19
21
  "doc": UnstructuredDocument,
20
22
  "odt": UnstructuredDocument,
@@ -61,7 +61,7 @@ async def _generate_improved_answer_for_single_interaction(
61
61
  )
62
62
 
63
63
  retrieved_context = await retriever.get_context(query_text)
64
- completion = await retriever.get_structured_completion(
64
+ completion = await retriever.get_completion(
65
65
  query=query_text,
66
66
  context=retrieved_context,
67
67
  response_model=ImprovedAnswerResponse,
@@ -70,9 +70,9 @@ async def _generate_improved_answer_for_single_interaction(
70
70
  new_context_text = await retriever.resolve_edges_to_text(retrieved_context)
71
71
 
72
72
  if completion:
73
- enrichment.improved_answer = completion.answer
73
+ enrichment.improved_answer = completion[0].answer
74
74
  enrichment.new_context = new_context_text
75
- enrichment.explanation = completion.explanation
75
+ enrichment.explanation = completion[0].explanation
76
76
  return enrichment
77
77
  else:
78
78
  logger.warning(
@@ -99,7 +99,7 @@ async def ingest_data(
99
99
 
100
100
  # data_id is the hash of original file contents + owner id to avoid duplicate data
101
101
 
102
- data_id = ingestion.identify(classified_data, user)
102
+ data_id = await ingestion.identify(classified_data, user)
103
103
  original_file_metadata = classified_data.get_metadata()
104
104
 
105
105
  # Find metadata from Cognee data storage text file
@@ -1,2 +1,4 @@
1
1
  from .extract_subgraph import extract_subgraph
2
2
  from .extract_subgraph_chunks import extract_subgraph_chunks
3
+ from .cognify_session import cognify_session
4
+ from .extract_user_sessions import extract_user_sessions
@@ -0,0 +1,41 @@
1
+ import cognee
2
+
3
+ from cognee.exceptions import CogneeValidationError, CogneeSystemError
4
+ from cognee.shared.logging_utils import get_logger
5
+
6
+ logger = get_logger("cognify_session")
7
+
8
+
9
+ async def cognify_session(data, dataset_id=None):
10
+ """
11
+ Process and cognify session data into the knowledge graph.
12
+
13
+ Adds session content to cognee with a dedicated "user_sessions" node set,
14
+ then triggers the cognify pipeline to extract entities and relationships
15
+ from the session data.
16
+
17
+ Args:
18
+ data: Session string containing Question, Context, and Answer information.
19
+ dataset_name: Name of dataset.
20
+
21
+ Raises:
22
+ CogneeValidationError: If data is None or empty.
23
+ CogneeSystemError: If cognee operations fail.
24
+ """
25
+ try:
26
+ if not data or (isinstance(data, str) and not data.strip()):
27
+ logger.warning("Empty session data provided to cognify_session task, skipping")
28
+ raise CogneeValidationError(message="Session data cannot be empty", log=False)
29
+
30
+ logger.info("Processing session data for cognification")
31
+
32
+ await cognee.add(data, dataset_id=dataset_id, node_set=["user_sessions_from_cache"])
33
+ logger.debug("Session data added to cognee with node_set: user_sessions")
34
+ await cognee.cognify(datasets=[dataset_id])
35
+ logger.info("Session data successfully cognified")
36
+
37
+ except CogneeValidationError:
38
+ raise
39
+ except Exception as e:
40
+ logger.error(f"Error cognifying session data: {str(e)}")
41
+ raise CogneeSystemError(message=f"Failed to cognify session data: {str(e)}", log=False)
@@ -0,0 +1,73 @@
1
+ from typing import Optional, List
2
+
3
+ from cognee.context_global_variables import session_user
4
+ from cognee.exceptions import CogneeSystemError
5
+ from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
6
+ from cognee.shared.logging_utils import get_logger
7
+ from cognee.modules.users.models import User
8
+
9
+ logger = get_logger("extract_user_sessions")
10
+
11
+
12
+ async def extract_user_sessions(
13
+ data,
14
+ session_ids: Optional[List[str]] = None,
15
+ ):
16
+ """
17
+ Extract Q&A sessions for the current user from cache.
18
+
19
+ Retrieves all Q&A triplets from specified session IDs and yields them
20
+ as formatted strings combining question, context, and answer.
21
+
22
+ Args:
23
+ data: Data passed from memify. If empty dict ({}), no external data is provided.
24
+ session_ids: Optional list of specific session IDs to extract.
25
+
26
+ Yields:
27
+ String containing session ID and all Q&A pairs formatted.
28
+
29
+ Raises:
30
+ CogneeSystemError: If cache engine is unavailable or extraction fails.
31
+ """
32
+ try:
33
+ if not data or data == [{}]:
34
+ logger.info("Fetching session metadata for current user")
35
+
36
+ user: User = session_user.get()
37
+ if not user:
38
+ raise CogneeSystemError(message="No authenticated user found in context", log=False)
39
+
40
+ user_id = str(user.id)
41
+
42
+ cache_engine = get_cache_engine()
43
+ if cache_engine is None:
44
+ raise CogneeSystemError(
45
+ message="Cache engine not available for session extraction, please enable caching in order to have sessions to save",
46
+ log=False,
47
+ )
48
+
49
+ if session_ids:
50
+ for session_id in session_ids:
51
+ try:
52
+ qa_data = await cache_engine.get_all_qas(user_id, session_id)
53
+ if qa_data:
54
+ logger.info(f"Extracted session {session_id} with {len(qa_data)} Q&A pairs")
55
+ session_string = f"Session ID: {session_id}\n\n"
56
+ for qa_pair in qa_data:
57
+ question = qa_pair.get("question", "")
58
+ answer = qa_pair.get("answer", "")
59
+ session_string += f"Question: {question}\n\nAnswer: {answer}\n\n"
60
+ yield session_string
61
+ except Exception as e:
62
+ logger.warning(f"Failed to extract session {session_id}: {str(e)}")
63
+ continue
64
+ else:
65
+ logger.info(
66
+ "No specific session_ids provided. Please specify which sessions to extract."
67
+ )
68
+
69
+ except CogneeSystemError:
70
+ raise
71
+ except Exception as e:
72
+ logger.error(f"Error extracting user sessions: {str(e)}")
73
+ raise CogneeSystemError(message=f"Failed to extract user sessions: {str(e)}", log=False)
@@ -8,47 +8,58 @@ logger = get_logger("index_data_points")
8
8
 
9
9
 
10
10
  async def index_data_points(data_points: list[DataPoint]):
11
- created_indexes = {}
12
- index_points = {}
11
+ """Index data points in the vector engine by creating embeddings for specified fields.
12
+
13
+ Process:
14
+ 1. Groups data points into a nested dict: {type_name: {field_name: [points]}}
15
+ 2. Creates vector indexes for each (type, field) combination on first encounter
16
+ 3. Batches points per (type, field) and creates async indexing tasks
17
+ 4. Executes all indexing tasks in parallel for efficient embedding generation
18
+
19
+ Args:
20
+ data_points: List of DataPoint objects to index. Each DataPoint's metadata must
21
+ contain an 'index_fields' list specifying which fields to embed.
22
+
23
+ Returns:
24
+ The original data_points list.
25
+ """
26
+ data_points_by_type = {}
13
27
 
14
28
  vector_engine = get_vector_engine()
15
29
 
16
30
  for data_point in data_points:
17
31
  data_point_type = type(data_point)
32
+ type_name = data_point_type.__name__
18
33
 
19
34
  for field_name in data_point.metadata["index_fields"]:
20
35
  if getattr(data_point, field_name, None) is None:
21
36
  continue
22
37
 
23
- index_name = f"{data_point_type.__name__}_{field_name}"
38
+ if type_name not in data_points_by_type:
39
+ data_points_by_type[type_name] = {}
24
40
 
25
- if index_name not in created_indexes:
26
- await vector_engine.create_vector_index(data_point_type.__name__, field_name)
27
- created_indexes[index_name] = True
28
-
29
- if index_name not in index_points:
30
- index_points[index_name] = []
41
+ if field_name not in data_points_by_type[type_name]:
42
+ await vector_engine.create_vector_index(type_name, field_name)
43
+ data_points_by_type[type_name][field_name] = []
31
44
 
32
45
  indexed_data_point = data_point.model_copy()
33
46
  indexed_data_point.metadata["index_fields"] = [field_name]
34
- index_points[index_name].append(indexed_data_point)
47
+ data_points_by_type[type_name][field_name].append(indexed_data_point)
35
48
 
36
- tasks: list[asyncio.Task] = []
37
49
  batch_size = vector_engine.embedding_engine.get_batch_size()
38
50
 
39
- for index_name_and_field, points in index_points.items():
40
- first = index_name_and_field.index("_")
41
- index_name = index_name_and_field[:first]
42
- field_name = index_name_and_field[first + 1 :]
51
+ batches = (
52
+ (type_name, field_name, points[i : i + batch_size])
53
+ for type_name, fields in data_points_by_type.items()
54
+ for field_name, points in fields.items()
55
+ for i in range(0, len(points), batch_size)
56
+ )
43
57
 
44
- # Create embedding requests per batch to run in parallel later
45
- for i in range(0, len(points), batch_size):
46
- batch = points[i : i + batch_size]
47
- tasks.append(
48
- asyncio.create_task(vector_engine.index_data_points(index_name, field_name, batch))
49
- )
58
+ tasks = [
59
+ asyncio.create_task(vector_engine.index_data_points(type_name, field_name, batch_points))
60
+ for type_name, field_name, batch_points in batches
61
+ ]
50
62
 
51
- # Run all embedding requests in parallel
52
63
  await asyncio.gather(*tasks)
53
64
 
54
65
  return data_points
@@ -1,17 +1,44 @@
1
- import asyncio
1
+ from collections import Counter
2
+ from typing import Optional, Dict, Any, List, Tuple, Union
2
3
 
3
4
  from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
4
5
  from cognee.shared.logging_utils import get_logger
5
- from collections import Counter
6
- from typing import Optional, Dict, Any, List, Tuple, Union
7
- from cognee.infrastructure.databases.vector import get_vector_engine
8
6
  from cognee.infrastructure.databases.graph import get_graph_engine
9
7
  from cognee.modules.graph.models.EdgeType import EdgeType
10
8
  from cognee.infrastructure.databases.graph.graph_db_interface import EdgeData
9
+ from cognee.tasks.storage.index_data_points import index_data_points
11
10
 
12
11
  logger = get_logger()
13
12
 
14
13
 
14
+ def _get_edge_text(item: dict) -> str:
15
+ """Extract edge text for embedding - prefers edge_text field with fallback."""
16
+ if "edge_text" in item:
17
+ return item["edge_text"]
18
+
19
+ if "relationship_name" in item:
20
+ return item["relationship_name"]
21
+
22
+ return ""
23
+
24
+
25
+ def create_edge_type_datapoints(edges_data) -> list[EdgeType]:
26
+ """Transform raw edge data into EdgeType datapoints."""
27
+ edge_texts = [
28
+ _get_edge_text(item)
29
+ for edge in edges_data
30
+ for item in edge
31
+ if isinstance(item, dict) and "relationship_name" in item
32
+ ]
33
+
34
+ edge_types = Counter(edge_texts)
35
+
36
+ return [
37
+ EdgeType(id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count)
38
+ for text, count in edge_types.items()
39
+ ]
40
+
41
+
15
42
  async def index_graph_edges(
16
43
  edges_data: Union[List[EdgeData], List[Tuple[str, str, str, Optional[Dict[str, Any]]]]] = None,
17
44
  ):
@@ -23,24 +50,17 @@ async def index_graph_edges(
23
50
  the `relationship_name` field.
24
51
 
25
52
  Steps:
26
- 1. Initialize the vector engine and graph engine.
27
- 2. Retrieve graph edge data and count relationship types (`relationship_name`).
28
- 3. Create vector indexes for `relationship_name` if they don't exist.
29
- 4. Transform the counted relationships into `EdgeType` objects.
30
- 5. Index the transformed data points in the vector engine.
53
+ 1. Initialize the graph engine if needed and retrieve edge data.
54
+ 2. Transform edge data into EdgeType datapoints.
55
+ 3. Index the EdgeType datapoints using the standard indexing function.
31
56
 
32
57
  Raises:
33
- RuntimeError: If initialization of the vector engine or graph engine fails.
58
+ RuntimeError: If initialization of the graph engine fails.
34
59
 
35
60
  Returns:
36
61
  None
37
62
  """
38
63
  try:
39
- created_indexes = {}
40
- index_points = {}
41
-
42
- vector_engine = get_vector_engine()
43
-
44
64
  if edges_data is None:
45
65
  graph_engine = await get_graph_engine()
46
66
  _, edges_data = await graph_engine.get_graph_data()
@@ -51,47 +71,7 @@ async def index_graph_edges(
51
71
  logger.error("Failed to initialize engines: %s", e)
52
72
  raise RuntimeError("Initialization error") from e
53
73
 
54
- edge_types = Counter(
55
- item.get("relationship_name")
56
- for edge in edges_data
57
- for item in edge
58
- if isinstance(item, dict) and "relationship_name" in item
59
- )
60
-
61
- for text, count in edge_types.items():
62
- edge = EdgeType(
63
- id=generate_edge_id(edge_id=text), relationship_name=text, number_of_edges=count
64
- )
65
- data_point_type = type(edge)
66
-
67
- for field_name in edge.metadata["index_fields"]:
68
- index_name = f"{data_point_type.__name__}.{field_name}"
69
-
70
- if index_name not in created_indexes:
71
- await vector_engine.create_vector_index(data_point_type.__name__, field_name)
72
- created_indexes[index_name] = True
73
-
74
- if index_name not in index_points:
75
- index_points[index_name] = []
76
-
77
- indexed_data_point = edge.model_copy()
78
- indexed_data_point.metadata["index_fields"] = [field_name]
79
- index_points[index_name].append(indexed_data_point)
80
-
81
- # Get maximum batch size for embedding model
82
- batch_size = vector_engine.embedding_engine.get_batch_size()
83
- tasks: list[asyncio.Task] = []
84
-
85
- for index_name, indexable_points in index_points.items():
86
- index_name, field_name = index_name.split(".")
87
-
88
- # Create embedding tasks to run in parallel later
89
- for start in range(0, len(indexable_points), batch_size):
90
- batch = indexable_points[start : start + batch_size]
91
-
92
- tasks.append(vector_engine.index_data_points(index_name, field_name, batch))
93
-
94
- # Start all embedding tasks and wait for completion
95
- await asyncio.gather(*tasks)
74
+ edge_type_datapoints = create_edge_type_datapoints(edges_data)
75
+ await index_data_points(edge_type_datapoints)
96
76
 
97
77
  return None
@@ -0,0 +1,70 @@
1
+ import os
2
+ import sys
3
+ import uuid
4
+ import pytest
5
+ import pathlib
6
+ from unittest.mock import patch
7
+
8
+ from cognee.modules.chunking.CsvChunker import CsvChunker
9
+ from cognee.modules.data.processing.document_types.CsvDocument import CsvDocument
10
+ from cognee.tests.integration.documents.AudioDocument_test import mock_get_embedding_engine
11
+ from cognee.tests.integration.documents.async_gen_zip import async_gen_zip
12
+
13
+ chunk_by_row_module = sys.modules.get("cognee.tasks.chunks.chunk_by_row")
14
+
15
+
16
+ GROUND_TRUTH = {
17
+ "chunk_size_10": [
18
+ {"token_count": 9, "len_text": 26, "cut_type": "row_cut", "chunk_index": 0},
19
+ {"token_count": 6, "len_text": 29, "cut_type": "row_end", "chunk_index": 1},
20
+ {"token_count": 9, "len_text": 25, "cut_type": "row_cut", "chunk_index": 2},
21
+ {"token_count": 6, "len_text": 30, "cut_type": "row_end", "chunk_index": 3},
22
+ ],
23
+ "chunk_size_128": [
24
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 0},
25
+ {"token_count": 15, "len_text": 57, "cut_type": "row_end", "chunk_index": 1},
26
+ ],
27
+ }
28
+
29
+
30
+ @pytest.mark.parametrize(
31
+ "input_file,chunk_size",
32
+ [("example_with_header.csv", 10), ("example_with_header.csv", 128)],
33
+ )
34
+ @patch.object(chunk_by_row_module, "get_embedding_engine", side_effect=mock_get_embedding_engine)
35
+ @pytest.mark.asyncio
36
+ async def test_CsvDocument(mock_engine, input_file, chunk_size):
37
+ # Define file paths of test data
38
+ csv_file_path = os.path.join(
39
+ pathlib.Path(__file__).parent.parent.parent,
40
+ "test_data",
41
+ input_file,
42
+ )
43
+
44
+ # Define test documents
45
+ csv_document = CsvDocument(
46
+ id=uuid.uuid4(),
47
+ name="example_with_header.csv",
48
+ raw_data_location=csv_file_path,
49
+ external_metadata="",
50
+ mime_type="text/csv",
51
+ )
52
+
53
+ # TEST CSV
54
+ ground_truth_key = f"chunk_size_{chunk_size}"
55
+ async for ground_truth, row_data in async_gen_zip(
56
+ GROUND_TRUTH[ground_truth_key],
57
+ csv_document.read(chunker_cls=CsvChunker, max_chunk_size=chunk_size),
58
+ ):
59
+ assert ground_truth["token_count"] == row_data.chunk_size, (
60
+ f'{ground_truth["token_count"] = } != {row_data.chunk_size = }'
61
+ )
62
+ assert ground_truth["len_text"] == len(row_data.text), (
63
+ f'{ground_truth["len_text"] = } != {len(row_data.text) = }'
64
+ )
65
+ assert ground_truth["cut_type"] == row_data.cut_type, (
66
+ f'{ground_truth["cut_type"] = } != {row_data.cut_type = }'
67
+ )
68
+ assert ground_truth["chunk_index"] == row_data.chunk_index, (
69
+ f'{ground_truth["chunk_index"] = } != {row_data.chunk_index = }'
70
+ )
@@ -55,7 +55,7 @@ async def main():
55
55
  classified_data = ingestion.classify(file)
56
56
 
57
57
  # data_id is the hash of original file contents + owner id to avoid duplicate data
58
- data_id = ingestion.identify(classified_data, await get_default_user())
58
+ data_id = await ingestion.identify(classified_data, await get_default_user())
59
59
 
60
60
  await cognee.add(file_path)
61
61
 
@@ -39,12 +39,12 @@ async def main():
39
39
 
40
40
  answer = await cognee.search("Do programmers change light bulbs?")
41
41
  assert len(answer) != 0
42
- lowercase_answer = answer[0].lower()
42
+ lowercase_answer = answer[0]["search_result"][0].lower()
43
43
  assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
44
44
 
45
45
  answer = await cognee.search("What colours are there in the presentation table?")
46
46
  assert len(answer) != 0
47
- lowercase_answer = answer[0].lower()
47
+ lowercase_answer = answer[0]["search_result"][0].lower()
48
48
  assert (
49
49
  ("red" in lowercase_answer)
50
50
  and ("blue" in lowercase_answer)