cognee 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +9 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +3 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +30 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/__init__.py +4 -0
- cognee/api/v1/ontologies/ontologies.py +158 -0
- cognee/api/v1/ontologies/routers/__init__.py +0 -0
- cognee/api/v1/ontologies/routers/get_ontology_router.py +109 -0
- cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
- cognee/api/v1/search/search.py +4 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/cli/commands/cognify_command.py +8 -1
- cognee/cli/config.py +1 -1
- cognee/context_global_variables.py +86 -9
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/cache/config.py +3 -1
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
- cognee/infrastructure/databases/graph/config.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +3 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +66 -18
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +5 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +6 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -13
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/engine/models/Edge.py +13 -1
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/files/utils/guess_file_type.py +4 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +37 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +22 -18
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +47 -38
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +46 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +20 -10
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +23 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +36 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +47 -36
- cognee/infrastructure/loaders/LoaderEngine.py +1 -0
- cognee/infrastructure/loaders/core/__init__.py +2 -1
- cognee/infrastructure/loaders/core/csv_loader.py +93 -0
- cognee/infrastructure/loaders/core/text_loader.py +1 -2
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
- cognee/infrastructure/loaders/supported_loaders.py +2 -1
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
- cognee/modules/chunking/CsvChunker.py +35 -0
- cognee/modules/chunking/models/DocumentChunk.py +2 -1
- cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/create_dataset.py +4 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/data/methods/get_dataset_ids.py +5 -1
- cognee/modules/data/methods/get_unique_data_id.py +68 -0
- cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
- cognee/modules/data/models/Dataset.py +2 -0
- cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
- cognee/modules/data/processing/document_types/__init__.py +1 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +89 -39
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
- cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
- cognee/modules/ingestion/identify.py +4 -4
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
- cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/base_graph_retriever.py +7 -3
- cognee/modules/retrieval/base_retriever.py +7 -3
- cognee/modules/retrieval/completion_retriever.py +11 -4
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +10 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +18 -51
- cognee/modules/retrieval/graph_completion_retriever.py +14 -1
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +13 -2
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +43 -11
- cognee/modules/retrieval/utils/completion.py +2 -22
- cognee/modules/run_custom_pipeline/__init__.py +1 -0
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +76 -0
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +26 -3
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/create_user.py +12 -27
- cognee/modules/users/methods/get_authenticated_user.py +3 -2
- cognee/modules/users/methods/get_default_user.py +4 -2
- cognee/modules/users/methods/get_user.py +1 -1
- cognee/modules/users/methods/get_user_by_email.py +1 -1
- cognee/modules/users/models/DatasetDatabase.py +24 -3
- cognee/modules/users/models/Tenant.py +6 -7
- cognee/modules/users/models/User.py +6 -5
- cognee/modules/users/models/UserTenant.py +12 -0
- cognee/modules/users/models/__init__.py +1 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
- cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
- cognee/modules/users/tenants/methods/__init__.py +1 -0
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
- cognee/modules/users/tenants/methods/create_tenant.py +22 -8
- cognee/modules/users/tenants/methods/select_tenant.py +62 -0
- cognee/shared/logging_utils.py +6 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/chunks/__init__.py +1 -0
- cognee/tasks/chunks/chunk_by_row.py +94 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/documents/classify_documents.py +2 -0
- cognee/tasks/feedback/generate_improved_answers.py +3 -3
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/ingestion/ingest_data.py +1 -1
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/cognify_session.py +41 -0
- cognee/tasks/memify/extract_user_sessions.py +73 -0
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tasks/storage/index_data_points.py +33 -22
- cognee/tasks/storage/index_graph_edges.py +37 -57
- cognee/tests/integration/documents/CsvDocument_test.py +70 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
- cognee/tests/test_add_docling_document.py +2 -2
- cognee/tests/test_cognee_server_start.py +84 -3
- cognee/tests/test_conversation_history.py +68 -5
- cognee/tests/test_data/example_with_header.csv +3 -0
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_edge_ingestion.py +27 -0
- cognee/tests/test_feedback_enrichment.py +1 -1
- cognee/tests/test_library.py +6 -4
- cognee/tests/test_load.py +62 -0
- cognee/tests/test_multi_tenancy.py +165 -0
- cognee/tests/test_parallel_databases.py +2 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_relational_db_migration.py +54 -2
- cognee/tests/test_search_db.py +44 -2
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
- cognee/tests/unit/api/test_ontology_endpoint.py +252 -0
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
- cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
- cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
- cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
- cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
- cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/METADATA +11 -7
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/RECORD +212 -160
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/entry_points.txt +0 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/WHEEL +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/NOTICE.md +0 -0
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
CHANGED
|
@@ -24,6 +24,7 @@ class LLMProvider(Enum):
|
|
|
24
24
|
- CUSTOM: Represents a custom provider option.
|
|
25
25
|
- GEMINI: Represents the Gemini provider.
|
|
26
26
|
- MISTRAL: Represents the Mistral AI provider.
|
|
27
|
+
- BEDROCK: Represents the AWS Bedrock provider.
|
|
27
28
|
"""
|
|
28
29
|
|
|
29
30
|
OPENAI = "openai"
|
|
@@ -32,6 +33,7 @@ class LLMProvider(Enum):
|
|
|
32
33
|
CUSTOM = "custom"
|
|
33
34
|
GEMINI = "gemini"
|
|
34
35
|
MISTRAL = "mistral"
|
|
36
|
+
BEDROCK = "bedrock"
|
|
35
37
|
|
|
36
38
|
|
|
37
39
|
def get_llm_client(raise_api_key_error: bool = True):
|
|
@@ -81,6 +83,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
81
83
|
model=llm_config.llm_model,
|
|
82
84
|
transcription_model=llm_config.transcription_model,
|
|
83
85
|
max_completion_tokens=max_completion_tokens,
|
|
86
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
84
87
|
streaming=llm_config.llm_streaming,
|
|
85
88
|
fallback_api_key=llm_config.fallback_api_key,
|
|
86
89
|
fallback_endpoint=llm_config.fallback_endpoint,
|
|
@@ -101,6 +104,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
101
104
|
llm_config.llm_model,
|
|
102
105
|
"Ollama",
|
|
103
106
|
max_completion_tokens=max_completion_tokens,
|
|
107
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
104
108
|
)
|
|
105
109
|
|
|
106
110
|
elif provider == LLMProvider.ANTHROPIC:
|
|
@@ -109,7 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
109
113
|
)
|
|
110
114
|
|
|
111
115
|
return AnthropicAdapter(
|
|
112
|
-
max_completion_tokens=max_completion_tokens,
|
|
116
|
+
max_completion_tokens=max_completion_tokens,
|
|
117
|
+
model=llm_config.llm_model,
|
|
118
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
113
119
|
)
|
|
114
120
|
|
|
115
121
|
elif provider == LLMProvider.CUSTOM:
|
|
@@ -126,6 +132,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
126
132
|
llm_config.llm_model,
|
|
127
133
|
"Custom",
|
|
128
134
|
max_completion_tokens=max_completion_tokens,
|
|
135
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
129
136
|
fallback_api_key=llm_config.fallback_api_key,
|
|
130
137
|
fallback_endpoint=llm_config.fallback_endpoint,
|
|
131
138
|
fallback_model=llm_config.fallback_model,
|
|
@@ -145,10 +152,11 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
145
152
|
max_completion_tokens=max_completion_tokens,
|
|
146
153
|
endpoint=llm_config.llm_endpoint,
|
|
147
154
|
api_version=llm_config.llm_api_version,
|
|
155
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
148
156
|
)
|
|
149
157
|
|
|
150
158
|
elif provider == LLMProvider.MISTRAL:
|
|
151
|
-
if llm_config.llm_api_key is None:
|
|
159
|
+
if llm_config.llm_api_key is None and raise_api_key_error:
|
|
152
160
|
raise LLMAPIKeyNotSetError()
|
|
153
161
|
|
|
154
162
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
|
@@ -160,21 +168,23 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
160
168
|
model=llm_config.llm_model,
|
|
161
169
|
max_completion_tokens=max_completion_tokens,
|
|
162
170
|
endpoint=llm_config.llm_endpoint,
|
|
171
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
163
172
|
)
|
|
164
173
|
|
|
165
|
-
elif provider == LLMProvider.
|
|
166
|
-
if llm_config.llm_api_key is None:
|
|
167
|
-
|
|
174
|
+
elif provider == LLMProvider.BEDROCK:
|
|
175
|
+
# if llm_config.llm_api_key is None and raise_api_key_error:
|
|
176
|
+
# raise LLMAPIKeyNotSetError()
|
|
168
177
|
|
|
169
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.
|
|
170
|
-
|
|
178
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.bedrock.adapter import (
|
|
179
|
+
BedrockAdapter,
|
|
171
180
|
)
|
|
172
181
|
|
|
173
|
-
return
|
|
174
|
-
api_key=llm_config.llm_api_key,
|
|
182
|
+
return BedrockAdapter(
|
|
175
183
|
model=llm_config.llm_model,
|
|
184
|
+
api_key=llm_config.llm_api_key,
|
|
176
185
|
max_completion_tokens=max_completion_tokens,
|
|
177
|
-
|
|
186
|
+
streaming=llm_config.llm_streaming,
|
|
187
|
+
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
178
188
|
)
|
|
179
189
|
|
|
180
190
|
else:
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
CHANGED
|
@@ -10,6 +10,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|
|
10
10
|
LLMInterface,
|
|
11
11
|
)
|
|
12
12
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
13
|
+
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
13
14
|
|
|
14
15
|
import logging
|
|
15
16
|
from tenacity import (
|
|
@@ -37,28 +38,38 @@ class MistralAdapter(LLMInterface):
|
|
|
37
38
|
model: str
|
|
38
39
|
api_key: str
|
|
39
40
|
max_completion_tokens: int
|
|
41
|
+
default_instructor_mode = "mistral_tools"
|
|
40
42
|
|
|
41
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
api_key: str,
|
|
46
|
+
model: str,
|
|
47
|
+
max_completion_tokens: int,
|
|
48
|
+
endpoint: str = None,
|
|
49
|
+
instructor_mode: str = None,
|
|
50
|
+
):
|
|
42
51
|
from mistralai import Mistral
|
|
43
52
|
|
|
44
53
|
self.model = model
|
|
45
54
|
self.max_completion_tokens = max_completion_tokens
|
|
46
55
|
|
|
56
|
+
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
57
|
+
|
|
47
58
|
self.aclient = instructor.from_litellm(
|
|
48
59
|
litellm.acompletion,
|
|
49
|
-
mode=instructor.Mode.
|
|
60
|
+
mode=instructor.Mode(self.instructor_mode),
|
|
50
61
|
api_key=get_llm_config().llm_api_key,
|
|
51
62
|
)
|
|
52
63
|
|
|
53
64
|
@retry(
|
|
54
65
|
stop=stop_after_delay(128),
|
|
55
|
-
wait=wait_exponential_jitter(
|
|
66
|
+
wait=wait_exponential_jitter(8, 128),
|
|
56
67
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
57
68
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
58
69
|
reraise=True,
|
|
59
70
|
)
|
|
60
71
|
async def acreate_structured_output(
|
|
61
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
72
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
62
73
|
) -> BaseModel:
|
|
63
74
|
"""
|
|
64
75
|
Generate a response from the user query.
|
|
@@ -87,13 +98,14 @@ class MistralAdapter(LLMInterface):
|
|
|
87
98
|
},
|
|
88
99
|
]
|
|
89
100
|
try:
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
101
|
+
async with llm_rate_limiter_context_manager():
|
|
102
|
+
response = await self.aclient.chat.completions.create(
|
|
103
|
+
model=self.model,
|
|
104
|
+
max_tokens=self.max_completion_tokens,
|
|
105
|
+
max_retries=2,
|
|
106
|
+
messages=messages,
|
|
107
|
+
response_model=response_model,
|
|
108
|
+
)
|
|
97
109
|
if response.choices and response.choices[0].message.content:
|
|
98
110
|
content = response.choices[0].message.content
|
|
99
111
|
return response_model.model_validate_json(content)
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
CHANGED
|
@@ -11,6 +11,8 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|
|
11
11
|
)
|
|
12
12
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
13
13
|
from cognee.shared.logging_utils import get_logger
|
|
14
|
+
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
15
|
+
|
|
14
16
|
from tenacity import (
|
|
15
17
|
retry,
|
|
16
18
|
stop_after_delay,
|
|
@@ -42,8 +44,16 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
42
44
|
- aclient
|
|
43
45
|
"""
|
|
44
46
|
|
|
47
|
+
default_instructor_mode = "json_mode"
|
|
48
|
+
|
|
45
49
|
def __init__(
|
|
46
|
-
self,
|
|
50
|
+
self,
|
|
51
|
+
endpoint: str,
|
|
52
|
+
api_key: str,
|
|
53
|
+
model: str,
|
|
54
|
+
name: str,
|
|
55
|
+
max_completion_tokens: int,
|
|
56
|
+
instructor_mode: str = None,
|
|
47
57
|
):
|
|
48
58
|
self.name = name
|
|
49
59
|
self.model = model
|
|
@@ -51,19 +61,22 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
51
61
|
self.endpoint = endpoint
|
|
52
62
|
self.max_completion_tokens = max_completion_tokens
|
|
53
63
|
|
|
64
|
+
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
65
|
+
|
|
54
66
|
self.aclient = instructor.from_openai(
|
|
55
|
-
OpenAI(base_url=self.endpoint, api_key=self.api_key),
|
|
67
|
+
OpenAI(base_url=self.endpoint, api_key=self.api_key),
|
|
68
|
+
mode=instructor.Mode(self.instructor_mode),
|
|
56
69
|
)
|
|
57
70
|
|
|
58
71
|
@retry(
|
|
59
72
|
stop=stop_after_delay(128),
|
|
60
|
-
wait=wait_exponential_jitter(
|
|
73
|
+
wait=wait_exponential_jitter(8, 128),
|
|
61
74
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
62
75
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
63
76
|
reraise=True,
|
|
64
77
|
)
|
|
65
78
|
async def acreate_structured_output(
|
|
66
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
79
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
67
80
|
) -> BaseModel:
|
|
68
81
|
"""
|
|
69
82
|
Generate a structured output from the LLM using the provided text and system prompt.
|
|
@@ -84,33 +97,33 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
84
97
|
|
|
85
98
|
- BaseModel: A structured output that conforms to the specified response model.
|
|
86
99
|
"""
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
100
|
+
async with llm_rate_limiter_context_manager():
|
|
101
|
+
response = self.aclient.chat.completions.create(
|
|
102
|
+
model=self.model,
|
|
103
|
+
messages=[
|
|
104
|
+
{
|
|
105
|
+
"role": "user",
|
|
106
|
+
"content": f"{text_input}",
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"role": "system",
|
|
110
|
+
"content": system_prompt,
|
|
111
|
+
},
|
|
112
|
+
],
|
|
113
|
+
max_retries=2,
|
|
114
|
+
response_model=response_model,
|
|
115
|
+
)
|
|
103
116
|
|
|
104
117
|
return response
|
|
105
118
|
|
|
106
119
|
@retry(
|
|
107
120
|
stop=stop_after_delay(128),
|
|
108
|
-
wait=wait_exponential_jitter(
|
|
121
|
+
wait=wait_exponential_jitter(8, 128),
|
|
109
122
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
110
123
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
111
124
|
reraise=True,
|
|
112
125
|
)
|
|
113
|
-
async def create_transcript(self, input_file: str) -> str:
|
|
126
|
+
async def create_transcript(self, input_file: str, **kwargs) -> str:
|
|
114
127
|
"""
|
|
115
128
|
Generate an audio transcript from a user query.
|
|
116
129
|
|
|
@@ -149,7 +162,7 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
149
162
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
150
163
|
reraise=True,
|
|
151
164
|
)
|
|
152
|
-
async def transcribe_image(self, input_file: str) -> str:
|
|
165
|
+
async def transcribe_image(self, input_file: str, **kwargs) -> str:
|
|
153
166
|
"""
|
|
154
167
|
Transcribe content from an image using base64 encoding.
|
|
155
168
|
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
CHANGED
|
@@ -22,6 +22,7 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|
|
22
22
|
from cognee.infrastructure.llm.exceptions import (
|
|
23
23
|
ContentPolicyFilterError,
|
|
24
24
|
)
|
|
25
|
+
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
25
26
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
26
27
|
from cognee.modules.observability.get_observe import get_observe
|
|
27
28
|
from cognee.shared.logging_utils import get_logger
|
|
@@ -56,6 +57,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
56
57
|
model: str
|
|
57
58
|
api_key: str
|
|
58
59
|
api_version: str
|
|
60
|
+
default_instructor_mode = "json_schema_mode"
|
|
59
61
|
|
|
60
62
|
MAX_RETRIES = 5
|
|
61
63
|
|
|
@@ -69,19 +71,21 @@ class OpenAIAdapter(LLMInterface):
|
|
|
69
71
|
model: str,
|
|
70
72
|
transcription_model: str,
|
|
71
73
|
max_completion_tokens: int,
|
|
74
|
+
instructor_mode: str = None,
|
|
72
75
|
streaming: bool = False,
|
|
73
76
|
fallback_model: str = None,
|
|
74
77
|
fallback_api_key: str = None,
|
|
75
78
|
fallback_endpoint: str = None,
|
|
76
79
|
):
|
|
80
|
+
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
77
81
|
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
|
78
82
|
# Make sure all new gpt models will work with this mode as well.
|
|
79
83
|
if "gpt-5" in model:
|
|
80
84
|
self.aclient = instructor.from_litellm(
|
|
81
|
-
litellm.acompletion, mode=instructor.Mode.
|
|
85
|
+
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
|
82
86
|
)
|
|
83
87
|
self.client = instructor.from_litellm(
|
|
84
|
-
litellm.completion, mode=instructor.Mode.
|
|
88
|
+
litellm.completion, mode=instructor.Mode(self.instructor_mode)
|
|
85
89
|
)
|
|
86
90
|
else:
|
|
87
91
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
|
@@ -102,13 +106,13 @@ class OpenAIAdapter(LLMInterface):
|
|
|
102
106
|
@observe(as_type="generation")
|
|
103
107
|
@retry(
|
|
104
108
|
stop=stop_after_delay(128),
|
|
105
|
-
wait=wait_exponential_jitter(
|
|
109
|
+
wait=wait_exponential_jitter(8, 128),
|
|
106
110
|
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
107
111
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
108
112
|
reraise=True,
|
|
109
113
|
)
|
|
110
114
|
async def acreate_structured_output(
|
|
111
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
115
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
112
116
|
) -> BaseModel:
|
|
113
117
|
"""
|
|
114
118
|
Generate a response from a user query.
|
|
@@ -132,34 +136,9 @@ class OpenAIAdapter(LLMInterface):
|
|
|
132
136
|
"""
|
|
133
137
|
|
|
134
138
|
try:
|
|
135
|
-
|
|
136
|
-
model=self.model,
|
|
137
|
-
messages=[
|
|
138
|
-
{
|
|
139
|
-
"role": "user",
|
|
140
|
-
"content": f"""{text_input}""",
|
|
141
|
-
},
|
|
142
|
-
{
|
|
143
|
-
"role": "system",
|
|
144
|
-
"content": system_prompt,
|
|
145
|
-
},
|
|
146
|
-
],
|
|
147
|
-
api_key=self.api_key,
|
|
148
|
-
api_base=self.endpoint,
|
|
149
|
-
api_version=self.api_version,
|
|
150
|
-
response_model=response_model,
|
|
151
|
-
max_retries=self.MAX_RETRIES,
|
|
152
|
-
)
|
|
153
|
-
except (
|
|
154
|
-
ContentFilterFinishReasonError,
|
|
155
|
-
ContentPolicyViolationError,
|
|
156
|
-
InstructorRetryException,
|
|
157
|
-
) as e:
|
|
158
|
-
if not (self.fallback_model and self.fallback_api_key):
|
|
159
|
-
raise e
|
|
160
|
-
try:
|
|
139
|
+
async with llm_rate_limiter_context_manager():
|
|
161
140
|
return await self.aclient.chat.completions.create(
|
|
162
|
-
model=self.
|
|
141
|
+
model=self.model,
|
|
163
142
|
messages=[
|
|
164
143
|
{
|
|
165
144
|
"role": "user",
|
|
@@ -170,11 +149,40 @@ class OpenAIAdapter(LLMInterface):
|
|
|
170
149
|
"content": system_prompt,
|
|
171
150
|
},
|
|
172
151
|
],
|
|
173
|
-
api_key=self.
|
|
174
|
-
|
|
152
|
+
api_key=self.api_key,
|
|
153
|
+
api_base=self.endpoint,
|
|
154
|
+
api_version=self.api_version,
|
|
175
155
|
response_model=response_model,
|
|
176
156
|
max_retries=self.MAX_RETRIES,
|
|
157
|
+
**kwargs,
|
|
177
158
|
)
|
|
159
|
+
except (
|
|
160
|
+
ContentFilterFinishReasonError,
|
|
161
|
+
ContentPolicyViolationError,
|
|
162
|
+
InstructorRetryException,
|
|
163
|
+
) as e:
|
|
164
|
+
if not (self.fallback_model and self.fallback_api_key):
|
|
165
|
+
raise e
|
|
166
|
+
try:
|
|
167
|
+
async with llm_rate_limiter_context_manager():
|
|
168
|
+
return await self.aclient.chat.completions.create(
|
|
169
|
+
model=self.fallback_model,
|
|
170
|
+
messages=[
|
|
171
|
+
{
|
|
172
|
+
"role": "user",
|
|
173
|
+
"content": f"""{text_input}""",
|
|
174
|
+
},
|
|
175
|
+
{
|
|
176
|
+
"role": "system",
|
|
177
|
+
"content": system_prompt,
|
|
178
|
+
},
|
|
179
|
+
],
|
|
180
|
+
api_key=self.fallback_api_key,
|
|
181
|
+
# api_base=self.fallback_endpoint,
|
|
182
|
+
response_model=response_model,
|
|
183
|
+
max_retries=self.MAX_RETRIES,
|
|
184
|
+
**kwargs,
|
|
185
|
+
)
|
|
178
186
|
except (
|
|
179
187
|
ContentFilterFinishReasonError,
|
|
180
188
|
ContentPolicyViolationError,
|
|
@@ -199,7 +207,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
199
207
|
reraise=True,
|
|
200
208
|
)
|
|
201
209
|
def create_structured_output(
|
|
202
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
210
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
203
211
|
) -> BaseModel:
|
|
204
212
|
"""
|
|
205
213
|
Generate a response from a user query.
|
|
@@ -239,6 +247,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
239
247
|
api_version=self.api_version,
|
|
240
248
|
response_model=response_model,
|
|
241
249
|
max_retries=self.MAX_RETRIES,
|
|
250
|
+
**kwargs,
|
|
242
251
|
)
|
|
243
252
|
|
|
244
253
|
@retry(
|
|
@@ -248,7 +257,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
248
257
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
249
258
|
reraise=True,
|
|
250
259
|
)
|
|
251
|
-
async def create_transcript(self, input):
|
|
260
|
+
async def create_transcript(self, input, **kwargs):
|
|
252
261
|
"""
|
|
253
262
|
Generate an audio transcript from a user query.
|
|
254
263
|
|
|
@@ -275,6 +284,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
275
284
|
api_base=self.endpoint,
|
|
276
285
|
api_version=self.api_version,
|
|
277
286
|
max_retries=self.MAX_RETRIES,
|
|
287
|
+
**kwargs,
|
|
278
288
|
)
|
|
279
289
|
|
|
280
290
|
return transcription
|
|
@@ -286,7 +296,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
286
296
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
287
297
|
reraise=True,
|
|
288
298
|
)
|
|
289
|
-
async def transcribe_image(self, input) -> BaseModel:
|
|
299
|
+
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
|
290
300
|
"""
|
|
291
301
|
Generate a transcription of an image from a user query.
|
|
292
302
|
|
|
@@ -331,4 +341,5 @@ class OpenAIAdapter(LLMInterface):
|
|
|
331
341
|
api_version=self.api_version,
|
|
332
342
|
max_completion_tokens=300,
|
|
333
343
|
max_retries=self.MAX_RETRIES,
|
|
344
|
+
**kwargs,
|
|
334
345
|
)
|
|
@@ -3,5 +3,6 @@
|
|
|
3
3
|
from .text_loader import TextLoader
|
|
4
4
|
from .audio_loader import AudioLoader
|
|
5
5
|
from .image_loader import ImageLoader
|
|
6
|
+
from .csv_loader import CsvLoader
|
|
6
7
|
|
|
7
|
-
__all__ = ["TextLoader", "AudioLoader", "ImageLoader"]
|
|
8
|
+
__all__ = ["TextLoader", "AudioLoader", "ImageLoader", "CsvLoader"]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
import csv
|
|
4
|
+
from cognee.infrastructure.loaders.LoaderInterface import LoaderInterface
|
|
5
|
+
from cognee.infrastructure.files.storage import get_file_storage, get_storage_config
|
|
6
|
+
from cognee.infrastructure.files.utils.get_file_metadata import get_file_metadata
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CsvLoader(LoaderInterface):
|
|
10
|
+
"""
|
|
11
|
+
Core CSV file loader that handles basic CSV file formats.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def supported_extensions(self) -> List[str]:
|
|
16
|
+
"""Supported text file extensions."""
|
|
17
|
+
return [
|
|
18
|
+
"csv",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def supported_mime_types(self) -> List[str]:
|
|
23
|
+
"""Supported MIME types for text content."""
|
|
24
|
+
return [
|
|
25
|
+
"text/csv",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def loader_name(self) -> str:
|
|
30
|
+
"""Unique identifier for this loader."""
|
|
31
|
+
return "csv_loader"
|
|
32
|
+
|
|
33
|
+
def can_handle(self, extension: str, mime_type: str) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Check if this loader can handle the given file.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
extension: File extension
|
|
39
|
+
mime_type: Optional MIME type
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
True if file can be handled, False otherwise
|
|
43
|
+
"""
|
|
44
|
+
if extension in self.supported_extensions and mime_type in self.supported_mime_types:
|
|
45
|
+
return True
|
|
46
|
+
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
async def load(self, file_path: str, encoding: str = "utf-8", **kwargs):
|
|
50
|
+
"""
|
|
51
|
+
Load and process the csv file.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
file_path: Path to the file to load
|
|
55
|
+
encoding: Text encoding to use (default: utf-8)
|
|
56
|
+
**kwargs: Additional configuration (unused)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
LoaderResult containing the file content and metadata
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
FileNotFoundError: If file doesn't exist
|
|
63
|
+
UnicodeDecodeError: If file cannot be decoded with specified encoding
|
|
64
|
+
OSError: If file cannot be read
|
|
65
|
+
"""
|
|
66
|
+
if not os.path.exists(file_path):
|
|
67
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
68
|
+
|
|
69
|
+
with open(file_path, "rb") as f:
|
|
70
|
+
file_metadata = await get_file_metadata(f)
|
|
71
|
+
# Name ingested file of current loader based on original file content hash
|
|
72
|
+
storage_file_name = "text_" + file_metadata["content_hash"] + ".txt"
|
|
73
|
+
|
|
74
|
+
row_texts = []
|
|
75
|
+
row_index = 1
|
|
76
|
+
|
|
77
|
+
with open(file_path, "r", encoding=encoding, newline="") as file:
|
|
78
|
+
reader = csv.DictReader(file)
|
|
79
|
+
for row in reader:
|
|
80
|
+
pairs = [f"{str(k)}: {str(v)}" for k, v in row.items()]
|
|
81
|
+
row_text = ", ".join(pairs)
|
|
82
|
+
row_texts.append(f"Row {row_index}:\n{row_text}\n")
|
|
83
|
+
row_index += 1
|
|
84
|
+
|
|
85
|
+
content = "\n".join(row_texts)
|
|
86
|
+
|
|
87
|
+
storage_config = get_storage_config()
|
|
88
|
+
data_root_directory = storage_config["data_root_directory"]
|
|
89
|
+
storage = get_file_storage(data_root_directory)
|
|
90
|
+
|
|
91
|
+
full_file_path = await storage.store(storage_file_name, content)
|
|
92
|
+
|
|
93
|
+
return full_file_path
|
|
@@ -16,7 +16,7 @@ class TextLoader(LoaderInterface):
|
|
|
16
16
|
@property
|
|
17
17
|
def supported_extensions(self) -> List[str]:
|
|
18
18
|
"""Supported text file extensions."""
|
|
19
|
-
return ["txt", "md", "
|
|
19
|
+
return ["txt", "md", "json", "xml", "yaml", "yml", "log"]
|
|
20
20
|
|
|
21
21
|
@property
|
|
22
22
|
def supported_mime_types(self) -> List[str]:
|
|
@@ -24,7 +24,6 @@ class TextLoader(LoaderInterface):
|
|
|
24
24
|
return [
|
|
25
25
|
"text/plain",
|
|
26
26
|
"text/markdown",
|
|
27
|
-
"text/csv",
|
|
28
27
|
"application/json",
|
|
29
28
|
"text/xml",
|
|
30
29
|
"application/xml",
|
|
@@ -227,12 +227,3 @@ class AdvancedPdfLoader(LoaderInterface):
|
|
|
227
227
|
if value is None:
|
|
228
228
|
return ""
|
|
229
229
|
return str(value).replace("\xa0", " ").strip()
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if __name__ == "__main__":
|
|
233
|
-
loader = AdvancedPdfLoader()
|
|
234
|
-
asyncio.run(
|
|
235
|
-
loader.load(
|
|
236
|
-
"/Users/xiaotao/work/cognee/cognee/infrastructure/loaders/external/attention_is_all_you_need.pdf"
|
|
237
|
-
)
|
|
238
|
-
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from cognee.infrastructure.loaders.external import PyPdfLoader
|
|
2
|
-
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader
|
|
2
|
+
from cognee.infrastructure.loaders.core import TextLoader, AudioLoader, ImageLoader, CsvLoader
|
|
3
3
|
|
|
4
4
|
# Registry for loader implementations
|
|
5
5
|
supported_loaders = {
|
|
@@ -7,6 +7,7 @@ supported_loaders = {
|
|
|
7
7
|
TextLoader.loader_name: TextLoader,
|
|
8
8
|
ImageLoader.loader_name: ImageLoader,
|
|
9
9
|
AudioLoader.loader_name: AudioLoader,
|
|
10
|
+
CsvLoader.loader_name: CsvLoader,
|
|
10
11
|
}
|
|
11
12
|
|
|
12
13
|
# Try adding optional loaders
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from cognee import memify
|
|
4
|
+
from cognee.context_global_variables import (
|
|
5
|
+
set_database_global_context_variables,
|
|
6
|
+
)
|
|
7
|
+
from cognee.exceptions import CogneeValidationError
|
|
8
|
+
from cognee.modules.data.methods import get_authorized_existing_datasets
|
|
9
|
+
from cognee.shared.logging_utils import get_logger
|
|
10
|
+
from cognee.modules.pipelines.tasks.task import Task
|
|
11
|
+
from cognee.modules.users.models import User
|
|
12
|
+
from cognee.tasks.memify.get_triplet_datapoints import get_triplet_datapoints
|
|
13
|
+
from cognee.tasks.storage import index_data_points
|
|
14
|
+
|
|
15
|
+
logger = get_logger("create_triplet_embeddings")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def create_triplet_embeddings(
|
|
19
|
+
user: User,
|
|
20
|
+
dataset: str = "main_dataset",
|
|
21
|
+
run_in_background: bool = False,
|
|
22
|
+
triplets_batch_size: int = 100,
|
|
23
|
+
) -> dict[str, Any]:
|
|
24
|
+
dataset_to_write = await get_authorized_existing_datasets(
|
|
25
|
+
user=user, datasets=[dataset], permission_type="write"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if not dataset_to_write:
|
|
29
|
+
raise CogneeValidationError(
|
|
30
|
+
message=f"User does not have write access to dataset: {dataset}",
|
|
31
|
+
log=False,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
await set_database_global_context_variables(
|
|
35
|
+
dataset_to_write[0].id, dataset_to_write[0].owner_id
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
extraction_tasks = [Task(get_triplet_datapoints, triplets_batch_size=triplets_batch_size)]
|
|
39
|
+
|
|
40
|
+
enrichment_tasks = [
|
|
41
|
+
Task(index_data_points, task_config={"batch_size": triplets_batch_size}),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
result = await memify(
|
|
45
|
+
extraction_tasks=extraction_tasks,
|
|
46
|
+
enrichment_tasks=enrichment_tasks,
|
|
47
|
+
dataset=dataset_to_write[0].id,
|
|
48
|
+
data=[{}],
|
|
49
|
+
user=user,
|
|
50
|
+
run_in_background=run_in_background,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return result
|