cognee 0.2.3.dev1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +2 -0
- cognee/__main__.py +4 -0
- cognee/api/client.py +28 -3
- cognee/api/health.py +10 -13
- cognee/api/v1/add/add.py +20 -6
- cognee/api/v1/add/routers/get_add_router.py +12 -37
- cognee/api/v1/cloud/routers/__init__.py +1 -0
- cognee/api/v1/cloud/routers/get_checks_router.py +23 -0
- cognee/api/v1/cognify/code_graph_pipeline.py +14 -3
- cognee/api/v1/cognify/cognify.py +67 -105
- cognee/api/v1/cognify/routers/get_cognify_router.py +11 -3
- cognee/api/v1/datasets/routers/get_datasets_router.py +16 -5
- cognee/api/v1/memify/routers/__init__.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +100 -0
- cognee/api/v1/notebooks/routers/__init__.py +1 -0
- cognee/api/v1/notebooks/routers/get_notebooks_router.py +96 -0
- cognee/api/v1/responses/default_tools.py +4 -0
- cognee/api/v1/responses/dispatch_function.py +6 -1
- cognee/api/v1/responses/models.py +1 -1
- cognee/api/v1/search/routers/get_search_router.py +20 -1
- cognee/api/v1/search/search.py +17 -4
- cognee/api/v1/sync/__init__.py +17 -0
- cognee/api/v1/sync/routers/__init__.py +3 -0
- cognee/api/v1/sync/routers/get_sync_router.py +241 -0
- cognee/api/v1/sync/sync.py +877 -0
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +529 -0
- cognee/api/v1/users/routers/get_auth_router.py +13 -1
- cognee/base_config.py +10 -1
- cognee/cli/__init__.py +10 -0
- cognee/cli/_cognee.py +273 -0
- cognee/cli/commands/__init__.py +1 -0
- cognee/cli/commands/add_command.py +80 -0
- cognee/cli/commands/cognify_command.py +128 -0
- cognee/cli/commands/config_command.py +225 -0
- cognee/cli/commands/delete_command.py +80 -0
- cognee/cli/commands/search_command.py +149 -0
- cognee/cli/config.py +33 -0
- cognee/cli/debug.py +21 -0
- cognee/cli/echo.py +45 -0
- cognee/cli/exceptions.py +23 -0
- cognee/cli/minimal_cli.py +97 -0
- cognee/cli/reference.py +26 -0
- cognee/cli/suppress_logging.py +12 -0
- cognee/eval_framework/corpus_builder/corpus_builder_executor.py +2 -2
- cognee/eval_framework/eval_config.py +1 -1
- cognee/infrastructure/databases/graph/config.py +10 -4
- cognee/infrastructure/databases/graph/get_graph_engine.py +4 -9
- cognee/infrastructure/databases/graph/kuzu/adapter.py +199 -2
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +138 -0
- cognee/infrastructure/databases/relational/__init__.py +2 -0
- cognee/infrastructure/databases/relational/get_async_session.py +15 -0
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +6 -1
- cognee/infrastructure/databases/relational/with_async_session.py +25 -0
- cognee/infrastructure/databases/vector/chromadb/ChromaDBAdapter.py +1 -1
- cognee/infrastructure/databases/vector/config.py +13 -6
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +6 -4
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +16 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +5 -5
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -2
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +2 -6
- cognee/infrastructure/databases/vector/embeddings/get_embedding_engine.py +10 -7
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +5 -0
- cognee/infrastructure/files/storage/StorageManager.py +7 -1
- cognee/infrastructure/files/storage/storage.py +16 -0
- cognee/infrastructure/files/utils/get_data_file_path.py +14 -9
- cognee/infrastructure/files/utils/get_file_metadata.py +2 -1
- cognee/infrastructure/llm/LLMGateway.py +32 -5
- cognee/infrastructure/llm/config.py +6 -4
- cognee/infrastructure/llm/prompts/extract_query_time.txt +15 -0
- cognee/infrastructure/llm/prompts/generate_event_entity_prompt.txt +25 -0
- cognee/infrastructure/llm/prompts/generate_event_graph_prompt.txt +30 -0
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/knowledge_graph/extract_content_graph.py +16 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/__init__.py +2 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/extract_event_entities.py +44 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/__init__.py +1 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_content_graph.py +19 -15
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/extraction/knowledge_graph/extract_event_graph.py +46 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +3 -3
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +14 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +6 -4
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +28 -4
- cognee/infrastructure/llm/tokenizer/Gemini/adapter.py +2 -2
- cognee/infrastructure/llm/tokenizer/HuggingFace/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/Mistral/adapter.py +3 -3
- cognee/infrastructure/llm/tokenizer/TikToken/adapter.py +6 -6
- cognee/infrastructure/llm/utils.py +7 -7
- cognee/infrastructure/utils/run_sync.py +8 -1
- cognee/modules/chunking/models/DocumentChunk.py +4 -3
- cognee/modules/cloud/exceptions/CloudApiKeyMissingError.py +15 -0
- cognee/modules/cloud/exceptions/CloudConnectionError.py +15 -0
- cognee/modules/cloud/exceptions/__init__.py +2 -0
- cognee/modules/cloud/operations/__init__.py +1 -0
- cognee/modules/cloud/operations/check_api_key.py +25 -0
- cognee/modules/data/deletion/prune_system.py +1 -1
- cognee/modules/data/methods/__init__.py +2 -0
- cognee/modules/data/methods/check_dataset_name.py +1 -1
- cognee/modules/data/methods/create_authorized_dataset.py +19 -0
- cognee/modules/data/methods/get_authorized_dataset.py +11 -5
- cognee/modules/data/methods/get_authorized_dataset_by_name.py +16 -0
- cognee/modules/data/methods/get_dataset_data.py +1 -1
- cognee/modules/data/methods/load_or_create_datasets.py +2 -20
- cognee/modules/engine/models/Event.py +16 -0
- cognee/modules/engine/models/Interval.py +8 -0
- cognee/modules/engine/models/Timestamp.py +13 -0
- cognee/modules/engine/models/__init__.py +3 -0
- cognee/modules/engine/utils/__init__.py +2 -0
- cognee/modules/engine/utils/generate_event_datapoint.py +46 -0
- cognee/modules/engine/utils/generate_timestamp_datapoint.py +51 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +2 -2
- cognee/modules/graph/methods/get_formatted_graph_data.py +3 -2
- cognee/modules/graph/utils/__init__.py +1 -0
- cognee/modules/graph/utils/resolve_edges_to_text.py +71 -0
- cognee/modules/memify/__init__.py +1 -0
- cognee/modules/memify/memify.py +118 -0
- cognee/modules/notebooks/methods/__init__.py +5 -0
- cognee/modules/notebooks/methods/create_notebook.py +26 -0
- cognee/modules/notebooks/methods/delete_notebook.py +13 -0
- cognee/modules/notebooks/methods/get_notebook.py +21 -0
- cognee/modules/notebooks/methods/get_notebooks.py +18 -0
- cognee/modules/notebooks/methods/update_notebook.py +17 -0
- cognee/modules/notebooks/models/Notebook.py +53 -0
- cognee/modules/notebooks/models/__init__.py +1 -0
- cognee/modules/notebooks/operations/__init__.py +1 -0
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +55 -0
- cognee/modules/pipelines/__init__.py +1 -1
- cognee/modules/pipelines/exceptions/tasks.py +18 -0
- cognee/modules/pipelines/layers/__init__.py +1 -0
- cognee/modules/pipelines/layers/check_pipeline_run_qualification.py +59 -0
- cognee/modules/pipelines/layers/pipeline_execution_mode.py +127 -0
- cognee/modules/pipelines/layers/reset_dataset_pipeline_run_status.py +28 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_dataset.py +34 -0
- cognee/modules/pipelines/layers/resolve_authorized_user_datasets.py +55 -0
- cognee/modules/pipelines/layers/setup_and_check_environment.py +41 -0
- cognee/modules/pipelines/layers/validate_pipeline_tasks.py +20 -0
- cognee/modules/pipelines/methods/__init__.py +2 -0
- cognee/modules/pipelines/methods/get_pipeline_runs_by_dataset.py +34 -0
- cognee/modules/pipelines/methods/reset_pipeline_run_status.py +16 -0
- cognee/modules/pipelines/operations/__init__.py +0 -1
- cognee/modules/pipelines/operations/log_pipeline_run_initiated.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +24 -138
- cognee/modules/pipelines/operations/run_tasks.py +17 -41
- cognee/modules/retrieval/base_feedback.py +11 -0
- cognee/modules/retrieval/base_graph_retriever.py +18 -0
- cognee/modules/retrieval/base_retriever.py +1 -1
- cognee/modules/retrieval/code_retriever.py +8 -0
- cognee/modules/retrieval/coding_rules_retriever.py +31 -0
- cognee/modules/retrieval/completion_retriever.py +9 -3
- cognee/modules/retrieval/context_providers/TripletSearchContextProvider.py +1 -0
- cognee/modules/retrieval/cypher_search_retriever.py +1 -9
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +29 -13
- cognee/modules/retrieval/graph_completion_cot_retriever.py +30 -13
- cognee/modules/retrieval/graph_completion_retriever.py +107 -56
- cognee/modules/retrieval/graph_summary_completion_retriever.py +5 -1
- cognee/modules/retrieval/insights_retriever.py +14 -3
- cognee/modules/retrieval/natural_language_retriever.py +0 -4
- cognee/modules/retrieval/summaries_retriever.py +1 -1
- cognee/modules/retrieval/temporal_retriever.py +152 -0
- cognee/modules/retrieval/user_qa_feedback.py +83 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +7 -32
- cognee/modules/retrieval/utils/completion.py +10 -3
- cognee/modules/retrieval/utils/extract_uuid_from_node.py +18 -0
- cognee/modules/retrieval/utils/models.py +40 -0
- cognee/modules/search/methods/get_search_type_tools.py +168 -0
- cognee/modules/search/methods/no_access_control_search.py +47 -0
- cognee/modules/search/methods/search.py +239 -118
- cognee/modules/search/types/SearchResult.py +21 -0
- cognee/modules/search/types/SearchType.py +3 -0
- cognee/modules/search/types/__init__.py +1 -0
- cognee/modules/search/utils/__init__.py +2 -0
- cognee/modules/search/utils/prepare_search_result.py +41 -0
- cognee/modules/search/utils/transform_context_to_graph.py +38 -0
- cognee/modules/settings/get_settings.py +2 -2
- cognee/modules/sync/__init__.py +1 -0
- cognee/modules/sync/methods/__init__.py +23 -0
- cognee/modules/sync/methods/create_sync_operation.py +53 -0
- cognee/modules/sync/methods/get_sync_operation.py +107 -0
- cognee/modules/sync/methods/update_sync_operation.py +248 -0
- cognee/modules/sync/models/SyncOperation.py +142 -0
- cognee/modules/sync/models/__init__.py +3 -0
- cognee/modules/users/__init__.py +0 -1
- cognee/modules/users/methods/__init__.py +4 -1
- cognee/modules/users/methods/create_user.py +26 -1
- cognee/modules/users/methods/get_authenticated_user.py +36 -42
- cognee/modules/users/methods/get_default_user.py +3 -1
- cognee/modules/users/permissions/methods/get_specific_user_permission_datasets.py +2 -1
- cognee/root_dir.py +19 -0
- cognee/shared/CodeGraphEntities.py +1 -0
- cognee/shared/logging_utils.py +143 -32
- cognee/shared/utils.py +0 -1
- cognee/tasks/codingagents/coding_rule_associations.py +127 -0
- cognee/tasks/graph/extract_graph_from_data.py +6 -2
- cognee/tasks/ingestion/save_data_item_to_storage.py +23 -0
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/extract_subgraph.py +7 -0
- cognee/tasks/memify/extract_subgraph_chunks.py +11 -0
- cognee/tasks/repo_processor/get_local_dependencies.py +2 -0
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +144 -47
- cognee/tasks/storage/add_data_points.py +33 -3
- cognee/tasks/temporal_graph/__init__.py +1 -0
- cognee/tasks/temporal_graph/add_entities_to_event.py +85 -0
- cognee/tasks/temporal_graph/enrich_events.py +34 -0
- cognee/tasks/temporal_graph/extract_events_and_entities.py +32 -0
- cognee/tasks/temporal_graph/extract_knowledge_graph_from_events.py +41 -0
- cognee/tasks/temporal_graph/models.py +49 -0
- cognee/tests/integration/cli/__init__.py +3 -0
- cognee/tests/integration/cli/test_cli_integration.py +331 -0
- cognee/tests/integration/documents/PdfDocument_test.py +2 -2
- cognee/tests/integration/documents/TextDocument_test.py +2 -4
- cognee/tests/integration/documents/UnstructuredDocument_test.py +5 -8
- cognee/tests/{test_deletion.py → test_delete_hard.py} +0 -37
- cognee/tests/test_delete_soft.py +85 -0
- cognee/tests/test_kuzu.py +2 -2
- cognee/tests/test_neo4j.py +2 -2
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +7 -5
- cognee/tests/test_search_db.py +136 -23
- cognee/tests/test_temporal_graph.py +167 -0
- cognee/tests/unit/api/__init__.py +1 -0
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +246 -0
- cognee/tests/unit/cli/__init__.py +3 -0
- cognee/tests/unit/cli/test_cli_commands.py +483 -0
- cognee/tests/unit/cli/test_cli_edge_cases.py +625 -0
- cognee/tests/unit/cli/test_cli_main.py +173 -0
- cognee/tests/unit/cli/test_cli_runner.py +62 -0
- cognee/tests/unit/cli/test_cli_utils.py +127 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +10 -15
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +4 -3
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +4 -2
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +18 -2
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +225 -0
- cognee/tests/unit/modules/users/__init__.py +1 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +277 -0
- cognee/tests/unit/processing/utils/utils_test.py +20 -1
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/METADATA +13 -9
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/RECORD +247 -135
- cognee-0.3.0.dist-info/entry_points.txt +2 -0
- cognee/infrastructure/databases/graph/networkx/adapter.py +0 -1017
- cognee/infrastructure/pipeline/models/Operation.py +0 -60
- cognee/notebooks/github_analysis_step_by_step.ipynb +0 -37
- cognee/tests/tasks/descriptive_metrics/networkx_metrics_test.py +0 -7
- cognee/tests/unit/modules/search/search_methods_test.py +0 -223
- /cognee/{infrastructure/databases/graph/networkx → api/v1/memify}/__init__.py +0 -0
- /cognee/{infrastructure/pipeline/models → tasks/codingagents}/__init__.py +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/WHEEL +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.2.3.dev1.dist-info → cognee-0.3.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Type
|
|
2
|
+
from typing import Type, Optional
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
5
5
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
@@ -8,21 +8,25 @@ from cognee.infrastructure.llm.config import (
|
|
|
8
8
|
)
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
async def extract_content_graph(
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
# Check if the prompt path is an absolute path or just a filename
|
|
17
|
-
if os.path.isabs(prompt_path):
|
|
18
|
-
# directory containing the file
|
|
19
|
-
base_directory = os.path.dirname(prompt_path)
|
|
20
|
-
# just the filename itself
|
|
21
|
-
prompt_path = os.path.basename(prompt_path)
|
|
11
|
+
async def extract_content_graph(
|
|
12
|
+
content: str, response_model: Type[BaseModel], custom_prompt: Optional[str] = None
|
|
13
|
+
):
|
|
14
|
+
if custom_prompt:
|
|
15
|
+
system_prompt = custom_prompt
|
|
22
16
|
else:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
llm_config = get_llm_config()
|
|
18
|
+
prompt_path = llm_config.graph_prompt_path
|
|
19
|
+
|
|
20
|
+
# Check if the prompt path is an absolute path or just a filename
|
|
21
|
+
if os.path.isabs(prompt_path):
|
|
22
|
+
# directory containing the file
|
|
23
|
+
base_directory = os.path.dirname(prompt_path)
|
|
24
|
+
# just the filename itself
|
|
25
|
+
prompt_path = os.path.basename(prompt_path)
|
|
26
|
+
else:
|
|
27
|
+
base_directory = None
|
|
28
|
+
|
|
29
|
+
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
|
26
30
|
|
|
27
31
|
content_graph = await LLMGateway.acreate_structured_output(
|
|
28
32
|
content, system_prompt, response_model
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
from typing import Type
|
|
4
|
+
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
5
|
+
|
|
6
|
+
from cognee.infrastructure.llm.config import (
|
|
7
|
+
get_llm_config,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def extract_event_graph(content: str, response_model: Type[BaseModel]):
|
|
12
|
+
"""
|
|
13
|
+
Extracts an event graph from the given content using an LLM with a structured output format.
|
|
14
|
+
|
|
15
|
+
This function loads a temporal graph extraction prompt from the LLM configuration,
|
|
16
|
+
renders it as a system prompt, and queries the LLM to produce a structured event
|
|
17
|
+
graph matching the specified response model.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
content (str): The input text from which to extract the event graph.
|
|
21
|
+
response_model (Type[BaseModel]): A Pydantic model defining the structure of the expected output.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
BaseModel: An instance of the response_model populated with the extracted event graph.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
llm_config = get_llm_config()
|
|
28
|
+
|
|
29
|
+
prompt_path = llm_config.temporal_graph_prompt_path
|
|
30
|
+
|
|
31
|
+
# Check if the prompt path is an absolute path or just a filename
|
|
32
|
+
if os.path.isabs(prompt_path):
|
|
33
|
+
# directory containing the file
|
|
34
|
+
base_directory = os.path.dirname(prompt_path)
|
|
35
|
+
# just the filename itself
|
|
36
|
+
prompt_path = os.path.basename(prompt_path)
|
|
37
|
+
else:
|
|
38
|
+
base_directory = None
|
|
39
|
+
|
|
40
|
+
system_prompt = LLMGateway.render_prompt(prompt_path, {}, base_directory=base_directory)
|
|
41
|
+
|
|
42
|
+
content_graph = await LLMGateway.acreate_structured_output(
|
|
43
|
+
content, system_prompt, response_model
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return content_graph
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
CHANGED
|
@@ -23,7 +23,7 @@ class AnthropicAdapter(LLMInterface):
|
|
|
23
23
|
name = "Anthropic"
|
|
24
24
|
model: str
|
|
25
25
|
|
|
26
|
-
def __init__(self,
|
|
26
|
+
def __init__(self, max_completion_tokens: int, model: str = None):
|
|
27
27
|
import anthropic
|
|
28
28
|
|
|
29
29
|
self.aclient = instructor.patch(
|
|
@@ -31,7 +31,7 @@ class AnthropicAdapter(LLMInterface):
|
|
|
31
31
|
)
|
|
32
32
|
|
|
33
33
|
self.model = model
|
|
34
|
-
self.
|
|
34
|
+
self.max_completion_tokens = max_completion_tokens
|
|
35
35
|
|
|
36
36
|
@sleep_and_retry_async()
|
|
37
37
|
@rate_limit_async
|
|
@@ -57,7 +57,7 @@ class AnthropicAdapter(LLMInterface):
|
|
|
57
57
|
|
|
58
58
|
return await self.aclient(
|
|
59
59
|
model=self.model,
|
|
60
|
-
|
|
60
|
+
max_completion_tokens=4096,
|
|
61
61
|
max_retries=5,
|
|
62
62
|
messages=[
|
|
63
63
|
{
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
CHANGED
|
@@ -34,7 +34,7 @@ class GeminiAdapter(LLMInterface):
|
|
|
34
34
|
self,
|
|
35
35
|
api_key: str,
|
|
36
36
|
model: str,
|
|
37
|
-
|
|
37
|
+
max_completion_tokens: int,
|
|
38
38
|
endpoint: Optional[str] = None,
|
|
39
39
|
api_version: Optional[str] = None,
|
|
40
40
|
streaming: bool = False,
|
|
@@ -44,7 +44,7 @@ class GeminiAdapter(LLMInterface):
|
|
|
44
44
|
self.endpoint = endpoint
|
|
45
45
|
self.api_version = api_version
|
|
46
46
|
self.streaming = streaming
|
|
47
|
-
self.
|
|
47
|
+
self.max_completion_tokens = max_completion_tokens
|
|
48
48
|
|
|
49
49
|
@observe(as_type="generation")
|
|
50
50
|
@sleep_and_retry_async()
|
|
@@ -90,7 +90,7 @@ class GeminiAdapter(LLMInterface):
|
|
|
90
90
|
model=f"{self.model}",
|
|
91
91
|
messages=messages,
|
|
92
92
|
api_key=self.api_key,
|
|
93
|
-
|
|
93
|
+
max_completion_tokens=self.max_completion_tokens,
|
|
94
94
|
temperature=0.1,
|
|
95
95
|
response_format=response_schema,
|
|
96
96
|
timeout=100,
|
|
@@ -41,7 +41,7 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
41
41
|
api_key: str,
|
|
42
42
|
model: str,
|
|
43
43
|
name: str,
|
|
44
|
-
|
|
44
|
+
max_completion_tokens: int,
|
|
45
45
|
fallback_model: str = None,
|
|
46
46
|
fallback_api_key: str = None,
|
|
47
47
|
fallback_endpoint: str = None,
|
|
@@ -50,7 +50,7 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
50
50
|
self.model = model
|
|
51
51
|
self.api_key = api_key
|
|
52
52
|
self.endpoint = endpoint
|
|
53
|
-
self.
|
|
53
|
+
self.max_completion_tokens = max_completion_tokens
|
|
54
54
|
|
|
55
55
|
self.fallback_model = fallback_model
|
|
56
56
|
self.fallback_api_key = fallback_api_key
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
CHANGED
|
@@ -54,11 +54,15 @@ def get_llm_client():
|
|
|
54
54
|
# Check if max_token value is defined in liteLLM for given model
|
|
55
55
|
# if not use value from cognee configuration
|
|
56
56
|
from cognee.infrastructure.llm.utils import (
|
|
57
|
-
|
|
57
|
+
get_model_max_completion_tokens,
|
|
58
58
|
) # imported here to avoid circular imports
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
model_max_completion_tokens = get_model_max_completion_tokens(llm_config.llm_model)
|
|
61
|
+
max_completion_tokens = (
|
|
62
|
+
model_max_completion_tokens
|
|
63
|
+
if model_max_completion_tokens
|
|
64
|
+
else llm_config.llm_max_completion_tokens
|
|
65
|
+
)
|
|
62
66
|
|
|
63
67
|
if provider == LLMProvider.OPENAI:
|
|
64
68
|
if llm_config.llm_api_key is None:
|
|
@@ -74,7 +78,7 @@ def get_llm_client():
|
|
|
74
78
|
api_version=llm_config.llm_api_version,
|
|
75
79
|
model=llm_config.llm_model,
|
|
76
80
|
transcription_model=llm_config.transcription_model,
|
|
77
|
-
|
|
81
|
+
max_completion_tokens=max_completion_tokens,
|
|
78
82
|
streaming=llm_config.llm_streaming,
|
|
79
83
|
fallback_api_key=llm_config.fallback_api_key,
|
|
80
84
|
fallback_endpoint=llm_config.fallback_endpoint,
|
|
@@ -94,7 +98,7 @@ def get_llm_client():
|
|
|
94
98
|
llm_config.llm_api_key,
|
|
95
99
|
llm_config.llm_model,
|
|
96
100
|
"Ollama",
|
|
97
|
-
|
|
101
|
+
max_completion_tokens=max_completion_tokens,
|
|
98
102
|
)
|
|
99
103
|
|
|
100
104
|
elif provider == LLMProvider.ANTHROPIC:
|
|
@@ -102,7 +106,9 @@ def get_llm_client():
|
|
|
102
106
|
AnthropicAdapter,
|
|
103
107
|
)
|
|
104
108
|
|
|
105
|
-
return AnthropicAdapter(
|
|
109
|
+
return AnthropicAdapter(
|
|
110
|
+
max_completion_tokens=max_completion_tokens, model=llm_config.llm_model
|
|
111
|
+
)
|
|
106
112
|
|
|
107
113
|
elif provider == LLMProvider.CUSTOM:
|
|
108
114
|
if llm_config.llm_api_key is None:
|
|
@@ -117,7 +123,7 @@ def get_llm_client():
|
|
|
117
123
|
llm_config.llm_api_key,
|
|
118
124
|
llm_config.llm_model,
|
|
119
125
|
"Custom",
|
|
120
|
-
|
|
126
|
+
max_completion_tokens=max_completion_tokens,
|
|
121
127
|
fallback_api_key=llm_config.fallback_api_key,
|
|
122
128
|
fallback_endpoint=llm_config.fallback_endpoint,
|
|
123
129
|
fallback_model=llm_config.fallback_model,
|
|
@@ -134,7 +140,7 @@ def get_llm_client():
|
|
|
134
140
|
return GeminiAdapter(
|
|
135
141
|
api_key=llm_config.llm_api_key,
|
|
136
142
|
model=llm_config.llm_model,
|
|
137
|
-
|
|
143
|
+
max_completion_tokens=max_completion_tokens,
|
|
138
144
|
endpoint=llm_config.llm_endpoint,
|
|
139
145
|
api_version=llm_config.llm_api_version,
|
|
140
146
|
streaming=llm_config.llm_streaming,
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
CHANGED
|
@@ -30,16 +30,18 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
30
30
|
- model
|
|
31
31
|
- api_key
|
|
32
32
|
- endpoint
|
|
33
|
-
-
|
|
33
|
+
- max_completion_tokens
|
|
34
34
|
- aclient
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
|
-
def __init__(
|
|
37
|
+
def __init__(
|
|
38
|
+
self, endpoint: str, api_key: str, model: str, name: str, max_completion_tokens: int
|
|
39
|
+
):
|
|
38
40
|
self.name = name
|
|
39
41
|
self.model = model
|
|
40
42
|
self.api_key = api_key
|
|
41
43
|
self.endpoint = endpoint
|
|
42
|
-
self.
|
|
44
|
+
self.max_completion_tokens = max_completion_tokens
|
|
43
45
|
|
|
44
46
|
self.aclient = instructor.from_openai(
|
|
45
47
|
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
|
@@ -159,7 +161,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
159
161
|
],
|
|
160
162
|
}
|
|
161
163
|
],
|
|
162
|
-
|
|
164
|
+
max_completion_tokens=300,
|
|
163
165
|
)
|
|
164
166
|
|
|
165
167
|
# Ensure response is valid before accessing .choices[0].message.content
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
CHANGED
|
@@ -23,9 +23,12 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|
|
23
23
|
sleep_and_retry_sync,
|
|
24
24
|
)
|
|
25
25
|
from cognee.modules.observability.get_observe import get_observe
|
|
26
|
+
from cognee.shared.logging_utils import get_logger
|
|
26
27
|
|
|
27
28
|
observe = get_observe()
|
|
28
29
|
|
|
30
|
+
logger = get_logger()
|
|
31
|
+
|
|
29
32
|
|
|
30
33
|
class OpenAIAdapter(LLMInterface):
|
|
31
34
|
"""
|
|
@@ -64,7 +67,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
64
67
|
api_version: str,
|
|
65
68
|
model: str,
|
|
66
69
|
transcription_model: str,
|
|
67
|
-
|
|
70
|
+
max_completion_tokens: int,
|
|
68
71
|
streaming: bool = False,
|
|
69
72
|
fallback_model: str = None,
|
|
70
73
|
fallback_api_key: str = None,
|
|
@@ -77,7 +80,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
77
80
|
self.api_key = api_key
|
|
78
81
|
self.endpoint = endpoint
|
|
79
82
|
self.api_version = api_version
|
|
80
|
-
self.
|
|
83
|
+
self.max_completion_tokens = max_completion_tokens
|
|
81
84
|
self.streaming = streaming
|
|
82
85
|
|
|
83
86
|
self.fallback_model = fallback_model
|
|
@@ -129,6 +132,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
129
132
|
api_version=self.api_version,
|
|
130
133
|
response_model=response_model,
|
|
131
134
|
max_retries=self.MAX_RETRIES,
|
|
135
|
+
extra_body={"reasoning_effort": "minimal"},
|
|
132
136
|
)
|
|
133
137
|
except (
|
|
134
138
|
ContentFilterFinishReasonError,
|
|
@@ -139,7 +143,27 @@ class OpenAIAdapter(LLMInterface):
|
|
|
139
143
|
isinstance(error, InstructorRetryException)
|
|
140
144
|
and "content management policy" not in str(error).lower()
|
|
141
145
|
):
|
|
142
|
-
|
|
146
|
+
logger.debug(
|
|
147
|
+
"LLM Model does not support reasoning_effort parameter, trying call without the parameter."
|
|
148
|
+
)
|
|
149
|
+
return await self.aclient.chat.completions.create(
|
|
150
|
+
model=self.model,
|
|
151
|
+
messages=[
|
|
152
|
+
{
|
|
153
|
+
"role": "user",
|
|
154
|
+
"content": f"""{text_input}""",
|
|
155
|
+
},
|
|
156
|
+
{
|
|
157
|
+
"role": "system",
|
|
158
|
+
"content": system_prompt,
|
|
159
|
+
},
|
|
160
|
+
],
|
|
161
|
+
api_key=self.api_key,
|
|
162
|
+
api_base=self.endpoint,
|
|
163
|
+
api_version=self.api_version,
|
|
164
|
+
response_model=response_model,
|
|
165
|
+
max_retries=self.MAX_RETRIES,
|
|
166
|
+
)
|
|
143
167
|
|
|
144
168
|
if not (self.fallback_model and self.fallback_api_key):
|
|
145
169
|
raise ContentPolicyFilterError(
|
|
@@ -301,7 +325,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
301
325
|
api_key=self.api_key,
|
|
302
326
|
api_base=self.endpoint,
|
|
303
327
|
api_version=self.api_version,
|
|
304
|
-
|
|
328
|
+
max_completion_tokens=300,
|
|
305
329
|
max_retries=self.MAX_RETRIES,
|
|
306
330
|
)
|
|
307
331
|
|
|
@@ -17,10 +17,10 @@ class GeminiTokenizer(TokenizerInterface):
|
|
|
17
17
|
def __init__(
|
|
18
18
|
self,
|
|
19
19
|
model: str,
|
|
20
|
-
|
|
20
|
+
max_completion_tokens: int = 3072,
|
|
21
21
|
):
|
|
22
22
|
self.model = model
|
|
23
|
-
self.
|
|
23
|
+
self.max_completion_tokens = max_completion_tokens
|
|
24
24
|
|
|
25
25
|
# Get LLM API key from config
|
|
26
26
|
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
|
@@ -14,17 +14,17 @@ class HuggingFaceTokenizer(TokenizerInterface):
|
|
|
14
14
|
|
|
15
15
|
Instance variables include:
|
|
16
16
|
- model: str
|
|
17
|
-
-
|
|
17
|
+
- max_completion_tokens: int
|
|
18
18
|
- tokenizer: AutoTokenizer
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
21
|
def __init__(
|
|
22
22
|
self,
|
|
23
23
|
model: str,
|
|
24
|
-
|
|
24
|
+
max_completion_tokens: int = 512,
|
|
25
25
|
):
|
|
26
26
|
self.model = model
|
|
27
|
-
self.
|
|
27
|
+
self.max_completion_tokens = max_completion_tokens
|
|
28
28
|
|
|
29
29
|
# Import here to make it an optional dependency
|
|
30
30
|
from transformers import AutoTokenizer
|
|
@@ -16,17 +16,17 @@ class MistralTokenizer(TokenizerInterface):
|
|
|
16
16
|
|
|
17
17
|
Instance variables include:
|
|
18
18
|
- model: str
|
|
19
|
-
-
|
|
19
|
+
- max_completion_tokens: int
|
|
20
20
|
- tokenizer: MistralTokenizer
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
25
|
model: str,
|
|
26
|
-
|
|
26
|
+
max_completion_tokens: int = 3072,
|
|
27
27
|
):
|
|
28
28
|
self.model = model
|
|
29
|
-
self.
|
|
29
|
+
self.max_completion_tokens = max_completion_tokens
|
|
30
30
|
|
|
31
31
|
# Import here to make it an optional dependency
|
|
32
32
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
|
@@ -13,10 +13,10 @@ class TikTokenTokenizer(TokenizerInterface):
|
|
|
13
13
|
def __init__(
|
|
14
14
|
self,
|
|
15
15
|
model: Optional[str] = None,
|
|
16
|
-
|
|
16
|
+
max_completion_tokens: int = 8191,
|
|
17
17
|
):
|
|
18
18
|
self.model = model
|
|
19
|
-
self.
|
|
19
|
+
self.max_completion_tokens = max_completion_tokens
|
|
20
20
|
# Initialize TikToken for GPT based on model
|
|
21
21
|
if model:
|
|
22
22
|
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
|
@@ -93,9 +93,9 @@ class TikTokenTokenizer(TokenizerInterface):
|
|
|
93
93
|
num_tokens = len(self.tokenizer.encode(text))
|
|
94
94
|
return num_tokens
|
|
95
95
|
|
|
96
|
-
def
|
|
96
|
+
def trim_text_to_max_completion_tokens(self, text: str) -> str:
|
|
97
97
|
"""
|
|
98
|
-
Trim the text so that the number of tokens does not exceed
|
|
98
|
+
Trim the text so that the number of tokens does not exceed max_completion_tokens.
|
|
99
99
|
|
|
100
100
|
Parameters:
|
|
101
101
|
-----------
|
|
@@ -111,13 +111,13 @@ class TikTokenTokenizer(TokenizerInterface):
|
|
|
111
111
|
num_tokens = self.count_tokens(text)
|
|
112
112
|
|
|
113
113
|
# If the number of tokens is within the limit, return the text as is
|
|
114
|
-
if num_tokens <= self.
|
|
114
|
+
if num_tokens <= self.max_completion_tokens:
|
|
115
115
|
return text
|
|
116
116
|
|
|
117
117
|
# If the number exceeds the limit, trim the text
|
|
118
118
|
# This is a simple trim, it may cut words in half; consider using word boundaries for a cleaner cut
|
|
119
119
|
encoded_text = self.tokenizer.encode(text)
|
|
120
|
-
trimmed_encoded_text = encoded_text[: self.
|
|
120
|
+
trimmed_encoded_text = encoded_text[: self.max_completion_tokens]
|
|
121
121
|
# Decoding the trimmed text
|
|
122
122
|
trimmed_text = self.tokenizer.decode(trimmed_encoded_text)
|
|
123
123
|
return trimmed_text
|
|
@@ -32,13 +32,13 @@ def get_max_chunk_tokens():
|
|
|
32
32
|
|
|
33
33
|
# We need to make sure chunk size won't take more than half of LLM max context token size
|
|
34
34
|
# but it also can't be bigger than the embedding engine max token size
|
|
35
|
-
llm_cutoff_point = llm_client.
|
|
36
|
-
max_chunk_tokens = min(embedding_engine.
|
|
35
|
+
llm_cutoff_point = llm_client.max_completion_tokens // 2 # Round down the division
|
|
36
|
+
max_chunk_tokens = min(embedding_engine.max_completion_tokens, llm_cutoff_point)
|
|
37
37
|
|
|
38
38
|
return max_chunk_tokens
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def
|
|
41
|
+
def get_model_max_completion_tokens(model_name: str):
|
|
42
42
|
"""
|
|
43
43
|
Retrieve the maximum token limit for a specified model name if it exists.
|
|
44
44
|
|
|
@@ -56,15 +56,15 @@ def get_model_max_tokens(model_name: str):
|
|
|
56
56
|
|
|
57
57
|
Number of max tokens of model, or None if model is unknown
|
|
58
58
|
"""
|
|
59
|
-
|
|
59
|
+
max_completion_tokens = None
|
|
60
60
|
|
|
61
61
|
if model_name in litellm.model_cost:
|
|
62
|
-
|
|
63
|
-
logger.debug(f"Max input tokens for {model_name}: {
|
|
62
|
+
max_completion_tokens = litellm.model_cost[model_name]["max_tokens"]
|
|
63
|
+
logger.debug(f"Max input tokens for {model_name}: {max_completion_tokens}")
|
|
64
64
|
else:
|
|
65
65
|
logger.info("Model not found in LiteLLM's model_cost.")
|
|
66
66
|
|
|
67
|
-
return
|
|
67
|
+
return max_completion_tokens
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
async def test_llm_connection():
|
|
@@ -8,8 +8,15 @@ def run_sync(coro, timeout=None):
|
|
|
8
8
|
|
|
9
9
|
def runner():
|
|
10
10
|
nonlocal result, exception
|
|
11
|
+
|
|
11
12
|
try:
|
|
12
|
-
|
|
13
|
+
try:
|
|
14
|
+
running_loop = asyncio.get_running_loop()
|
|
15
|
+
|
|
16
|
+
result = asyncio.run_coroutine_threadsafe(coro, running_loop).result(timeout)
|
|
17
|
+
except RuntimeError:
|
|
18
|
+
result = asyncio.run(coro)
|
|
19
|
+
|
|
13
20
|
except Exception as e:
|
|
14
21
|
exception = e
|
|
15
22
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List, Union
|
|
2
2
|
|
|
3
3
|
from cognee.infrastructure.engine import DataPoint
|
|
4
4
|
from cognee.modules.data.processing.document_types import Document
|
|
5
5
|
from cognee.modules.engine.models import Entity
|
|
6
|
+
from cognee.tasks.temporal_graph.models import Event
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class DocumentChunk(DataPoint):
|
|
@@ -20,7 +21,7 @@ class DocumentChunk(DataPoint):
|
|
|
20
21
|
- chunk_index: The index of the chunk in the original document.
|
|
21
22
|
- cut_type: The type of cut that defined this chunk.
|
|
22
23
|
- is_part_of: The document to which this chunk belongs.
|
|
23
|
-
- contains: A list of entities contained within the chunk (default is None).
|
|
24
|
+
- contains: A list of entities or events contained within the chunk (default is None).
|
|
24
25
|
- metadata: A dictionary to hold meta information related to the chunk, including index
|
|
25
26
|
fields.
|
|
26
27
|
"""
|
|
@@ -30,6 +31,6 @@ class DocumentChunk(DataPoint):
|
|
|
30
31
|
chunk_index: int
|
|
31
32
|
cut_type: str
|
|
32
33
|
is_part_of: Document
|
|
33
|
-
contains: List[Entity] = None
|
|
34
|
+
contains: List[Union[Entity, Event]] = None
|
|
34
35
|
|
|
35
36
|
metadata: dict = {"index_fields": ["text"]}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from fastapi import status
|
|
2
|
+
|
|
3
|
+
from cognee.exceptions.exceptions import CogneeConfigurationError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CloudApiKeyMissingError(CogneeConfigurationError):
|
|
7
|
+
"""Raised when the API key for the cloud service is not provided."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
message: str = "Failed to connect to the cloud service. Please add your API key to local instance.",
|
|
12
|
+
name: str = "CloudApiKeyMissingError",
|
|
13
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
|
14
|
+
):
|
|
15
|
+
super().__init__(message, name, status_code)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from fastapi import status
|
|
2
|
+
|
|
3
|
+
from cognee.exceptions.exceptions import CogneeConfigurationError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CloudConnectionError(CogneeConfigurationError):
|
|
7
|
+
"""Raised when the connection to the cloud service fails."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
message: str = "Failed to connect to the cloud service. Please check your cloud API key in local instance.",
|
|
12
|
+
name: str = "CloudConnnectionError",
|
|
13
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
14
|
+
):
|
|
15
|
+
super().__init__(message, name, status_code)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .check_api_key import check_api_key
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import aiohttp
|
|
2
|
+
|
|
3
|
+
from cognee.modules.cloud.exceptions import CloudConnectionError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def check_api_key(auth_token: str):
|
|
7
|
+
cloud_base_url = "http://localhost:8001"
|
|
8
|
+
|
|
9
|
+
url = f"{cloud_base_url}/api/api-keys/check"
|
|
10
|
+
headers = {"X-Api-Key": auth_token}
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
async with aiohttp.ClientSession() as session:
|
|
14
|
+
async with session.post(url, headers=headers) as response:
|
|
15
|
+
if response.status == 200:
|
|
16
|
+
return
|
|
17
|
+
else:
|
|
18
|
+
error_text = await response.text()
|
|
19
|
+
|
|
20
|
+
raise CloudConnectionError(
|
|
21
|
+
f"Failed to connect to cloud instance: {response.status} - {error_text}"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise CloudConnectionError(f"Failed to connect to cloud instance: {str(e)}")
|
|
@@ -3,7 +3,7 @@ from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_eng
|
|
|
3
3
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
async def prune_system(graph=True, vector=True, metadata=
|
|
6
|
+
async def prune_system(graph=True, vector=True, metadata=True):
|
|
7
7
|
if graph:
|
|
8
8
|
graph_engine = await get_graph_engine()
|
|
9
9
|
await graph_engine.delete_graph()
|
|
@@ -7,6 +7,7 @@ from .get_datasets import get_datasets
|
|
|
7
7
|
from .get_datasets_by_name import get_datasets_by_name
|
|
8
8
|
from .get_dataset_data import get_dataset_data
|
|
9
9
|
from .get_authorized_dataset import get_authorized_dataset
|
|
10
|
+
from .get_authorized_dataset_by_name import get_authorized_dataset_by_name
|
|
10
11
|
from .get_data import get_data
|
|
11
12
|
from .get_unique_dataset_id import get_unique_dataset_id
|
|
12
13
|
from .get_authorized_existing_datasets import get_authorized_existing_datasets
|
|
@@ -18,6 +19,7 @@ from .delete_data import delete_data
|
|
|
18
19
|
|
|
19
20
|
# Create
|
|
20
21
|
from .load_or_create_datasets import load_or_create_datasets
|
|
22
|
+
from .create_authorized_dataset import create_authorized_dataset
|
|
21
23
|
|
|
22
24
|
# Check
|
|
23
25
|
from .check_dataset_name import check_dataset_name
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
2
|
+
from cognee.modules.users.models import User
|
|
3
|
+
from cognee.modules.data.models import Dataset
|
|
4
|
+
from cognee.modules.users.permissions.methods import give_permission_on_dataset
|
|
5
|
+
from .create_dataset import create_dataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
async def create_authorized_dataset(dataset_name: str, user: User) -> Dataset:
|
|
9
|
+
db_engine = get_relational_engine()
|
|
10
|
+
|
|
11
|
+
async with db_engine.get_async_session() as session:
|
|
12
|
+
new_dataset = await create_dataset(dataset_name, user, session)
|
|
13
|
+
|
|
14
|
+
await give_permission_on_dataset(user, new_dataset.id, "read")
|
|
15
|
+
await give_permission_on_dataset(user, new_dataset.id, "write")
|
|
16
|
+
await give_permission_on_dataset(user, new_dataset.id, "delete")
|
|
17
|
+
await give_permission_on_dataset(user, new_dataset.id, "share")
|
|
18
|
+
|
|
19
|
+
return new_dataset
|