cognee 0.2.3.dev1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/__main__.py +4 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +20 -6
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +14 -3
- cognee/api/v1/cognify/cognify.py +67 -105
- cognee/api/v1/cognify/routers/get_cognify_router.py +11 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +16 -5
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/responses/default_tools.py +4 -0
- cognee/api/v1/responses/dispatch_function.py +6 -1
- cognee/api/v1/responses/models.py +1 -1
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +17 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +529 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/__init__.py +10 -0
- cognee/cli/_cognee.py +273 -0
- cognee/cli/commands/__init__.py +1 -0
- cognee/cli/commands/add_command.py +80 -0
- cognee/cli/commands/cognify_command.py +128 -0
- cognee/cli/commands/config_command.py +225 -0
- cognee/cli/commands/delete_command.py +80 -0
- cognee/cli/commands/search_command.py +149 -0
- cognee/cli/config.py +33 -0
- cognee/cli/debug.py +21 -0
- cognee/cli/echo.py +45 -0
- cognee/cli/exceptions.py +23 -0
- cognee/cli/minimal_cli.py +97 -0
- cognee/cli/reference.py +26 -0
- cognee/cli/suppress_logging.py +12 -0
- cognee/eval_framework/corpus_builder/corpus_builder_executor.py +2 -2
- cognee/eval_framework/eval_config.py +1 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -9
- cognee/infrastructure/databases/graph/kuzu/adapter.py +199 -2
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +138 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -4
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +16 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +5 -5
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -2
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +10 -7
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +14 -9
- cognee/infrastructure/files/utils/get_file_metadata.py +2 -1
- cognee/infrastructure/llm/LLMGateway.py +32 -5
- cognee/infrastructure/llm/config.py +6 -4
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +16 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +19 -15
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +14 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +6 -4
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +28 -4
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +2 -2
- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/Mistral/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +6 -6
- cognee/infrastructure/llm/utils.py +7 -7
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/__init__.py +2 -0
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/create_authorized_dataset.py +19 -0
- cognee/modules/data/methods/get_authorized_dataset.py +11 -5
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +16 -0
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +2 -20
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/methods/get_formatted_graph_data.py +3 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/__init__.py +1 -1
- cognee/modules/pipelines/exceptions/tasks.py +18 -0
- cognee/modules/pipelines/layers/__init__.py +1 -0
- cognee/modules/pipelines/layers/check_pipeline_run_qualification.py +59 -0
- cognee/modules/pipelines/layers/pipeline_execution_mode.py +127 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +28 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +34 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +55 -0
- cognee/modules/pipelines/layers/setup_and_check_environment.py +41 -0
- cognee/modules/pipelines/layers/validate_pipeline_tasks.py +20 -0
- cognee/modules/pipelines/methods/__init__.py +2 -0
- cognee/modules/pipelines/methods/get_pipeline_runs_by_dataset.py +34 -0
- cognee/modules/pipelines/methods/reset_pipeline_run_status.py +16 -0
- cognee/modules/pipelines/operations/__init__.py +0 -1
- cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +24 -138
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_feedback.py +11 -0
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/cypher_search_retriever.py +1 -9
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +29 -13
- cognee/modules/retrieval/graph_completion_cot_retriever.py +30 -13
- cognee/modules/retrieval/graph_completion_retriever.py +107 -56
- cognee/modules/retrieval/graph_summary_completion_retriever.py +5 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/natural_language_retriever.py +0 -4
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/user_qa_feedback.py +83 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/retrieval/utils/extract_uuid_from_node.py +18 -0
- cognee/modules/retrieval/utils/models.py +40 -0
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +239 -118
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +3 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/settings/get_settings.py +2 -2
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/CodeGraphEntities.py +1 -0
- cognee/shared/logging_utils.py +143 -32
- cognee/shared/utils.py +0 -1
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/graph/extract_graph_from_data.py +6 -2
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +2 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +144 -47
- cognee/tasks/storage/add_data_points.py +33 -3
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/integration/cli/__init__.py +3 -0
- cognee/tests/integration/cli/test_cli_integration.py +331 -0
- cognee/tests/integration/documents/PdfDocument_test.py +2 -2
- cognee/tests/integration/documents/TextDocument_test.py +2 -4
- cognee/tests/integration/documents/UnstructuredDocument_test.py +5 -8
- cognee/tests/{test_deletion.py → test_delete_hard.py} +0 -37
- cognee/tests/test_delete_soft.py +85 -0
- cognee/tests/test_kuzu.py +2 -2
- cognee/tests/test_neo4j.py +2 -2
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +136 -23
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/cli/__init__.py +3 -0
- cognee/tests/unit/cli/test_cli_commands.py +483 -0
- cognee/tests/unit/cli/test_cli_edge_cases.py +625 -0
- cognee/tests/unit/cli/test_cli_main.py +173 -0
- cognee/tests/unit/cli/test_cli_runner.py +62 -0
- cognee/tests/unit/cli/test_cli_utils.py +127 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +10 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +4 -3
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/METADATA +13 -9
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/RECORD +247 -135
- cognee-0.3.0.dist-info/entry_points.txt +2 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +0 -1017
- cognee/infrastructure/pipeline/models/Operation.py +0 -60
- cognee/notebooks/github_analysis_step_by_step.ipynb +0 -37
- cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py +0 -7
- cognee/tests/unit/modules/search/search_methods_test.py +0 -223
- /cognee/{infrastructure/databases/graph/networkx → api/v1/memify}/__init__.py +0 -0
- /cognee/{infrastructure/pipeline/models → tasks/codingagents}/__init__.py +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
cognee/cli/reference.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import Protocol, Optional
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SupportsCliCommand(Protocol):
|
|
7
|
+
"""Protocol for defining one cognee cli command"""
|
|
8
|
+
|
|
9
|
+
command_string: str
|
|
10
|
+
"""name of the command"""
|
|
11
|
+
help_string: str
|
|
12
|
+
"""the help string for argparse"""
|
|
13
|
+
description: Optional[str]
|
|
14
|
+
"""the more detailed description for argparse, may include markdown for the docs"""
|
|
15
|
+
docs_url: Optional[str]
|
|
16
|
+
"""the default docs url to be printed in case of an exception"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def configure_parser(self, parser: argparse.ArgumentParser) -> None:
|
|
20
|
+
"""Configures the parser for the given argument"""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def execute(self, args: argparse.Namespace) -> None:
|
|
25
|
+
"""Executes the command with the given arguments"""
|
|
26
|
+
...
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module to suppress verbose logging before any cognee imports.
|
|
3
|
+
This must be imported before any other cognee modules.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
# Set CLI mode to suppress verbose logging
|
|
9
|
+
os.environ["COGNEE_CLI_MODE"] = "true"
|
|
10
|
+
|
|
11
|
+
# Also set log level to ERROR for extra safety
|
|
12
|
+
os.environ["LOG_LEVEL"] = "ERROR"
|
|
@@ -5,7 +5,7 @@ from typing import Optional, Tuple, List, Dict, Union, Any, Callable, Awaitable
|
|
|
5
5
|
from cognee.eval_framework.benchmark_adapters.benchmark_adapters import BenchmarkAdapter
|
|
6
6
|
from cognee.modules.chunking.TextChunker import TextChunker
|
|
7
7
|
from cognee.modules.pipelines.tasks.task import Task
|
|
8
|
-
from cognee.modules.pipelines import
|
|
8
|
+
from cognee.modules.pipelines import run_pipeline
|
|
9
9
|
|
|
10
10
|
logger = get_logger(level=ERROR)
|
|
11
11
|
|
|
@@ -61,7 +61,7 @@ class CorpusBuilderExecutor:
|
|
|
61
61
|
await cognee.add(self.raw_corpus)
|
|
62
62
|
|
|
63
63
|
tasks = await self.task_getter(chunk_size=chunk_size, chunker=chunker)
|
|
64
|
-
pipeline_run =
|
|
64
|
+
pipeline_run = run_pipeline(tasks=tasks)
|
|
65
65
|
|
|
66
66
|
async for run_info in pipeline_run:
|
|
67
67
|
print(run_info)
|
|
@@ -6,6 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
6
6
|
import pydantic
|
|
7
7
|
from pydantic import Field
|
|
8
8
|
from cognee.base_config import get_base_config
|
|
9
|
+
from cognee.root_dir import ensure_absolute_path
|
|
9
10
|
from cognee.shared.data_models import KnowledgeGraph
|
|
10
11
|
|
|
11
12
|
|
|
@@ -51,15 +52,20 @@ class GraphConfig(BaseSettings):
|
|
|
51
52
|
@pydantic.model_validator(mode="after")
|
|
52
53
|
def fill_derived(cls, values):
|
|
53
54
|
provider = values.graph_database_provider.lower()
|
|
55
|
+
base_config = get_base_config()
|
|
54
56
|
|
|
55
57
|
# Set default filename if no filename is provided
|
|
56
58
|
if not values.graph_filename:
|
|
57
59
|
values.graph_filename = f"cognee_graph_{provider}"
|
|
58
60
|
|
|
59
|
-
#
|
|
60
|
-
if
|
|
61
|
-
|
|
62
|
-
|
|
61
|
+
# Handle graph file path
|
|
62
|
+
if values.graph_file_path:
|
|
63
|
+
# Check if absolute path is provided
|
|
64
|
+
values.graph_file_path = ensure_absolute_path(
|
|
65
|
+
os.path.join(values.graph_file_path, values.graph_filename)
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
# Default path
|
|
63
69
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
|
64
70
|
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
|
65
71
|
|
|
@@ -21,10 +21,6 @@ async def get_graph_engine() -> GraphDBInterface:
|
|
|
21
21
|
if hasattr(graph_client, "initialize"):
|
|
22
22
|
await graph_client.initialize()
|
|
23
23
|
|
|
24
|
-
# Handle loading of graph for NetworkX
|
|
25
|
-
if config["graph_database_provider"].lower() == "networkx" and graph_client.graph is None:
|
|
26
|
-
await graph_client.load_graph_from_file()
|
|
27
|
-
|
|
28
24
|
return graph_client
|
|
29
25
|
|
|
30
26
|
|
|
@@ -181,8 +177,7 @@ def create_graph_engine(
|
|
|
181
177
|
graph_id=graph_identifier,
|
|
182
178
|
)
|
|
183
179
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
return graph_client
|
|
180
|
+
raise EnvironmentError(
|
|
181
|
+
f"Unsupported graph database provider: {graph_database_provider}. "
|
|
182
|
+
f"Supported providers are: {', '.join(list(supported_databases.keys()) + ['neo4j', 'falkordb', 'kuzu', 'kuzu-remote', 'memgraph', 'neptune', 'neptune_analytics'])}"
|
|
183
|
+
)
|
|
@@ -21,6 +21,8 @@ from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
|
21
21
|
)
|
|
22
22
|
from cognee.infrastructure.engine import DataPoint
|
|
23
23
|
from cognee.modules.storage.utils import JSONEncoder
|
|
24
|
+
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
|
25
|
+
from cognee.tasks.temporal_graph.models import Timestamp
|
|
24
26
|
|
|
25
27
|
logger = get_logger()
|
|
26
28
|
|
|
@@ -106,6 +108,18 @@ class KuzuAdapter(GraphDBInterface):
|
|
|
106
108
|
|
|
107
109
|
self.db.init_database()
|
|
108
110
|
self.connection = Connection(self.db)
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
self.connection.execute("INSTALL JSON;")
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.info(f"JSON extension already installed or not needed: {e}")
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
self.connection.execute("LOAD EXTENSION JSON;")
|
|
119
|
+
logger.info("Loaded JSON extension")
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.info(f"JSON extension already loaded or unavailable: {e}")
|
|
122
|
+
|
|
109
123
|
# Create node table with essential fields and timestamp
|
|
110
124
|
self.connection.execute("""
|
|
111
125
|
CREATE NODE TABLE IF NOT EXISTS Node(
|
|
@@ -138,8 +152,9 @@ class KuzuAdapter(GraphDBInterface):
|
|
|
138
152
|
|
|
139
153
|
s3_file_storage = S3FileStorage("")
|
|
140
154
|
|
|
141
|
-
|
|
142
|
-
self.
|
|
155
|
+
if self.connection:
|
|
156
|
+
async with self.KUZU_ASYNC_LOCK:
|
|
157
|
+
self.connection.execute("CHECKPOINT;")
|
|
143
158
|
|
|
144
159
|
s3_file_storage.s3.put(self.temp_graph_file, self.db_path, recursive=True)
|
|
145
160
|
|
|
@@ -1631,3 +1646,185 @@ class KuzuAdapter(GraphDBInterface):
|
|
|
1631
1646
|
"""
|
|
1632
1647
|
result = await self.query(query)
|
|
1633
1648
|
return [record[0] for record in result] if result else []
|
|
1649
|
+
|
|
1650
|
+
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
|
1651
|
+
"""
|
|
1652
|
+
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
|
1653
|
+
Parameters:
|
|
1654
|
+
-----------
|
|
1655
|
+
- limit (int): The maximum number of interaction IDs to return.
|
|
1656
|
+
Returns:
|
|
1657
|
+
--------
|
|
1658
|
+
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
|
1659
|
+
"""
|
|
1660
|
+
|
|
1661
|
+
query = """
|
|
1662
|
+
MATCH (n)
|
|
1663
|
+
WHERE n.type = 'CogneeUserInteraction'
|
|
1664
|
+
RETURN n.id as id
|
|
1665
|
+
ORDER BY n.created_at DESC
|
|
1666
|
+
LIMIT $limit
|
|
1667
|
+
"""
|
|
1668
|
+
rows = await self.query(query, {"limit": limit})
|
|
1669
|
+
|
|
1670
|
+
id_list = [row[0] for row in rows]
|
|
1671
|
+
return id_list
|
|
1672
|
+
|
|
1673
|
+
async def apply_feedback_weight(
|
|
1674
|
+
self,
|
|
1675
|
+
node_ids: List[str],
|
|
1676
|
+
weight: float,
|
|
1677
|
+
) -> None:
|
|
1678
|
+
"""
|
|
1679
|
+
Increment `feedback_weight` inside r.properties JSON for edges where
|
|
1680
|
+
relationship_name = 'used_graph_element_to_answer'.
|
|
1681
|
+
|
|
1682
|
+
"""
|
|
1683
|
+
# Step 1: fetch matching edges
|
|
1684
|
+
query = """
|
|
1685
|
+
MATCH (n:Node)-[r:EDGE]->()
|
|
1686
|
+
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
|
1687
|
+
RETURN r.properties, n.id
|
|
1688
|
+
"""
|
|
1689
|
+
results = await self.query(query, {"node_ids": node_ids})
|
|
1690
|
+
|
|
1691
|
+
# Step 2: update JSON client-side
|
|
1692
|
+
updates = []
|
|
1693
|
+
for props_json, source_id in results:
|
|
1694
|
+
try:
|
|
1695
|
+
props = json.loads(props_json) if props_json else {}
|
|
1696
|
+
except json.JSONDecodeError:
|
|
1697
|
+
props = {}
|
|
1698
|
+
|
|
1699
|
+
props["feedback_weight"] = props.get("feedback_weight", 0) + weight
|
|
1700
|
+
updates.append((source_id, json.dumps(props)))
|
|
1701
|
+
|
|
1702
|
+
# Step 3: write back
|
|
1703
|
+
for node_id, new_props in updates:
|
|
1704
|
+
update_query = """
|
|
1705
|
+
MATCH (n:Node)-[r:EDGE]->()
|
|
1706
|
+
WHERE n.id = $node_id AND r.relationship_name = 'used_graph_element_to_answer'
|
|
1707
|
+
SET r.properties = $props
|
|
1708
|
+
"""
|
|
1709
|
+
await self.query(update_query, {"node_id": node_id, "props": new_props})
|
|
1710
|
+
|
|
1711
|
+
async def collect_events(self, ids: List[str]) -> Any:
|
|
1712
|
+
"""
|
|
1713
|
+
Collect all Event-type nodes reachable within 1..2 hops
|
|
1714
|
+
from the given node IDs.
|
|
1715
|
+
|
|
1716
|
+
Args:
|
|
1717
|
+
graph_engine: Object exposing an async .query(str) -> Any
|
|
1718
|
+
ids: List of node IDs (strings)
|
|
1719
|
+
|
|
1720
|
+
Returns:
|
|
1721
|
+
List of events
|
|
1722
|
+
"""
|
|
1723
|
+
|
|
1724
|
+
event_collection_cypher = """UNWIND [{quoted}] AS uid
|
|
1725
|
+
MATCH (start {{id: uid}})
|
|
1726
|
+
MATCH (start)-[*1..2]-(event)
|
|
1727
|
+
WHERE event.type = 'Event'
|
|
1728
|
+
WITH DISTINCT event
|
|
1729
|
+
RETURN collect(event) AS events;
|
|
1730
|
+
"""
|
|
1731
|
+
|
|
1732
|
+
query = event_collection_cypher.format(quoted=ids)
|
|
1733
|
+
result = await self.query(query)
|
|
1734
|
+
events = []
|
|
1735
|
+
for node in result[0][0]:
|
|
1736
|
+
props = json.loads(node["properties"])
|
|
1737
|
+
|
|
1738
|
+
event = {
|
|
1739
|
+
"id": node["id"],
|
|
1740
|
+
"name": node["name"],
|
|
1741
|
+
"description": props.get("description"),
|
|
1742
|
+
}
|
|
1743
|
+
|
|
1744
|
+
if props.get("location"):
|
|
1745
|
+
event["location"] = props["location"]
|
|
1746
|
+
|
|
1747
|
+
events.append(event)
|
|
1748
|
+
|
|
1749
|
+
return [{"events": events}]
|
|
1750
|
+
|
|
1751
|
+
async def collect_time_ids(
|
|
1752
|
+
self,
|
|
1753
|
+
time_from: Optional[Timestamp] = None,
|
|
1754
|
+
time_to: Optional[Timestamp] = None,
|
|
1755
|
+
) -> str:
|
|
1756
|
+
"""
|
|
1757
|
+
Collect IDs of Timestamp nodes between time_from and time_to.
|
|
1758
|
+
|
|
1759
|
+
Args:
|
|
1760
|
+
graph_engine: Object exposing an async .query(query, params) -> list[dict]
|
|
1761
|
+
time_from: Lower bound int (inclusive), optional
|
|
1762
|
+
time_to: Upper bound int (inclusive), optional
|
|
1763
|
+
|
|
1764
|
+
Returns:
|
|
1765
|
+
A string of quoted IDs: "'id1', 'id2', 'id3'"
|
|
1766
|
+
(ready for use in a Cypher UNWIND clause).
|
|
1767
|
+
"""
|
|
1768
|
+
|
|
1769
|
+
ids: List[str] = []
|
|
1770
|
+
|
|
1771
|
+
if time_from and time_to:
|
|
1772
|
+
time_from = date_to_int(time_from)
|
|
1773
|
+
time_to = date_to_int(time_to)
|
|
1774
|
+
|
|
1775
|
+
cypher = f"""
|
|
1776
|
+
MATCH (n:Node)
|
|
1777
|
+
WHERE n.type = 'Timestamp'
|
|
1778
|
+
// Extract time_at from the JSON string and cast to INT64
|
|
1779
|
+
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
1780
|
+
WITH n,
|
|
1781
|
+
CASE
|
|
1782
|
+
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
1783
|
+
ELSE CAST(t_str AS INT64)
|
|
1784
|
+
END AS t
|
|
1785
|
+
WHERE t >= {time_from}
|
|
1786
|
+
AND t <= {time_to}
|
|
1787
|
+
RETURN n.id as id
|
|
1788
|
+
"""
|
|
1789
|
+
|
|
1790
|
+
elif time_from:
|
|
1791
|
+
time_from = date_to_int(time_from)
|
|
1792
|
+
|
|
1793
|
+
cypher = f"""
|
|
1794
|
+
MATCH (n:Node)
|
|
1795
|
+
WHERE n.type = 'Timestamp'
|
|
1796
|
+
// Extract time_at from the JSON string and cast to INT64
|
|
1797
|
+
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
1798
|
+
WITH n,
|
|
1799
|
+
CASE
|
|
1800
|
+
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
1801
|
+
ELSE CAST(t_str AS INT64)
|
|
1802
|
+
END AS t
|
|
1803
|
+
WHERE t >= {time_from}
|
|
1804
|
+
RETURN n.id as id
|
|
1805
|
+
"""
|
|
1806
|
+
|
|
1807
|
+
elif time_to:
|
|
1808
|
+
time_to = date_to_int(time_to)
|
|
1809
|
+
|
|
1810
|
+
cypher = f"""
|
|
1811
|
+
MATCH (n:Node)
|
|
1812
|
+
WHERE n.type = 'Timestamp'
|
|
1813
|
+
// Extract time_at from the JSON string and cast to INT64
|
|
1814
|
+
WITH n, json_extract(n.properties, '$.time_at') AS t_str
|
|
1815
|
+
WITH n,
|
|
1816
|
+
CASE
|
|
1817
|
+
WHEN t_str IS NULL OR t_str = '' THEN NULL
|
|
1818
|
+
ELSE CAST(t_str AS INT64)
|
|
1819
|
+
END AS t
|
|
1820
|
+
WHERE t <= {time_to}
|
|
1821
|
+
RETURN n.id as id
|
|
1822
|
+
"""
|
|
1823
|
+
|
|
1824
|
+
else:
|
|
1825
|
+
return ids
|
|
1826
|
+
|
|
1827
|
+
time_nodes = await self.query(cypher)
|
|
1828
|
+
time_ids_list = [item[0] for item in time_nodes]
|
|
1829
|
+
|
|
1830
|
+
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
|
@@ -11,6 +11,8 @@ from contextlib import asynccontextmanager
|
|
|
11
11
|
from typing import Optional, Any, List, Dict, Type, Tuple
|
|
12
12
|
|
|
13
13
|
from cognee.infrastructure.engine import DataPoint
|
|
14
|
+
from cognee.modules.engine.utils.generate_timestamp_datapoint import date_to_int
|
|
15
|
+
from cognee.tasks.temporal_graph.models import Timestamp
|
|
14
16
|
from cognee.shared.logging_utils import get_logger, ERROR
|
|
15
17
|
from cognee.infrastructure.databases.graph.graph_db_interface import (
|
|
16
18
|
GraphDBInterface,
|
|
@@ -1322,3 +1324,139 @@ class Neo4jAdapter(GraphDBInterface):
|
|
|
1322
1324
|
"""
|
|
1323
1325
|
result = await self.query(query)
|
|
1324
1326
|
return [record["n"] for record in result] if result else []
|
|
1327
|
+
|
|
1328
|
+
async def get_last_user_interaction_ids(self, limit: int) -> List[str]:
|
|
1329
|
+
"""
|
|
1330
|
+
Retrieve the IDs of the most recent CogneeUserInteraction nodes.
|
|
1331
|
+
Parameters:
|
|
1332
|
+
-----------
|
|
1333
|
+
- limit (int): The maximum number of interaction IDs to return.
|
|
1334
|
+
Returns:
|
|
1335
|
+
--------
|
|
1336
|
+
- List[str]: A list of interaction IDs, sorted by created_at descending.
|
|
1337
|
+
"""
|
|
1338
|
+
|
|
1339
|
+
query = """
|
|
1340
|
+
MATCH (n)
|
|
1341
|
+
WHERE n.type = 'CogneeUserInteraction'
|
|
1342
|
+
RETURN n.id as id
|
|
1343
|
+
ORDER BY n.created_at DESC
|
|
1344
|
+
LIMIT $limit
|
|
1345
|
+
"""
|
|
1346
|
+
rows = await self.query(query, {"limit": limit})
|
|
1347
|
+
|
|
1348
|
+
id_list = [row["id"] for row in rows if "id" in row]
|
|
1349
|
+
return id_list
|
|
1350
|
+
|
|
1351
|
+
async def apply_feedback_weight(
|
|
1352
|
+
self,
|
|
1353
|
+
node_ids: List[str],
|
|
1354
|
+
weight: float,
|
|
1355
|
+
) -> None:
|
|
1356
|
+
"""
|
|
1357
|
+
Increment `feedback_weight` on relationships `:used_graph_element_to_answer`
|
|
1358
|
+
outgoing from nodes whose `id` is in `node_ids`.
|
|
1359
|
+
|
|
1360
|
+
Args:
|
|
1361
|
+
node_ids: List of node IDs to match.
|
|
1362
|
+
weight: Amount to add to `r.feedback_weight` (can be negative).
|
|
1363
|
+
|
|
1364
|
+
Side effects:
|
|
1365
|
+
Updates relationship property `feedback_weight`, defaulting missing values to 0.
|
|
1366
|
+
"""
|
|
1367
|
+
query = """
|
|
1368
|
+
MATCH (n)-[r]->()
|
|
1369
|
+
WHERE n.id IN $node_ids AND r.relationship_name = 'used_graph_element_to_answer'
|
|
1370
|
+
SET r.feedback_weight = coalesce(r.feedback_weight, 0) + $weight
|
|
1371
|
+
"""
|
|
1372
|
+
await self.query(
|
|
1373
|
+
query,
|
|
1374
|
+
params={"weight": float(weight), "node_ids": list(node_ids)},
|
|
1375
|
+
)
|
|
1376
|
+
|
|
1377
|
+
async def collect_events(self, ids: List[str]) -> Any:
|
|
1378
|
+
"""
|
|
1379
|
+
Collect all Event-type nodes reachable within 1..2 hops
|
|
1380
|
+
from the given node IDs.
|
|
1381
|
+
|
|
1382
|
+
Args:
|
|
1383
|
+
graph_engine: Object exposing an async .query(str) -> Any
|
|
1384
|
+
ids: List of node IDs (strings)
|
|
1385
|
+
|
|
1386
|
+
Returns:
|
|
1387
|
+
List of events
|
|
1388
|
+
"""
|
|
1389
|
+
|
|
1390
|
+
event_collection_cypher = """UNWIND [{quoted}] AS uid
|
|
1391
|
+
MATCH (start {{id: uid}})
|
|
1392
|
+
MATCH (start)-[*1..2]-(event)
|
|
1393
|
+
WHERE event.type = 'Event'
|
|
1394
|
+
WITH DISTINCT event
|
|
1395
|
+
RETURN collect(event) AS events;
|
|
1396
|
+
"""
|
|
1397
|
+
|
|
1398
|
+
query = event_collection_cypher.format(quoted=ids)
|
|
1399
|
+
return await self.query(query)
|
|
1400
|
+
|
|
1401
|
+
async def collect_time_ids(
|
|
1402
|
+
self,
|
|
1403
|
+
time_from: Optional[Timestamp] = None,
|
|
1404
|
+
time_to: Optional[Timestamp] = None,
|
|
1405
|
+
) -> str:
|
|
1406
|
+
"""
|
|
1407
|
+
Collect IDs of Timestamp nodes between time_from and time_to.
|
|
1408
|
+
|
|
1409
|
+
Args:
|
|
1410
|
+
graph_engine: Object exposing an async .query(query, params) -> list[dict]
|
|
1411
|
+
time_from: Lower bound int (inclusive), optional
|
|
1412
|
+
time_to: Upper bound int (inclusive), optional
|
|
1413
|
+
|
|
1414
|
+
Returns:
|
|
1415
|
+
A string of quoted IDs: "'id1', 'id2', 'id3'"
|
|
1416
|
+
(ready for use in a Cypher UNWIND clause).
|
|
1417
|
+
"""
|
|
1418
|
+
|
|
1419
|
+
ids: List[str] = []
|
|
1420
|
+
|
|
1421
|
+
if time_from and time_to:
|
|
1422
|
+
time_from = date_to_int(time_from)
|
|
1423
|
+
time_to = date_to_int(time_to)
|
|
1424
|
+
|
|
1425
|
+
cypher = """
|
|
1426
|
+
MATCH (n)
|
|
1427
|
+
WHERE n.type = 'Timestamp'
|
|
1428
|
+
AND n.time_at >= $time_from
|
|
1429
|
+
AND n.time_at <= $time_to
|
|
1430
|
+
RETURN n.id AS id
|
|
1431
|
+
"""
|
|
1432
|
+
params = {"time_from": time_from, "time_to": time_to}
|
|
1433
|
+
|
|
1434
|
+
elif time_from:
|
|
1435
|
+
time_from = date_to_int(time_from)
|
|
1436
|
+
|
|
1437
|
+
cypher = """
|
|
1438
|
+
MATCH (n)
|
|
1439
|
+
WHERE n.type = 'Timestamp'
|
|
1440
|
+
AND n.time_at >= $time_from
|
|
1441
|
+
RETURN n.id AS id
|
|
1442
|
+
"""
|
|
1443
|
+
params = {"time_from": time_from}
|
|
1444
|
+
|
|
1445
|
+
elif time_to:
|
|
1446
|
+
time_to = date_to_int(time_to)
|
|
1447
|
+
|
|
1448
|
+
cypher = """
|
|
1449
|
+
MATCH (n)
|
|
1450
|
+
WHERE n.type = 'Timestamp'
|
|
1451
|
+
AND n.time_at <= $time_to
|
|
1452
|
+
RETURN n.id AS id
|
|
1453
|
+
"""
|
|
1454
|
+
params = {"time_to": time_to}
|
|
1455
|
+
|
|
1456
|
+
else:
|
|
1457
|
+
return ids
|
|
1458
|
+
|
|
1459
|
+
time_nodes = await self.query(cypher, params)
|
|
1460
|
+
time_ids_list = [item["id"] for item in time_nodes if "id" in item]
|
|
1461
|
+
|
|
1462
|
+
return ", ".join(f"'{uid}'" for uid in time_ids_list)
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from .ModelBase import Base
|
|
2
2
|
from .config import get_relational_config
|
|
3
3
|
from .config import get_migration_config
|
|
4
|
+
from .get_async_session import get_async_session
|
|
5
|
+
from .with_async_session import with_async_session
|
|
4
6
|
from .create_db_and_tables import create_db_and_tables
|
|
5
7
|
from .get_relational_engine import get_relational_engine
|
|
6
8
|
from .get_migration_relational_engine import get_migration_relational_engine
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import AsyncGenerator
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
4
|
+
|
|
5
|
+
from .get_relational_engine import get_relational_engine
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@asynccontextmanager
|
|
9
|
+
async def get_async_session(auto_commit=False) -> AsyncGenerator[AsyncSession, None]:
|
|
10
|
+
db_engine = get_relational_engine()
|
|
11
|
+
async with db_engine.get_async_session() as session:
|
|
12
|
+
yield session
|
|
13
|
+
|
|
14
|
+
if auto_commit:
|
|
15
|
+
await session.commit()
|
|
@@ -57,7 +57,12 @@ class SQLAlchemyAdapter:
|
|
|
57
57
|
)
|
|
58
58
|
else:
|
|
59
59
|
self.engine = create_async_engine(
|
|
60
|
-
connection_string,
|
|
60
|
+
connection_string,
|
|
61
|
+
pool_size=5,
|
|
62
|
+
max_overflow=10,
|
|
63
|
+
pool_recycle=280,
|
|
64
|
+
pool_pre_ping=True,
|
|
65
|
+
pool_timeout=280,
|
|
61
66
|
)
|
|
62
67
|
|
|
63
68
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Any, Callable, Optional
|
|
2
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
3
|
+
from .get_async_session import get_async_session
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_session_from_args(args):
|
|
7
|
+
last_arg = args[-1]
|
|
8
|
+
if isinstance(last_arg, AsyncSession):
|
|
9
|
+
return last_arg
|
|
10
|
+
return None
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def with_async_session(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
14
|
+
async def wrapper(*args, **kwargs):
|
|
15
|
+
session = kwargs.get("session") or get_session_from_args(args) # type: Optional[AsyncSession]
|
|
16
|
+
|
|
17
|
+
if session is None:
|
|
18
|
+
async with get_async_session() as session:
|
|
19
|
+
result = await func(*args, **kwargs, session=session)
|
|
20
|
+
await session.commit()
|
|
21
|
+
return result
|
|
22
|
+
else:
|
|
23
|
+
return await func(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
return wrapper
|
|
@@ -538,7 +538,7 @@ class ChromaDBAdapter(VectorDBInterface):
|
|
|
538
538
|
Returns True upon successful deletion of all collections.
|
|
539
539
|
"""
|
|
540
540
|
client = await self.get_connection()
|
|
541
|
-
collections = await
|
|
541
|
+
collections = await client.list_collections()
|
|
542
542
|
for collection_name in collections:
|
|
543
543
|
await client.delete_collection(collection_name)
|
|
544
544
|
return True
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pydantic
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from functools import lru_cache
|
|
4
5
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
5
6
|
|
|
6
7
|
from cognee.base_config import get_base_config
|
|
8
|
+
from cognee.root_dir import ensure_absolute_path
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class VectorConfig(BaseSettings):
|
|
@@ -11,11 +13,9 @@ class VectorConfig(BaseSettings):
|
|
|
11
13
|
Manage the configuration settings for the vector database.
|
|
12
14
|
|
|
13
15
|
Public methods:
|
|
14
|
-
|
|
15
16
|
- to_dict: Convert the configuration to a dictionary.
|
|
16
17
|
|
|
17
18
|
Instance variables:
|
|
18
|
-
|
|
19
19
|
- vector_db_url: The URL of the vector database.
|
|
20
20
|
- vector_db_port: The port for the vector database.
|
|
21
21
|
- vector_db_key: The key for accessing the vector database.
|
|
@@ -30,10 +30,17 @@ class VectorConfig(BaseSettings):
|
|
|
30
30
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
31
31
|
|
|
32
32
|
@pydantic.model_validator(mode="after")
|
|
33
|
-
def
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
33
|
+
def validate_paths(cls, values):
|
|
34
|
+
base_config = get_base_config()
|
|
35
|
+
|
|
36
|
+
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
|
|
37
|
+
if values.vector_db_url and Path(values.vector_db_url).exists():
|
|
38
|
+
# Relative path to absolute
|
|
39
|
+
values.vector_db_url = ensure_absolute_path(
|
|
40
|
+
values.vector_db_url,
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
# Default path
|
|
37
44
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
|
38
45
|
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
|
39
46
|
|
|
@@ -4,7 +4,7 @@ from fastembed import TextEmbedding
|
|
|
4
4
|
import litellm
|
|
5
5
|
import os
|
|
6
6
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
|
7
|
-
from cognee.infrastructure.databases.exceptions
|
|
7
|
+
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
|
8
8
|
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
|
9
9
|
TikTokenTokenizer,
|
|
10
10
|
)
|
|
@@ -41,11 +41,11 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
41
41
|
self,
|
|
42
42
|
model: Optional[str] = "openai/text-embedding-3-large",
|
|
43
43
|
dimensions: Optional[int] = 3072,
|
|
44
|
-
|
|
44
|
+
max_completion_tokens: int = 512,
|
|
45
45
|
):
|
|
46
46
|
self.model = model
|
|
47
47
|
self.dimensions = dimensions
|
|
48
|
-
self.
|
|
48
|
+
self.max_completion_tokens = max_completion_tokens
|
|
49
49
|
self.tokenizer = self.get_tokenizer()
|
|
50
50
|
# self.retry_count = 0
|
|
51
51
|
self.embedding_model = TextEmbedding(model_name=model)
|
|
@@ -112,7 +112,9 @@ class FastembedEmbeddingEngine(EmbeddingEngine):
|
|
|
112
112
|
"""
|
|
113
113
|
logger.debug("Loading tokenizer for FastembedEmbeddingEngine...")
|
|
114
114
|
|
|
115
|
-
tokenizer = TikTokenTokenizer(
|
|
115
|
+
tokenizer = TikTokenTokenizer(
|
|
116
|
+
model="gpt-4o", max_completion_tokens=self.max_completion_tokens
|
|
117
|
+
)
|
|
116
118
|
|
|
117
119
|
logger.debug("Tokenizer loaded for for FastembedEmbeddingEngine")
|
|
118
120
|
return tokenizer
|