cognee 0.2.2.dev0__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/client.py +41 -3
- cognee/api/health.py +332 -0
- cognee/api/v1/add/add.py +5 -2
- cognee/api/v1/add/routers/get_add_router.py +3 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +3 -1
- cognee/api/v1/cognify/cognify.py +8 -0
- cognee/api/v1/cognify/routers/get_cognify_router.py +8 -1
- cognee/api/v1/config/config.py +3 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -7
- cognee/api/v1/delete/delete.py +16 -12
- cognee/api/v1/responses/routers/get_responses_router.py +3 -1
- cognee/api/v1/search/search.py +10 -0
- cognee/api/v1/settings/routers/get_settings_router.py +0 -2
- cognee/base_config.py +1 -0
- cognee/eval_framework/evaluation/direct_llm_eval_adapter.py +5 -6
- cognee/infrastructure/databases/graph/config.py +2 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +58 -12
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -10
- cognee/infrastructure/databases/graph/kuzu/adapter.py +12 -7
- cognee/infrastructure/databases/graph/kuzu/kuzu_migrate.py +1 -1
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +48 -13
- cognee/infrastructure/databases/graph/neptune_driver/__init__.py +15 -0
- cognee/infrastructure/databases/graph/neptune_driver/adapter.py +1427 -0
- cognee/infrastructure/databases/graph/neptune_driver/exceptions.py +115 -0
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +224 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +3 -3
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +449 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +8 -3
- cognee/infrastructure/databases/vector/create_vector_engine.py +31 -15
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +3 -1
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +21 -6
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -3
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +3 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +22 -16
- cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +36 -34
- cognee/infrastructure/databases/vector/vector_db_interface.py +78 -7
- cognee/infrastructure/files/utils/get_data_file_path.py +39 -0
- cognee/infrastructure/files/utils/guess_file_type.py +2 -2
- cognee/infrastructure/files/utils/open_data_file.py +4 -23
- cognee/infrastructure/llm/LLMGateway.py +137 -0
- cognee/infrastructure/llm/__init__.py +14 -4
- cognee/infrastructure/llm/config.py +29 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_hotpot_using_cognee_search.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question.txt +1 -1
- cognee/infrastructure/llm/prompts/answer_simple_question_restricted.txt +1 -1
- cognee/infrastructure/llm/prompts/categorize_categories.txt +1 -1
- cognee/infrastructure/llm/prompts/classify_content.txt +1 -1
- cognee/infrastructure/llm/prompts/context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/graph_context_for_question.txt +1 -1
- cognee/infrastructure/llm/prompts/natural_language_retriever_system.txt +1 -1
- cognee/infrastructure/llm/prompts/patch_gen_instructions.txt +1 -1
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +130 -0
- cognee/infrastructure/llm/prompts/summarize_code.txt +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/__init__.py +57 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/async_client.py +533 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/config.py +94 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/globals.py +37 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/inlinedbaml.py +21 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/parser.py +131 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/runtime.py +266 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/stream_types.py +137 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/sync_client.py +550 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/tracing.py +26 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_builder.py +962 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/type_map.py +52 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_client/types.py +166 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_categories.baml +109 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extract_content_graph.baml +343 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/baml/baml_src}/extraction/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/extract_summary.py +89 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +33 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/generators.baml +18 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +3 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_categories.py +12 -0
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/extract_summary.py +16 -7
- cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/extract_content_graph.py +7 -6
- cognee/infrastructure/llm/{anthropic → structured_output_framework/litellm_instructor/llm/anthropic}/adapter.py +10 -4
- cognee/infrastructure/llm/{gemini → structured_output_framework/litellm_instructor/llm/gemini}/adapter.py +6 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/__init__.py +0 -0
- cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/litellm_instructor/llm/generic_llm_api}/adapter.py +7 -3
- cognee/infrastructure/llm/{get_llm_client.py → structured_output_framework/litellm_instructor/llm/get_llm_client.py} +18 -6
- cognee/infrastructure/llm/{llm_interface.py → structured_output_framework/litellm_instructor/llm/llm_interface.py} +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/__init__.py +0 -0
- cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor/llm/ollama}/adapter.py +4 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/__init__.py +0 -0
- cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm/openai}/adapter.py +6 -4
- cognee/infrastructure/llm/{rate_limiter.py → structured_output_framework/litellm_instructor/llm/rate_limiter.py} +0 -5
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +4 -2
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +7 -3
- cognee/infrastructure/llm/tokenizer/__init__.py +4 -0
- cognee/infrastructure/llm/utils.py +3 -1
- cognee/infrastructure/loaders/LoaderEngine.py +156 -0
- cognee/infrastructure/loaders/LoaderInterface.py +73 -0
- cognee/infrastructure/loaders/__init__.py +18 -0
- cognee/infrastructure/loaders/core/__init__.py +7 -0
- cognee/infrastructure/loaders/core/audio_loader.py +98 -0
- cognee/infrastructure/loaders/core/image_loader.py +114 -0
- cognee/infrastructure/loaders/core/text_loader.py +90 -0
- cognee/infrastructure/loaders/create_loader_engine.py +32 -0
- cognee/infrastructure/loaders/external/__init__.py +22 -0
- cognee/infrastructure/loaders/external/pypdf_loader.py +96 -0
- cognee/infrastructure/loaders/external/unstructured_loader.py +127 -0
- cognee/infrastructure/loaders/get_loader_engine.py +18 -0
- cognee/infrastructure/loaders/supported_loaders.py +18 -0
- cognee/infrastructure/loaders/use_loader.py +21 -0
- cognee/infrastructure/loaders/utils/__init__.py +0 -0
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/get_authorized_dataset.py +23 -0
- cognee/modules/data/models/Data.py +11 -1
- cognee/modules/data/processing/document_types/AudioDocument.py +2 -2
- cognee/modules/data/processing/document_types/ImageDocument.py +2 -2
- cognee/modules/data/processing/document_types/PdfDocument.py +4 -11
- cognee/modules/engine/utils/generate_edge_id.py +5 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +9 -18
- cognee/modules/graph/methods/get_formatted_graph_data.py +7 -1
- cognee/modules/graph/utils/get_graph_from_model.py +93 -101
- cognee/modules/ingestion/data_types/TextData.py +8 -2
- cognee/modules/ingestion/save_data_to_file.py +1 -1
- cognee/modules/pipelines/exceptions/__init__.py +1 -0
- cognee/modules/pipelines/exceptions/exceptions.py +12 -0
- cognee/modules/pipelines/models/DataItemStatus.py +5 -0
- cognee/modules/pipelines/models/PipelineRunInfo.py +6 -0
- cognee/modules/pipelines/models/__init__.py +1 -0
- cognee/modules/pipelines/operations/pipeline.py +10 -2
- cognee/modules/pipelines/operations/run_tasks.py +251 -19
- cognee/modules/retrieval/code_retriever.py +3 -5
- cognee/modules/retrieval/completion_retriever.py +1 -1
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +0 -2
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +0 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +8 -9
- cognee/modules/retrieval/natural_language_retriever.py +3 -5
- cognee/modules/retrieval/utils/completion.py +6 -9
- cognee/modules/retrieval/utils/description_to_codepart_search.py +2 -3
- cognee/modules/search/methods/search.py +5 -1
- cognee/modules/search/operations/__init__.py +1 -0
- cognee/modules/search/operations/select_search_type.py +42 -0
- cognee/modules/search/types/SearchType.py +1 -0
- cognee/modules/settings/get_settings.py +0 -4
- cognee/modules/settings/save_vector_db_config.py +1 -1
- cognee/shared/data_models.py +3 -1
- cognee/shared/logging_utils.py +0 -5
- cognee/tasks/chunk_naive_llm_classifier/chunk_naive_llm_classifier.py +2 -2
- cognee/tasks/documents/extract_chunks_from_documents.py +10 -12
- cognee/tasks/entity_completion/entity_extractors/llm_entity_extractor.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_content_nodes_and_relationship_names.py +4 -6
- cognee/tasks/graph/cascade_extract/utils/extract_edge_triplets.py +6 -7
- cognee/tasks/graph/cascade_extract/utils/extract_nodes.py +4 -7
- cognee/tasks/graph/extract_graph_from_code.py +3 -2
- cognee/tasks/graph/extract_graph_from_data.py +4 -3
- cognee/tasks/graph/infer_data_ontology.py +5 -6
- cognee/tasks/ingestion/data_item_to_text_file.py +79 -0
- cognee/tasks/ingestion/ingest_data.py +91 -61
- cognee/tasks/ingestion/resolve_data_directories.py +3 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +3 -0
- cognee/tasks/storage/index_data_points.py +1 -1
- cognee/tasks/storage/index_graph_edges.py +4 -1
- cognee/tasks/summarization/summarize_code.py +2 -3
- cognee/tasks/summarization/summarize_text.py +3 -2
- cognee/tests/test_cognee_server_start.py +12 -7
- cognee/tests/test_deduplication.py +2 -2
- cognee/tests/test_deletion.py +58 -17
- cognee/tests/test_graph_visualization_permissions.py +161 -0
- cognee/tests/test_neptune_analytics_graph.py +309 -0
- cognee/tests/test_neptune_analytics_hybrid.py +176 -0
- cognee/tests/{test_qdrant.py → test_neptune_analytics_vector.py} +86 -16
- cognee/tests/test_pgvector.py +5 -5
- cognee/tests/test_s3.py +1 -6
- cognee/tests/unit/infrastructure/databases/test_rate_limiter.py +11 -10
- cognee/tests/unit/infrastructure/databases/vector/__init__.py +0 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +1 -1
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +5 -5
- cognee/tests/unit/infrastructure/test_rate_limiting_realistic.py +6 -4
- cognee/tests/unit/infrastructure/test_rate_limiting_retry.py +1 -1
- cognee/tests/unit/interfaces/graph/get_graph_from_model_unit_test.py +61 -3
- cognee/tests/unit/modules/search/search_methods_test.py +55 -0
- {cognee-0.2.2.dev0.dist-info → cognee-0.2.3.dist-info}/METADATA +12 -6
- {cognee-0.2.2.dev0.dist-info → cognee-0.2.3.dist-info}/RECORD +195 -156
- cognee/infrastructure/databases/vector/pinecone/adapter.py +0 -8
- cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +0 -514
- cognee/infrastructure/databases/vector/qdrant/__init__.py +0 -2
- cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +0 -527
- cognee/infrastructure/databases/vector/weaviate_db/__init__.py +0 -1
- cognee/modules/data/extraction/extract_categories.py +0 -14
- distributed/Dockerfile +0 -34
- distributed/app.py +0 -4
- distributed/entrypoint.py +0 -71
- distributed/entrypoint.sh +0 -5
- distributed/modal_image.py +0 -11
- distributed/queues.py +0 -5
- distributed/tasks/queued_add_data_points.py +0 -13
- distributed/tasks/queued_add_edges.py +0 -13
- distributed/tasks/queued_add_nodes.py +0 -13
- distributed/test.py +0 -28
- distributed/utils.py +0 -19
- distributed/workers/data_point_saving_worker.py +0 -93
- distributed/workers/graph_saving_worker.py +0 -104
- /cognee/infrastructure/databases/{graph/memgraph → hybrid/neptune_analytics}/__init__.py +0 -0
- /cognee/infrastructure/{llm → databases/vector/embeddings}/embedding_rate_limiter.py +0 -0
- /cognee/infrastructure/{databases/vector/pinecone → llm/structured_output_framework}/__init__.py +0 -0
- /cognee/infrastructure/llm/{anthropic → structured_output_framework/baml/baml_src}/__init__.py +0 -0
- /cognee/infrastructure/llm/{gemini/__init__.py → structured_output_framework/baml/baml_src/extraction/extract_categories.py} +0 -0
- /cognee/infrastructure/llm/{generic_llm_api → structured_output_framework/baml/baml_src/extraction/knowledge_graph}/__init__.py +0 -0
- /cognee/infrastructure/llm/{ollama → structured_output_framework/litellm_instructor}/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/knowledge_graph/__init__.py +0 -0
- /cognee/{modules/data → infrastructure/llm/structured_output_framework/litellm_instructor}/extraction/texts.json +0 -0
- /cognee/infrastructure/llm/{openai → structured_output_framework/litellm_instructor/llm}/__init__.py +0 -0
- {distributed → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic}/__init__.py +0 -0
- {distributed/tasks → cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini}/__init__.py +0 -0
- /cognee/modules/data/{extraction/knowledge_graph → methods}/add_model_class_to_graph.py +0 -0
- {cognee-0.2.2.dev0.dist-info → cognee-0.2.3.dist-info}/WHEEL +0 -0
- {cognee-0.2.2.dev0.dist-info → cognee-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.2.dev0.dist-info → cognee-0.2.3.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
import asyncio
|
|
4
|
+
import pytest
|
|
5
|
+
|
|
6
|
+
from cognee.modules.chunking.models import DocumentChunk
|
|
7
|
+
from cognee.modules.engine.models import Entity, EntityType
|
|
8
|
+
from cognee.modules.data.processing.document_types import TextDocument
|
|
9
|
+
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
|
10
|
+
from cognee.shared.logging_utils import get_logger
|
|
11
|
+
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
|
12
|
+
NeptuneAnalyticsAdapter,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# Set up Amazon credentials in .env file and get the values from environment variables
|
|
16
|
+
load_dotenv()
|
|
17
|
+
graph_id = os.getenv("GRAPH_ID", "")
|
|
18
|
+
|
|
19
|
+
# get the default embedder
|
|
20
|
+
embedding_engine = get_embedding_engine()
|
|
21
|
+
na_graph = NeptuneAnalyticsAdapter(graph_id)
|
|
22
|
+
na_vector = NeptuneAnalyticsAdapter(graph_id, embedding_engine)
|
|
23
|
+
|
|
24
|
+
collection = "test_collection"
|
|
25
|
+
|
|
26
|
+
logger = get_logger("test_neptune_analytics_hybrid")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def setup_data():
|
|
30
|
+
# Define nodes data before the main function
|
|
31
|
+
# These nodes were defined using openAI from the following prompt:
|
|
32
|
+
#
|
|
33
|
+
# Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads
|
|
34
|
+
# that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It
|
|
35
|
+
# complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load
|
|
36
|
+
# the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's
|
|
37
|
+
# stored in Amazon S3.
|
|
38
|
+
|
|
39
|
+
document = TextDocument(
|
|
40
|
+
name="text.txt",
|
|
41
|
+
raw_data_location="git/cognee/examples/database_examples/data_storage/data/text.txt",
|
|
42
|
+
external_metadata="{}",
|
|
43
|
+
mime_type="text/plain",
|
|
44
|
+
)
|
|
45
|
+
document_chunk = DocumentChunk(
|
|
46
|
+
text="Neptune Analytics is an ideal choice for investigatory, exploratory, or data-science workloads \n that require fast iteration for data, analytical and algorithmic processing, or vector search on graph data. It \n complements Amazon Neptune Database, a popular managed graph database. To perform intensive analysis, you can load \n the data from a Neptune Database graph or snapshot into Neptune Analytics. You can also load graph data that's \n stored in Amazon S3.\n ",
|
|
47
|
+
chunk_size=187,
|
|
48
|
+
chunk_index=0,
|
|
49
|
+
cut_type="paragraph_end",
|
|
50
|
+
is_part_of=document,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
graph_database = EntityType(name="graph database", description="graph database")
|
|
54
|
+
neptune_analytics_entity = Entity(
|
|
55
|
+
name="neptune analytics",
|
|
56
|
+
description="A memory-optimized graph database engine for analytics that processes large amounts of graph data quickly.",
|
|
57
|
+
)
|
|
58
|
+
neptune_database_entity = Entity(
|
|
59
|
+
name="amazon neptune database",
|
|
60
|
+
description="A popular managed graph database that complements Neptune Analytics.",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
storage = EntityType(name="storage", description="storage")
|
|
64
|
+
storage_entity = Entity(
|
|
65
|
+
name="amazon s3",
|
|
66
|
+
description="A storage service provided by Amazon Web Services that allows storing graph data.",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
nodes_data = [
|
|
70
|
+
document,
|
|
71
|
+
document_chunk,
|
|
72
|
+
graph_database,
|
|
73
|
+
neptune_analytics_entity,
|
|
74
|
+
neptune_database_entity,
|
|
75
|
+
storage,
|
|
76
|
+
storage_entity,
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
edges_data = [
|
|
80
|
+
(
|
|
81
|
+
str(document_chunk.id),
|
|
82
|
+
str(storage_entity.id),
|
|
83
|
+
"contains",
|
|
84
|
+
),
|
|
85
|
+
(
|
|
86
|
+
str(storage_entity.id),
|
|
87
|
+
str(storage.id),
|
|
88
|
+
"is_a",
|
|
89
|
+
),
|
|
90
|
+
(
|
|
91
|
+
str(document_chunk.id),
|
|
92
|
+
str(neptune_database_entity.id),
|
|
93
|
+
"contains",
|
|
94
|
+
),
|
|
95
|
+
(
|
|
96
|
+
str(neptune_database_entity.id),
|
|
97
|
+
str(graph_database.id),
|
|
98
|
+
"is_a",
|
|
99
|
+
),
|
|
100
|
+
(
|
|
101
|
+
str(document_chunk.id),
|
|
102
|
+
str(document.id),
|
|
103
|
+
"is_part_of",
|
|
104
|
+
),
|
|
105
|
+
(
|
|
106
|
+
str(document_chunk.id),
|
|
107
|
+
str(neptune_analytics_entity.id),
|
|
108
|
+
"contains",
|
|
109
|
+
),
|
|
110
|
+
(
|
|
111
|
+
str(neptune_analytics_entity.id),
|
|
112
|
+
str(graph_database.id),
|
|
113
|
+
"is_a",
|
|
114
|
+
),
|
|
115
|
+
]
|
|
116
|
+
return nodes_data, edges_data
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def test_add_graph_then_vector_data():
|
|
120
|
+
logger.info("------test_add_graph_then_vector_data-------")
|
|
121
|
+
(nodes, edges) = setup_data()
|
|
122
|
+
await na_graph.add_nodes(nodes)
|
|
123
|
+
await na_graph.add_edges(edges)
|
|
124
|
+
await na_vector.create_data_points(collection, nodes)
|
|
125
|
+
|
|
126
|
+
node_ids = [str(node.id) for node in nodes]
|
|
127
|
+
retrieved_data_points = await na_vector.retrieve(collection, node_ids)
|
|
128
|
+
retrieved_nodes = await na_graph.get_nodes(node_ids)
|
|
129
|
+
|
|
130
|
+
assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids)
|
|
131
|
+
|
|
132
|
+
# delete all nodes and edges and vectors:
|
|
133
|
+
await na_graph.delete_graph()
|
|
134
|
+
await na_vector.prune()
|
|
135
|
+
|
|
136
|
+
(nodes, edges) = await na_graph.get_graph_data()
|
|
137
|
+
assert len(nodes) == 0
|
|
138
|
+
assert len(edges) == 0
|
|
139
|
+
logger.info("------PASSED-------")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def test_add_vector_then_node_data():
|
|
143
|
+
logger.info("------test_add_vector_then_node_data-------")
|
|
144
|
+
(nodes, edges) = setup_data()
|
|
145
|
+
await na_vector.create_data_points(collection, nodes)
|
|
146
|
+
await na_graph.add_nodes(nodes)
|
|
147
|
+
await na_graph.add_edges(edges)
|
|
148
|
+
|
|
149
|
+
node_ids = [str(node.id) for node in nodes]
|
|
150
|
+
retrieved_data_points = await na_vector.retrieve(collection, node_ids)
|
|
151
|
+
retrieved_nodes = await na_graph.get_nodes(node_ids)
|
|
152
|
+
|
|
153
|
+
assert len(retrieved_data_points) == len(retrieved_nodes) == len(node_ids)
|
|
154
|
+
|
|
155
|
+
# delete all nodes and edges and vectors:
|
|
156
|
+
await na_vector.prune()
|
|
157
|
+
await na_graph.delete_graph()
|
|
158
|
+
|
|
159
|
+
(nodes, edges) = await na_graph.get_graph_data()
|
|
160
|
+
assert len(nodes) == 0
|
|
161
|
+
assert len(edges) == 0
|
|
162
|
+
logger.info("------PASSED-------")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def main():
|
|
166
|
+
"""
|
|
167
|
+
Example script uses neptune analytics for the graph and vector (hybrid) store with small sample data
|
|
168
|
+
This example demonstrates how to add nodes and vectors to Neptune Analytics, and ensures that
|
|
169
|
+
the nodes do not conflict
|
|
170
|
+
"""
|
|
171
|
+
asyncio.run(test_add_graph_then_vector_data())
|
|
172
|
+
asyncio.run(test_add_vector_then_node_data())
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
if __name__ == "__main__":
|
|
176
|
+
main()
|
|
@@ -1,26 +1,34 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pathlib
|
|
3
3
|
import cognee
|
|
4
|
-
|
|
4
|
+
import uuid
|
|
5
|
+
import pytest
|
|
5
6
|
from cognee.modules.search.operations import get_history
|
|
6
7
|
from cognee.modules.users.methods import get_default_user
|
|
7
8
|
from cognee.shared.logging_utils import get_logger
|
|
8
9
|
from cognee.modules.search.types import SearchType
|
|
10
|
+
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
11
|
+
from cognee.infrastructure.databases.hybrid.neptune_analytics.NeptuneAnalyticsAdapter import (
|
|
12
|
+
NeptuneAnalyticsAdapter,
|
|
13
|
+
IndexSchema,
|
|
14
|
+
)
|
|
9
15
|
|
|
10
16
|
logger = get_logger()
|
|
11
17
|
|
|
12
18
|
|
|
13
19
|
async def main():
|
|
14
|
-
|
|
20
|
+
graph_id = os.getenv("GRAPH_ID", "")
|
|
21
|
+
cognee.config.set_vector_db_provider("neptune_analytics")
|
|
22
|
+
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
|
15
23
|
data_directory_path = str(
|
|
16
24
|
pathlib.Path(
|
|
17
|
-
os.path.join(pathlib.Path(__file__).parent, ".data_storage/
|
|
25
|
+
os.path.join(pathlib.Path(__file__).parent, ".data_storage/test_neptune")
|
|
18
26
|
).resolve()
|
|
19
27
|
)
|
|
20
28
|
cognee.config.data_root_directory(data_directory_path)
|
|
21
29
|
cognee_directory_path = str(
|
|
22
30
|
pathlib.Path(
|
|
23
|
-
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/
|
|
31
|
+
os.path.join(pathlib.Path(__file__).parent, ".cognee_system/test_neptune")
|
|
24
32
|
).resolve()
|
|
25
33
|
)
|
|
26
34
|
cognee.config.system_root_directory(cognee_directory_path)
|
|
@@ -47,14 +55,8 @@ async def main():
|
|
|
47
55
|
|
|
48
56
|
await cognee.cognify([dataset_name])
|
|
49
57
|
|
|
50
|
-
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
51
|
-
|
|
52
58
|
vector_engine = get_vector_engine()
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
assert len(search_results) != 0, "The search results list is empty."
|
|
56
|
-
|
|
57
|
-
random_node = search_results[0]
|
|
59
|
+
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
|
|
58
60
|
random_node_name = random_node.payload["text"]
|
|
59
61
|
|
|
60
62
|
search_results = await cognee.search(
|
|
@@ -84,16 +86,84 @@ async def main():
|
|
|
84
86
|
assert len(history) == 6, "Search history is not correct."
|
|
85
87
|
|
|
86
88
|
await cognee.prune.prune_data()
|
|
87
|
-
|
|
88
|
-
assert not os.path.isdir(data_root_directory), "Local data files are not deleted"
|
|
89
|
+
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"
|
|
89
90
|
|
|
90
91
|
await cognee.prune.prune_system(metadata=True)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
async def vector_backend_api_test():
|
|
95
|
+
cognee.config.set_vector_db_provider("neptune_analytics")
|
|
96
|
+
|
|
97
|
+
# When URL is absent
|
|
98
|
+
cognee.config.set_vector_db_url(None)
|
|
99
|
+
with pytest.raises(OSError):
|
|
100
|
+
get_vector_engine()
|
|
101
|
+
|
|
102
|
+
# Assert invalid graph ID.
|
|
103
|
+
cognee.config.set_vector_db_url("invalid_url")
|
|
104
|
+
with pytest.raises(ValueError):
|
|
105
|
+
get_vector_engine()
|
|
106
|
+
|
|
107
|
+
# Return a valid engine object with valid URL.
|
|
108
|
+
graph_id = os.getenv("GRAPH_ID", "")
|
|
109
|
+
cognee.config.set_vector_db_url(f"neptune-graph://{graph_id}")
|
|
110
|
+
engine = get_vector_engine()
|
|
111
|
+
assert isinstance(engine, NeptuneAnalyticsAdapter)
|
|
112
|
+
|
|
113
|
+
TEST_COLLECTION_NAME = "test"
|
|
114
|
+
# Data point - 1
|
|
115
|
+
TEST_UUID = str(uuid.uuid4())
|
|
116
|
+
TEST_TEXT = "Hello world"
|
|
117
|
+
datapoint = IndexSchema(id=TEST_UUID, text=TEST_TEXT)
|
|
118
|
+
# Data point - 2
|
|
119
|
+
TEST_UUID_2 = str(uuid.uuid4())
|
|
120
|
+
TEST_TEXT_2 = "Cognee"
|
|
121
|
+
datapoint_2 = IndexSchema(id=TEST_UUID_2, text=TEST_TEXT_2)
|
|
122
|
+
|
|
123
|
+
# Prun all vector_db entries
|
|
124
|
+
await engine.prune()
|
|
125
|
+
|
|
126
|
+
# Always return true
|
|
127
|
+
has_collection = await engine.has_collection(TEST_COLLECTION_NAME)
|
|
128
|
+
assert has_collection
|
|
129
|
+
# No-op
|
|
130
|
+
await engine.create_collection(TEST_COLLECTION_NAME, IndexSchema)
|
|
131
|
+
|
|
132
|
+
# Save data-points
|
|
133
|
+
await engine.create_data_points(TEST_COLLECTION_NAME, [datapoint, datapoint_2])
|
|
134
|
+
# Search single text
|
|
135
|
+
result_search = await engine.search(
|
|
136
|
+
collection_name=TEST_COLLECTION_NAME,
|
|
137
|
+
query_text=TEST_TEXT,
|
|
138
|
+
query_vector=None,
|
|
139
|
+
limit=10,
|
|
140
|
+
with_vector=True,
|
|
141
|
+
)
|
|
142
|
+
assert len(result_search) == 2
|
|
143
|
+
|
|
144
|
+
# # Retrieve data-points
|
|
145
|
+
result = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
|
146
|
+
assert any(str(r.id) == TEST_UUID and r.payload["text"] == TEST_TEXT for r in result)
|
|
147
|
+
assert any(str(r.id) == TEST_UUID_2 and r.payload["text"] == TEST_TEXT_2 for r in result)
|
|
148
|
+
# Search multiple
|
|
149
|
+
result_search_batch = await engine.batch_search(
|
|
150
|
+
collection_name=TEST_COLLECTION_NAME,
|
|
151
|
+
query_texts=[TEST_TEXT, TEST_TEXT_2],
|
|
152
|
+
limit=10,
|
|
153
|
+
with_vectors=False,
|
|
154
|
+
)
|
|
155
|
+
assert len(result_search_batch) == 2 and all(len(batch) == 2 for batch in result_search_batch)
|
|
156
|
+
|
|
157
|
+
# Delete datapoint from vector store
|
|
158
|
+
await engine.delete_data_points(TEST_COLLECTION_NAME, [TEST_UUID, TEST_UUID_2])
|
|
159
|
+
|
|
160
|
+
# Retrieve should return an empty list.
|
|
161
|
+
result_deleted = await engine.retrieve(TEST_COLLECTION_NAME, [TEST_UUID])
|
|
162
|
+
assert result_deleted == []
|
|
94
163
|
|
|
95
164
|
|
|
96
165
|
if __name__ == "__main__":
|
|
97
166
|
import asyncio
|
|
98
167
|
|
|
99
168
|
asyncio.run(main())
|
|
169
|
+
asyncio.run(vector_backend_api_test())
|
cognee/tests/test_pgvector.py
CHANGED
|
@@ -37,16 +37,16 @@ async def test_local_file_deletion(data_text, file_location):
|
|
|
37
37
|
# Get data entry from database based on file path
|
|
38
38
|
data = (
|
|
39
39
|
await session.scalars(
|
|
40
|
-
select(Data).where(Data.
|
|
40
|
+
select(Data).where(Data.original_data_location == "file://" + file_location)
|
|
41
41
|
)
|
|
42
42
|
).one()
|
|
43
|
-
assert os.path.isfile(data.
|
|
44
|
-
f"Data location doesn't exist: {data.
|
|
43
|
+
assert os.path.isfile(data.original_data_location.replace("file://", "")), (
|
|
44
|
+
f"Data location doesn't exist: {data.original_data_location}"
|
|
45
45
|
)
|
|
46
46
|
# Test local files not created by cognee won't get deleted
|
|
47
47
|
await engine.delete_data_entity(data.id)
|
|
48
|
-
assert os.path.exists(data.
|
|
49
|
-
f"Data location doesn't exists: {data.
|
|
48
|
+
assert os.path.exists(data.original_data_location.replace("file://", "")), (
|
|
49
|
+
f"Data location doesn't exists: {data.original_data_location}"
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
cognee/tests/test_s3.py
CHANGED
|
@@ -28,13 +28,8 @@ async def main():
|
|
|
28
28
|
logging.info(type_counts)
|
|
29
29
|
logging.info(edge_type_counts)
|
|
30
30
|
|
|
31
|
-
# Assert there is exactly one PdfDocument.
|
|
32
|
-
assert type_counts.get("PdfDocument", 0) == 1, (
|
|
33
|
-
f"Expected exactly one PdfDocument, but found {type_counts.get('PdfDocument', 0)}"
|
|
34
|
-
)
|
|
35
|
-
|
|
36
31
|
# Assert there is exactly one TextDocument.
|
|
37
|
-
assert type_counts.get("TextDocument", 0) ==
|
|
32
|
+
assert type_counts.get("TextDocument", 0) == 2, (
|
|
38
33
|
f"Expected exactly one TextDocument, but found {type_counts.get('TextDocument', 0)}"
|
|
39
34
|
)
|
|
40
35
|
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
"""Tests for the LLM rate limiter."""
|
|
2
2
|
|
|
3
3
|
import pytest
|
|
4
|
-
import asyncio
|
|
5
|
-
import time
|
|
6
4
|
from unittest.mock import patch
|
|
7
|
-
from cognee.infrastructure.llm.rate_limiter import (
|
|
5
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
8
6
|
llm_rate_limiter,
|
|
9
7
|
rate_limit_async,
|
|
10
8
|
rate_limit_sync,
|
|
11
9
|
)
|
|
12
10
|
|
|
11
|
+
LLM_RATE_LIMITER = "cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter.llm_rate_limiter"
|
|
12
|
+
GET_LLM_CONFIG = "cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter.get_llm_config"
|
|
13
|
+
|
|
13
14
|
|
|
14
15
|
@pytest.fixture(autouse=True)
|
|
15
16
|
def reset_limiter_singleton():
|
|
@@ -20,7 +21,7 @@ def reset_limiter_singleton():
|
|
|
20
21
|
|
|
21
22
|
def test_rate_limiter_initialization():
|
|
22
23
|
"""Test that the rate limiter can be initialized properly."""
|
|
23
|
-
with patch(
|
|
24
|
+
with patch(GET_LLM_CONFIG) as mock_config:
|
|
24
25
|
mock_config.return_value.llm_rate_limit_enabled = True
|
|
25
26
|
mock_config.return_value.llm_rate_limit_requests = 10
|
|
26
27
|
mock_config.return_value.llm_rate_limit_interval = 60 # 1 minute
|
|
@@ -34,7 +35,7 @@ def test_rate_limiter_initialization():
|
|
|
34
35
|
|
|
35
36
|
def test_rate_limiter_disabled():
|
|
36
37
|
"""Test that the rate limiter is disabled by default."""
|
|
37
|
-
with patch(
|
|
38
|
+
with patch(GET_LLM_CONFIG) as mock_config:
|
|
38
39
|
mock_config.return_value.llm_rate_limit_enabled = False
|
|
39
40
|
|
|
40
41
|
limiter = llm_rate_limiter()
|
|
@@ -45,7 +46,7 @@ def test_rate_limiter_disabled():
|
|
|
45
46
|
|
|
46
47
|
def test_rate_limiter_singleton():
|
|
47
48
|
"""Test that the rate limiter is a singleton."""
|
|
48
|
-
with patch(
|
|
49
|
+
with patch(GET_LLM_CONFIG) as mock_config:
|
|
49
50
|
mock_config.return_value.llm_rate_limit_enabled = True
|
|
50
51
|
mock_config.return_value.llm_rate_limit_requests = 5
|
|
51
52
|
mock_config.return_value.llm_rate_limit_interval = 60
|
|
@@ -58,7 +59,7 @@ def test_rate_limiter_singleton():
|
|
|
58
59
|
|
|
59
60
|
def test_sync_decorator():
|
|
60
61
|
"""Test the sync decorator."""
|
|
61
|
-
with patch(
|
|
62
|
+
with patch(LLM_RATE_LIMITER) as mock_limiter_class:
|
|
62
63
|
mock_limiter = mock_limiter_class.return_value
|
|
63
64
|
mock_limiter.wait_if_needed.return_value = 0
|
|
64
65
|
|
|
@@ -75,7 +76,7 @@ def test_sync_decorator():
|
|
|
75
76
|
@pytest.mark.asyncio
|
|
76
77
|
async def test_async_decorator():
|
|
77
78
|
"""Test the async decorator."""
|
|
78
|
-
with patch(
|
|
79
|
+
with patch(LLM_RATE_LIMITER) as mock_limiter_class:
|
|
79
80
|
mock_limiter = mock_limiter_class.return_value
|
|
80
81
|
|
|
81
82
|
# Mock an async method with a coroutine
|
|
@@ -96,7 +97,7 @@ async def test_async_decorator():
|
|
|
96
97
|
|
|
97
98
|
def test_rate_limiting_actual():
|
|
98
99
|
"""Test actual rate limiting behavior with a small window."""
|
|
99
|
-
with patch(
|
|
100
|
+
with patch(GET_LLM_CONFIG) as mock_config:
|
|
100
101
|
# Configure for 3 requests per minute
|
|
101
102
|
mock_config.return_value.llm_rate_limit_enabled = True
|
|
102
103
|
mock_config.return_value.llm_rate_limit_requests = 3
|
|
@@ -117,7 +118,7 @@ def test_rate_limiting_actual():
|
|
|
117
118
|
|
|
118
119
|
def test_rate_limit_60_per_minute():
|
|
119
120
|
"""Test rate limiting with the default 60 requests per minute limit."""
|
|
120
|
-
with patch(
|
|
121
|
+
with patch(GET_LLM_CONFIG) as mock_config:
|
|
121
122
|
# Configure for default values: 60 requests per 60 seconds
|
|
122
123
|
mock_config.return_value.llm_rate_limit_enabled = True
|
|
123
124
|
mock_config.return_value.llm_rate_limit_requests = 60 # 60 requests
|
|
File without changes
|
|
@@ -4,7 +4,7 @@ from typing import List
|
|
|
4
4
|
from cognee.infrastructure.databases.vector.embeddings.LiteLLMEmbeddingEngine import (
|
|
5
5
|
LiteLLMEmbeddingEngine,
|
|
6
6
|
)
|
|
7
|
-
from cognee.infrastructure.
|
|
7
|
+
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
|
8
8
|
embedding_rate_limit_async,
|
|
9
9
|
embedding_sleep_and_retry_async,
|
|
10
10
|
)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
import asyncio
|
|
4
|
-
from functools import lru_cache
|
|
5
4
|
import logging
|
|
6
5
|
|
|
7
|
-
from cognee.infrastructure.llm.config import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
6
|
+
from cognee.infrastructure.llm.config import (
|
|
7
|
+
get_llm_config,
|
|
8
|
+
)
|
|
9
|
+
from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter import (
|
|
10
|
+
EmbeddingRateLimiter,
|
|
11
11
|
)
|
|
12
12
|
from cognee.tests.unit.infrastructure.mock_embedding_engine import MockEmbeddingEngine
|
|
13
13
|
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import time
|
|
3
2
|
import os
|
|
4
|
-
from functools import lru_cache
|
|
5
3
|
from unittest.mock import patch
|
|
6
4
|
from cognee.shared.logging_utils import get_logger
|
|
7
|
-
from cognee.infrastructure.llm.rate_limiter import
|
|
8
|
-
|
|
5
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
6
|
+
llm_rate_limiter,
|
|
7
|
+
)
|
|
8
|
+
from cognee.infrastructure.llm.config import (
|
|
9
|
+
get_llm_config,
|
|
10
|
+
)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
async def test_rate_limiting_realistic():
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import time
|
|
2
2
|
import asyncio
|
|
3
3
|
from cognee.shared.logging_utils import get_logger
|
|
4
|
-
from cognee.infrastructure.llm.rate_limiter import (
|
|
4
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
5
5
|
sleep_and_retry_sync,
|
|
6
6
|
sleep_and_retry_async,
|
|
7
7
|
is_rate_limit_error,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import pytest
|
|
2
|
-
from typing import List
|
|
3
|
-
from cognee.infrastructure.engine import DataPoint
|
|
2
|
+
from typing import List, Any
|
|
3
|
+
from cognee.infrastructure.engine import DataPoint, Edge
|
|
4
4
|
|
|
5
5
|
from cognee.modules.graph.utils import get_graph_from_model
|
|
6
6
|
|
|
@@ -28,7 +28,20 @@ class Entity(DataPoint):
|
|
|
28
28
|
metadata: dict = {"index_fields": ["name"]}
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class Company(DataPoint):
|
|
32
|
+
name: str
|
|
33
|
+
employees: List[Any] = None # Allow flexible edge system with tuples
|
|
34
|
+
metadata: dict = {"index_fields": ["name"]}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Employee(DataPoint):
|
|
38
|
+
name: str
|
|
39
|
+
role: str
|
|
40
|
+
metadata: dict = {"index_fields": ["name"]}
|
|
41
|
+
|
|
42
|
+
|
|
31
43
|
DocumentChunk.model_rebuild()
|
|
44
|
+
Company.model_rebuild()
|
|
32
45
|
|
|
33
46
|
|
|
34
47
|
@pytest.mark.asyncio
|
|
@@ -50,7 +63,7 @@ async def test_get_graph_from_model_simple_structure():
|
|
|
50
63
|
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
|
51
64
|
assert len(edges) == 1, f"Expected 1 edges, got {len(edges)}"
|
|
52
65
|
|
|
53
|
-
edge_key = str(entity.id)
|
|
66
|
+
edge_key = f"{str(entity.id)}_{str(entitytype.id)}_is_type"
|
|
54
67
|
assert edge_key in added_edges, f"Edge {edge_key} not found"
|
|
55
68
|
|
|
56
69
|
|
|
@@ -149,3 +162,48 @@ async def test_get_graph_from_model_no_contains():
|
|
|
149
162
|
|
|
150
163
|
assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}"
|
|
151
164
|
assert len(edges) == 1, f"Expected 1 edge, got {len(edges)}"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@pytest.mark.asyncio
|
|
168
|
+
async def test_get_graph_from_model_flexible_edges():
|
|
169
|
+
"""Tests the new flexible edge system with mixed relationships"""
|
|
170
|
+
# Create employees
|
|
171
|
+
manager = Employee(name="Manager", role="Manager")
|
|
172
|
+
sales1 = Employee(name="Sales1", role="Sales")
|
|
173
|
+
sales2 = Employee(name="Sales2", role="Sales")
|
|
174
|
+
admin1 = Employee(name="Admin1", role="Admin")
|
|
175
|
+
admin2 = Employee(name="Admin2", role="Admin")
|
|
176
|
+
|
|
177
|
+
# Create company with mixed employee relationships
|
|
178
|
+
company = Company(
|
|
179
|
+
name="Test Company",
|
|
180
|
+
employees=[
|
|
181
|
+
# Weighted relationship
|
|
182
|
+
(Edge(weight=0.9, relationship_type="manages"), manager),
|
|
183
|
+
# Multiple weights relationship
|
|
184
|
+
(
|
|
185
|
+
Edge(weights={"performance": 0.8, "experience": 0.7}, relationship_type="employs"),
|
|
186
|
+
sales1,
|
|
187
|
+
),
|
|
188
|
+
# Simple relationship
|
|
189
|
+
sales2,
|
|
190
|
+
# Group relationship
|
|
191
|
+
(Edge(weights={"team_efficiency": 0.8}, relationship_type="employs"), [admin1, admin2]),
|
|
192
|
+
],
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
added_nodes = {}
|
|
196
|
+
added_edges = {}
|
|
197
|
+
visited_properties = {}
|
|
198
|
+
|
|
199
|
+
nodes, edges = await get_graph_from_model(company, added_nodes, added_edges, visited_properties)
|
|
200
|
+
|
|
201
|
+
# Should have 6 nodes: company + 5 employees
|
|
202
|
+
assert len(nodes) == 6, f"Expected 6 nodes, got {len(nodes)}"
|
|
203
|
+
# Should have 5 edges: 4 employee relationships
|
|
204
|
+
assert len(edges) == 5, f"Expected 5 edges, got {len(edges)}"
|
|
205
|
+
|
|
206
|
+
# Verify all employees are connected
|
|
207
|
+
employee_ids = {str(emp.id) for emp in [manager, sales1, sales2, admin1, admin2]}
|
|
208
|
+
edge_target_ids = {str(edge[1]) for edge in edges}
|
|
209
|
+
assert employee_ids.issubset(edge_target_ids), "Not all employees are connected"
|
|
@@ -155,6 +155,61 @@ async def test_specific_search_chunks(mock_send_telemetry, mock_chunks_retriever
|
|
|
155
155
|
assert results[0]["content"] == "Chunk result"
|
|
156
156
|
|
|
157
157
|
|
|
158
|
+
@pytest.mark.asyncio
|
|
159
|
+
@pytest.mark.parametrize(
|
|
160
|
+
"selected_type, retriever_name, expected_content, top_k",
|
|
161
|
+
[
|
|
162
|
+
(SearchType.RAG_COMPLETION, "CompletionRetriever", "RAG result from lucky search", 10),
|
|
163
|
+
(SearchType.CHUNKS, "ChunksRetriever", "Chunk result from lucky search", 5),
|
|
164
|
+
(SearchType.SUMMARIES, "SummariesRetriever", "Summary from lucky search", 15),
|
|
165
|
+
(SearchType.INSIGHTS, "InsightsRetriever", "Insight result from lucky search", 20),
|
|
166
|
+
],
|
|
167
|
+
)
|
|
168
|
+
@patch.object(search_module, "select_search_type")
|
|
169
|
+
@patch.object(search_module, "send_telemetry")
|
|
170
|
+
async def test_specific_search_feeling_lucky(
|
|
171
|
+
mock_send_telemetry,
|
|
172
|
+
mock_select_search_type,
|
|
173
|
+
selected_type,
|
|
174
|
+
retriever_name,
|
|
175
|
+
expected_content,
|
|
176
|
+
top_k,
|
|
177
|
+
mock_user,
|
|
178
|
+
):
|
|
179
|
+
with patch.object(search_module, retriever_name) as mock_retriever_class:
|
|
180
|
+
# Setup
|
|
181
|
+
query = f"test query for {retriever_name}"
|
|
182
|
+
query_type = SearchType.FEELING_LUCKY
|
|
183
|
+
|
|
184
|
+
# Mock the intelligent search type selection
|
|
185
|
+
mock_select_search_type.return_value = selected_type
|
|
186
|
+
|
|
187
|
+
# Mock the retriever
|
|
188
|
+
mock_retriever_instance = MagicMock()
|
|
189
|
+
mock_retriever_instance.get_completion = AsyncMock(
|
|
190
|
+
return_value=[{"content": expected_content}]
|
|
191
|
+
)
|
|
192
|
+
mock_retriever_class.return_value = mock_retriever_instance
|
|
193
|
+
|
|
194
|
+
# Execute
|
|
195
|
+
results = await specific_search(query_type, query, mock_user, top_k=top_k)
|
|
196
|
+
|
|
197
|
+
# Verify
|
|
198
|
+
mock_select_search_type.assert_called_once_with(query)
|
|
199
|
+
|
|
200
|
+
if retriever_name == "CompletionRetriever":
|
|
201
|
+
mock_retriever_class.assert_called_once_with(
|
|
202
|
+
system_prompt_path="answer_simple_question.txt", top_k=top_k
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
mock_retriever_class.assert_called_once_with(top_k=top_k)
|
|
206
|
+
|
|
207
|
+
mock_retriever_instance.get_completion.assert_called_once_with(query)
|
|
208
|
+
mock_send_telemetry.assert_called()
|
|
209
|
+
assert len(results) == 1
|
|
210
|
+
assert results[0]["content"] == expected_content
|
|
211
|
+
|
|
212
|
+
|
|
158
213
|
@pytest.mark.asyncio
|
|
159
214
|
async def test_specific_search_invalid_type(mock_user):
|
|
160
215
|
# Setup
|