cognee 0.3.6__py3-none-any.whl → 0.3.7.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/health.py +2 -12
- cognee/api/v1/add/add.py +46 -6
- cognee/api/v1/add/routers/get_add_router.py +11 -2
- cognee/api/v1/cognify/cognify.py +29 -9
- cognee/api/v1/cognify/routers/get_cognify_router.py +2 -1
- cognee/api/v1/datasets/datasets.py +11 -0
- cognee/api/v1/datasets/routers/get_datasets_router.py +8 -0
- cognee/api/v1/delete/routers/get_delete_router.py +2 -0
- cognee/api/v1/memify/routers/get_memify_router.py +2 -1
- cognee/api/v1/permissions/routers/get_permissions_router.py +6 -0
- cognee/api/v1/responses/default_tools.py +0 -1
- cognee/api/v1/responses/dispatch_function.py +1 -1
- cognee/api/v1/responses/routers/default_tools.py +0 -1
- cognee/api/v1/search/routers/get_search_router.py +3 -3
- cognee/api/v1/search/search.py +11 -9
- cognee/api/v1/settings/routers/get_settings_router.py +7 -1
- cognee/api/v1/sync/routers/get_sync_router.py +3 -0
- cognee/api/v1/ui/ui.py +45 -16
- cognee/api/v1/update/routers/get_update_router.py +3 -1
- cognee/api/v1/update/update.py +3 -3
- cognee/api/v1/users/routers/get_visualize_router.py +2 -0
- cognee/cli/_cognee.py +61 -10
- cognee/cli/commands/add_command.py +3 -3
- cognee/cli/commands/cognify_command.py +3 -3
- cognee/cli/commands/config_command.py +9 -7
- cognee/cli/commands/delete_command.py +3 -3
- cognee/cli/commands/search_command.py +3 -7
- cognee/cli/config.py +0 -1
- cognee/context_global_variables.py +5 -0
- cognee/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/cache/__init__.py +2 -0
- cognee/infrastructure/databases/cache/cache_db_interface.py +79 -0
- cognee/infrastructure/databases/cache/config.py +44 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +67 -0
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +243 -0
- cognee/infrastructure/databases/exceptions/__init__.py +1 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +18 -2
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +5 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +76 -47
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +13 -3
- cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py +1 -1
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +1 -1
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +21 -3
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +17 -10
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +17 -4
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -3
- cognee/infrastructure/databases/vector/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -1
- cognee/infrastructure/files/exceptions.py +1 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -9
- cognee/infrastructure/files/storage/S3FileStorage.py +11 -11
- cognee/infrastructure/files/utils/guess_file_type.py +6 -0
- cognee/infrastructure/llm/prompts/feedback_reaction_prompt.txt +14 -0
- cognee/infrastructure/llm/prompts/feedback_report_prompt.txt +13 -0
- cognee/infrastructure/llm/prompts/feedback_user_context_prompt.txt +5 -0
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +19 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +32 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py +0 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +109 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +33 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +40 -18
- cognee/infrastructure/loaders/LoaderEngine.py +27 -7
- cognee/infrastructure/loaders/external/__init__.py +7 -0
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +2 -8
- cognee/infrastructure/loaders/external/beautiful_soup_loader.py +310 -0
- cognee/infrastructure/loaders/supported_loaders.py +7 -0
- cognee/modules/data/exceptions/exceptions.py +1 -1
- cognee/modules/data/methods/__init__.py +3 -0
- cognee/modules/data/methods/get_dataset_data.py +4 -1
- cognee/modules/data/methods/has_dataset_data.py +21 -0
- cognee/modules/engine/models/TableRow.py +0 -1
- cognee/modules/ingestion/save_data_to_file.py +9 -2
- cognee/modules/pipelines/exceptions/exceptions.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +12 -1
- cognee/modules/pipelines/operations/run_tasks.py +25 -197
- cognee/modules/pipelines/operations/run_tasks_base.py +7 -0
- cognee/modules/pipelines/operations/run_tasks_data_item.py +260 -0
- cognee/modules/pipelines/operations/run_tasks_distributed.py +121 -38
- cognee/modules/pipelines/operations/run_tasks_with_telemetry.py +9 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +48 -8
- cognee/modules/retrieval/base_graph_retriever.py +3 -1
- cognee/modules/retrieval/base_retriever.py +3 -1
- cognee/modules/retrieval/chunks_retriever.py +5 -1
- cognee/modules/retrieval/code_retriever.py +20 -2
- cognee/modules/retrieval/completion_retriever.py +50 -9
- cognee/modules/retrieval/cypher_search_retriever.py +11 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +47 -8
- cognee/modules/retrieval/graph_completion_cot_retriever.py +152 -22
- cognee/modules/retrieval/graph_completion_retriever.py +54 -10
- cognee/modules/retrieval/lexical_retriever.py +20 -2
- cognee/modules/retrieval/natural_language_retriever.py +10 -1
- cognee/modules/retrieval/summaries_retriever.py +5 -1
- cognee/modules/retrieval/temporal_retriever.py +62 -10
- cognee/modules/retrieval/user_qa_feedback.py +3 -2
- cognee/modules/retrieval/utils/completion.py +30 -4
- cognee/modules/retrieval/utils/description_to_codepart_search.py +1 -1
- cognee/modules/retrieval/utils/session_cache.py +156 -0
- cognee/modules/search/methods/get_search_type_tools.py +0 -5
- cognee/modules/search/methods/no_access_control_search.py +12 -1
- cognee/modules/search/methods/search.py +51 -5
- cognee/modules/search/types/SearchType.py +0 -1
- cognee/modules/settings/get_settings.py +23 -0
- cognee/modules/users/methods/get_authenticated_user.py +3 -1
- cognee/modules/users/methods/get_default_user.py +1 -6
- cognee/modules/users/roles/methods/create_role.py +2 -2
- cognee/modules/users/tenants/methods/create_tenant.py +2 -2
- cognee/shared/exceptions/exceptions.py +1 -1
- cognee/shared/logging_utils.py +18 -11
- cognee/shared/utils.py +24 -2
- cognee/tasks/codingagents/coding_rule_associations.py +1 -2
- cognee/tasks/documents/exceptions/exceptions.py +1 -1
- cognee/tasks/feedback/__init__.py +13 -0
- cognee/tasks/feedback/create_enrichments.py +84 -0
- cognee/tasks/feedback/extract_feedback_interactions.py +230 -0
- cognee/tasks/feedback/generate_improved_answers.py +130 -0
- cognee/tasks/feedback/link_enrichments_to_feedback.py +67 -0
- cognee/tasks/feedback/models.py +26 -0
- cognee/tasks/graph/extract_graph_from_data.py +2 -0
- cognee/tasks/ingestion/data_item_to_text_file.py +3 -3
- cognee/tasks/ingestion/ingest_data.py +11 -5
- cognee/tasks/ingestion/save_data_item_to_storage.py +12 -1
- cognee/tasks/storage/add_data_points.py +3 -10
- cognee/tasks/storage/index_data_points.py +19 -14
- cognee/tasks/storage/index_graph_edges.py +25 -11
- cognee/tasks/web_scraper/__init__.py +34 -0
- cognee/tasks/web_scraper/config.py +26 -0
- cognee/tasks/web_scraper/default_url_crawler.py +446 -0
- cognee/tasks/web_scraper/models.py +46 -0
- cognee/tasks/web_scraper/types.py +4 -0
- cognee/tasks/web_scraper/utils.py +142 -0
- cognee/tasks/web_scraper/web_scraper_task.py +396 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py +0 -1
- cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +13 -0
- cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +19 -0
- cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +344 -0
- cognee/tests/subprocesses/reader.py +25 -0
- cognee/tests/subprocesses/simple_cognify_1.py +31 -0
- cognee/tests/subprocesses/simple_cognify_2.py +31 -0
- cognee/tests/subprocesses/writer.py +32 -0
- cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py +0 -2
- cognee/tests/tasks/descriptive_metrics/neo4j_metrics_test.py +8 -3
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +89 -0
- cognee/tests/tasks/web_scraping/web_scraping_test.py +172 -0
- cognee/tests/test_add_docling_document.py +56 -0
- cognee/tests/test_chromadb.py +7 -11
- cognee/tests/test_concurrent_subprocess_access.py +76 -0
- cognee/tests/test_conversation_history.py +240 -0
- cognee/tests/test_feedback_enrichment.py +174 -0
- cognee/tests/test_kuzu.py +27 -15
- cognee/tests/test_lancedb.py +7 -11
- cognee/tests/test_library.py +32 -2
- cognee/tests/test_neo4j.py +24 -16
- cognee/tests/test_neptune_analytics_vector.py +7 -11
- cognee/tests/test_permissions.py +9 -13
- cognee/tests/test_pgvector.py +4 -4
- cognee/tests/test_remote_kuzu.py +8 -11
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +6 -8
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +89 -0
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +154 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +51 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/METADATA +21 -6
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/RECORD +178 -139
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/entry_points.txt +1 -0
- distributed/Dockerfile +0 -3
- distributed/entrypoint.py +21 -9
- distributed/signal.py +5 -0
- distributed/workers/data_point_saving_worker.py +64 -34
- distributed/workers/graph_saving_worker.py +71 -47
- cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +0 -1116
- cognee/modules/retrieval/insights_retriever.py +0 -133
- cognee/tests/test_memgraph.py +0 -109
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +0 -251
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/WHEEL +0 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.6.dist-info → cognee-0.3.7.dev1.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -3,8 +3,16 @@ from cognee.shared.logging_utils import get_logger
|
|
|
3
3
|
import aiohttp
|
|
4
4
|
from typing import List, Optional
|
|
5
5
|
import os
|
|
6
|
-
|
|
6
|
+
import litellm
|
|
7
|
+
import logging
|
|
7
8
|
import aiohttp.http_exceptions
|
|
9
|
+
from tenacity import (
|
|
10
|
+
retry,
|
|
11
|
+
stop_after_delay,
|
|
12
|
+
wait_exponential_jitter,
|
|
13
|
+
retry_if_not_exception_type,
|
|
14
|
+
before_sleep_log,
|
|
15
|
+
)
|
|
8
16
|
|
|
9
17
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
|
10
18
|
from cognee.infrastructure.llm.tokenizer.HuggingFace import (
|
|
@@ -69,7 +77,6 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
69
77
|
enable_mocking = str(enable_mocking).lower()
|
|
70
78
|
self.mock = enable_mocking in ("true", "1", "yes")
|
|
71
79
|
|
|
72
|
-
@embedding_rate_limit_async
|
|
73
80
|
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
|
74
81
|
"""
|
|
75
82
|
Generate embedding vectors for a list of text prompts.
|
|
@@ -92,7 +99,13 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
92
99
|
embeddings = await asyncio.gather(*[self._get_embedding(prompt) for prompt in text])
|
|
93
100
|
return embeddings
|
|
94
101
|
|
|
95
|
-
@
|
|
102
|
+
@retry(
|
|
103
|
+
stop=stop_after_delay(128),
|
|
104
|
+
wait=wait_exponential_jitter(2, 128),
|
|
105
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
106
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
107
|
+
reraise=True,
|
|
108
|
+
)
|
|
96
109
|
async def _get_embedding(self, prompt: str) -> List[float]:
|
|
97
110
|
"""
|
|
98
111
|
Internal method to call the Ollama embeddings endpoint for a single prompt.
|
|
@@ -111,7 +124,7 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
111
124
|
self.endpoint, json=payload, headers=headers, timeout=60.0
|
|
112
125
|
) as response:
|
|
113
126
|
data = await response.json()
|
|
114
|
-
return data["
|
|
127
|
+
return data["embeddings"][0]
|
|
115
128
|
|
|
116
129
|
def get_vector_size(self) -> int:
|
|
117
130
|
"""
|
|
@@ -24,11 +24,10 @@ class EmbeddingConfig(BaseSettings):
|
|
|
24
24
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
25
25
|
|
|
26
26
|
def model_post_init(self, __context) -> None:
|
|
27
|
-
# If embedding batch size is not defined use 2048 as default for OpenAI and 100 for all other embedding models
|
|
28
27
|
if not self.embedding_batch_size and self.embedding_provider.lower() == "openai":
|
|
29
|
-
self.embedding_batch_size =
|
|
28
|
+
self.embedding_batch_size = 36
|
|
30
29
|
elif not self.embedding_batch_size:
|
|
31
|
-
self.embedding_batch_size =
|
|
30
|
+
self.embedding_batch_size = 36
|
|
32
31
|
|
|
33
32
|
def to_dict(self) -> dict:
|
|
34
33
|
"""
|
|
@@ -15,7 +15,7 @@ class CollectionNotFoundError(CogneeValidationError):
|
|
|
15
15
|
self,
|
|
16
16
|
message,
|
|
17
17
|
name: str = "CollectionNotFoundError",
|
|
18
|
-
status_code: int = status.
|
|
18
|
+
status_code: int = status.HTTP_422_UNPROCESSABLE_CONTENT,
|
|
19
19
|
log=True,
|
|
20
20
|
log_level="DEBUG",
|
|
21
21
|
):
|
|
@@ -324,7 +324,6 @@ class LanceDBAdapter(VectorDBInterface):
|
|
|
324
324
|
|
|
325
325
|
def get_data_point_schema(self, model_type: BaseModel):
|
|
326
326
|
related_models_fields = []
|
|
327
|
-
|
|
328
327
|
for field_name, field_config in model_type.model_fields.items():
|
|
329
328
|
if hasattr(field_config, "model_fields"):
|
|
330
329
|
related_models_fields.append(field_name)
|
|
@@ -8,6 +8,6 @@ class FileContentHashingError(Exception):
|
|
|
8
8
|
self,
|
|
9
9
|
message: str = "Failed to hash content of the file.",
|
|
10
10
|
name: str = "FileContentHashingError",
|
|
11
|
-
status_code=status.
|
|
11
|
+
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
|
12
12
|
):
|
|
13
13
|
super().__init__(message, name, status_code)
|
|
@@ -82,16 +82,16 @@ class LocalFileStorage(Storage):
|
|
|
82
82
|
self.ensure_directory_exists(file_dir_path)
|
|
83
83
|
|
|
84
84
|
if overwrite or not os.path.exists(full_file_path):
|
|
85
|
-
|
|
86
|
-
full_file_path,
|
|
87
|
-
mode="w" if isinstance(data, str) else "wb",
|
|
88
|
-
encoding="utf-8" if isinstance(data, str) else None,
|
|
89
|
-
) as file:
|
|
90
|
-
if hasattr(data, "read"):
|
|
91
|
-
data.seek(0)
|
|
92
|
-
file.write(data.read())
|
|
93
|
-
else:
|
|
85
|
+
if isinstance(data, str):
|
|
86
|
+
with open(full_file_path, mode="w", encoding="utf-8", newline="\n") as file:
|
|
94
87
|
file.write(data)
|
|
88
|
+
else:
|
|
89
|
+
with open(full_file_path, mode="wb") as file:
|
|
90
|
+
if hasattr(data, "read"):
|
|
91
|
+
data.seek(0)
|
|
92
|
+
file.write(data.read())
|
|
93
|
+
else:
|
|
94
|
+
file.write(data)
|
|
95
95
|
|
|
96
96
|
file.close()
|
|
97
97
|
|
|
@@ -70,18 +70,18 @@ class S3FileStorage(Storage):
|
|
|
70
70
|
if overwrite or not await self.file_exists(file_path):
|
|
71
71
|
|
|
72
72
|
def save_data_to_file():
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
) as file:
|
|
78
|
-
if hasattr(data, "read"):
|
|
79
|
-
data.seek(0)
|
|
80
|
-
file.write(data.read())
|
|
81
|
-
else:
|
|
73
|
+
if isinstance(data, str):
|
|
74
|
+
with self.s3.open(
|
|
75
|
+
full_file_path, mode="w", encoding="utf-8", newline="\n"
|
|
76
|
+
) as file:
|
|
82
77
|
file.write(data)
|
|
83
|
-
|
|
84
|
-
|
|
78
|
+
else:
|
|
79
|
+
with self.s3.open(full_file_path, mode="wb") as file:
|
|
80
|
+
if hasattr(data, "read"):
|
|
81
|
+
data.seek(0)
|
|
82
|
+
file.write(data.read())
|
|
83
|
+
else:
|
|
84
|
+
file.write(data)
|
|
85
85
|
|
|
86
86
|
await run_async(save_data_to_file)
|
|
87
87
|
|
|
@@ -124,6 +124,12 @@ def guess_file_type(file: BinaryIO) -> filetype.Type:
|
|
|
124
124
|
"""
|
|
125
125
|
file_type = filetype.guess(file)
|
|
126
126
|
|
|
127
|
+
# If file type could not be determined consider it a plain text file as they don't have magic number encoding
|
|
128
|
+
if file_type is None:
|
|
129
|
+
from filetype.types.base import Type
|
|
130
|
+
|
|
131
|
+
file_type = Type("text/plain", "txt")
|
|
132
|
+
|
|
127
133
|
if file_type is None:
|
|
128
134
|
raise FileTypeException(f"Unknown file detected: {file.name}.")
|
|
129
135
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
A question was previously answered, but the answer received negative feedback.
|
|
2
|
+
Please reconsider and improve the response.
|
|
3
|
+
|
|
4
|
+
Question: {question}
|
|
5
|
+
Context originally used: {context}
|
|
6
|
+
Previous answer: {wrong_answer}
|
|
7
|
+
Feedback on that answer: {negative_feedback}
|
|
8
|
+
|
|
9
|
+
Task: Provide a better response. The new answer should be short and direct.
|
|
10
|
+
Then explain briefly why this answer is better.
|
|
11
|
+
|
|
12
|
+
Format your reply as:
|
|
13
|
+
Answer: <improved answer>
|
|
14
|
+
Explanation: <short explanation>
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
Write a concise, stand-alone paragraph that explains the correct answer to the question below.
|
|
2
|
+
The paragraph should read naturally on its own, providing all necessary context and reasoning
|
|
3
|
+
so the answer is clear and well-supported.
|
|
4
|
+
|
|
5
|
+
Question: {question}
|
|
6
|
+
Correct answer: {improved_answer}
|
|
7
|
+
Supporting context: {new_context}
|
|
8
|
+
|
|
9
|
+
Your paragraph should:
|
|
10
|
+
- First sentence clearly states the correct answer as a full sentence
|
|
11
|
+
- Remainder flows from first sentence and provides explanation based on context
|
|
12
|
+
- Use simple, direct language that is easy to follow
|
|
13
|
+
- Use shorter sentences, no long-winded explanations
|
|
@@ -10,8 +10,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|
|
10
10
|
- Summarizing large amounts of information
|
|
11
11
|
- Quick understanding of complex subjects
|
|
12
12
|
|
|
13
|
-
* **`INSIGHTS`**: The `INSIGHTS` search type discovers connections and relationships between entities in the knowledge graph.
|
|
14
|
-
|
|
15
13
|
**Best for:**
|
|
16
14
|
|
|
17
15
|
- Discovering how entities are connected
|
|
@@ -95,9 +93,6 @@ Here are the available `SearchType` tools and their specific functions:
|
|
|
95
93
|
Query: "Summarize the key findings from these research papers"
|
|
96
94
|
Response: `SUMMARIES`
|
|
97
95
|
|
|
98
|
-
Query: "What is the relationship between the methodologies used in these papers?"
|
|
99
|
-
Response: `INSIGHTS`
|
|
100
|
-
|
|
101
96
|
Query: "When was Einstein born?"
|
|
102
97
|
Response: `CHUNKS`
|
|
103
98
|
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
CHANGED
|
@@ -1,19 +1,24 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from typing import Type
|
|
2
3
|
from pydantic import BaseModel
|
|
4
|
+
import litellm
|
|
3
5
|
import instructor
|
|
6
|
+
from cognee.shared.logging_utils import get_logger
|
|
7
|
+
from tenacity import (
|
|
8
|
+
retry,
|
|
9
|
+
stop_after_delay,
|
|
10
|
+
wait_exponential_jitter,
|
|
11
|
+
retry_if_not_exception_type,
|
|
12
|
+
before_sleep_log,
|
|
13
|
+
)
|
|
4
14
|
|
|
5
|
-
from cognee.infrastructure.llm.exceptions import MissingSystemPromptPathError
|
|
6
15
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
7
16
|
LLMInterface,
|
|
8
17
|
)
|
|
9
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
10
|
-
rate_limit_async,
|
|
11
|
-
sleep_and_retry_async,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
15
18
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
16
19
|
|
|
20
|
+
logger = get_logger()
|
|
21
|
+
|
|
17
22
|
|
|
18
23
|
class AnthropicAdapter(LLMInterface):
|
|
19
24
|
"""
|
|
@@ -35,8 +40,13 @@ class AnthropicAdapter(LLMInterface):
|
|
|
35
40
|
self.model = model
|
|
36
41
|
self.max_completion_tokens = max_completion_tokens
|
|
37
42
|
|
|
38
|
-
@
|
|
39
|
-
|
|
43
|
+
@retry(
|
|
44
|
+
stop=stop_after_delay(128),
|
|
45
|
+
wait=wait_exponential_jitter(2, 128),
|
|
46
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
47
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
48
|
+
reraise=True,
|
|
49
|
+
)
|
|
40
50
|
async def acreate_structured_output(
|
|
41
51
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
42
52
|
) -> BaseModel:
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
CHANGED
|
@@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
|
12
12
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
13
13
|
LLMInterface,
|
|
14
14
|
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
import logging
|
|
16
|
+
from cognee.shared.logging_utils import get_logger
|
|
17
|
+
from tenacity import (
|
|
18
|
+
retry,
|
|
19
|
+
stop_after_delay,
|
|
20
|
+
wait_exponential_jitter,
|
|
21
|
+
retry_if_not_exception_type,
|
|
22
|
+
before_sleep_log,
|
|
18
23
|
)
|
|
19
24
|
|
|
25
|
+
logger = get_logger()
|
|
26
|
+
|
|
20
27
|
|
|
21
28
|
class GeminiAdapter(LLMInterface):
|
|
22
29
|
"""
|
|
@@ -58,8 +65,13 @@ class GeminiAdapter(LLMInterface):
|
|
|
58
65
|
|
|
59
66
|
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
|
60
67
|
|
|
61
|
-
@
|
|
62
|
-
|
|
68
|
+
@retry(
|
|
69
|
+
stop=stop_after_delay(128),
|
|
70
|
+
wait=wait_exponential_jitter(2, 128),
|
|
71
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
72
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
73
|
+
reraise=True,
|
|
74
|
+
)
|
|
63
75
|
async def acreate_structured_output(
|
|
64
76
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
65
77
|
) -> BaseModel:
|
|
@@ -12,11 +12,18 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
|
12
12
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
13
13
|
LLMInterface,
|
|
14
14
|
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
import logging
|
|
16
|
+
from cognee.shared.logging_utils import get_logger
|
|
17
|
+
from tenacity import (
|
|
18
|
+
retry,
|
|
19
|
+
stop_after_delay,
|
|
20
|
+
wait_exponential_jitter,
|
|
21
|
+
retry_if_not_exception_type,
|
|
22
|
+
before_sleep_log,
|
|
18
23
|
)
|
|
19
24
|
|
|
25
|
+
logger = get_logger()
|
|
26
|
+
|
|
20
27
|
|
|
21
28
|
class GenericAPIAdapter(LLMInterface):
|
|
22
29
|
"""
|
|
@@ -58,8 +65,13 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
58
65
|
|
|
59
66
|
self.aclient = instructor.from_litellm(litellm.acompletion, mode=instructor.Mode.JSON)
|
|
60
67
|
|
|
61
|
-
@
|
|
62
|
-
|
|
68
|
+
@retry(
|
|
69
|
+
stop=stop_after_delay(128),
|
|
70
|
+
wait=wait_exponential_jitter(2, 128),
|
|
71
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
72
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
73
|
+
reraise=True,
|
|
74
|
+
)
|
|
63
75
|
async def acreate_structured_output(
|
|
64
76
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
65
77
|
) -> BaseModel:
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
CHANGED
|
@@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|
|
23
23
|
- ANTHROPIC: Represents the Anthropic provider.
|
|
24
24
|
- CUSTOM: Represents a custom provider option.
|
|
25
25
|
- GEMINI: Represents the Gemini provider.
|
|
26
|
+
- MISTRAL: Represents the Mistral AI provider.
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
29
|
OPENAI = "openai"
|
|
@@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|
|
30
31
|
ANTHROPIC = "anthropic"
|
|
31
32
|
CUSTOM = "custom"
|
|
32
33
|
GEMINI = "gemini"
|
|
34
|
+
MISTRAL = "mistral"
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def get_llm_client(raise_api_key_error: bool = True):
|
|
@@ -145,5 +147,35 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
145
147
|
api_version=llm_config.llm_api_version,
|
|
146
148
|
)
|
|
147
149
|
|
|
150
|
+
elif provider == LLMProvider.MISTRAL:
|
|
151
|
+
if llm_config.llm_api_key is None:
|
|
152
|
+
raise LLMAPIKeyNotSetError()
|
|
153
|
+
|
|
154
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
|
155
|
+
MistralAdapter,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return MistralAdapter(
|
|
159
|
+
api_key=llm_config.llm_api_key,
|
|
160
|
+
model=llm_config.llm_model,
|
|
161
|
+
max_completion_tokens=max_completion_tokens,
|
|
162
|
+
endpoint=llm_config.llm_endpoint,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
elif provider == LLMProvider.MISTRAL:
|
|
166
|
+
if llm_config.llm_api_key is None:
|
|
167
|
+
raise LLMAPIKeyNotSetError()
|
|
168
|
+
|
|
169
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
|
170
|
+
MistralAdapter,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return MistralAdapter(
|
|
174
|
+
api_key=llm_config.llm_api_key,
|
|
175
|
+
model=llm_config.llm_model,
|
|
176
|
+
max_completion_tokens=max_completion_tokens,
|
|
177
|
+
endpoint=llm_config.llm_endpoint,
|
|
178
|
+
)
|
|
179
|
+
|
|
148
180
|
else:
|
|
149
181
|
raise UnsupportedLLMProviderError(provider)
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py
ADDED
|
File without changes
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import litellm
|
|
2
|
+
import instructor
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from typing import Type
|
|
5
|
+
from litellm import JSONSchemaValidationError
|
|
6
|
+
|
|
7
|
+
from cognee.shared.logging_utils import get_logger
|
|
8
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
9
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
10
|
+
LLMInterface,
|
|
11
|
+
)
|
|
12
|
+
from cognee.infrastructure.llm.config import get_llm_config
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from tenacity import (
|
|
16
|
+
retry,
|
|
17
|
+
stop_after_delay,
|
|
18
|
+
wait_exponential_jitter,
|
|
19
|
+
retry_if_not_exception_type,
|
|
20
|
+
before_sleep_log,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = get_logger()
|
|
24
|
+
observe = get_observe()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MistralAdapter(LLMInterface):
|
|
28
|
+
"""
|
|
29
|
+
Adapter for Mistral AI API, for structured output generation and prompt display.
|
|
30
|
+
|
|
31
|
+
Public methods:
|
|
32
|
+
- acreate_structured_output
|
|
33
|
+
- show_prompt
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = "Mistral"
|
|
37
|
+
model: str
|
|
38
|
+
api_key: str
|
|
39
|
+
max_completion_tokens: int
|
|
40
|
+
|
|
41
|
+
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
|
42
|
+
from mistralai import Mistral
|
|
43
|
+
|
|
44
|
+
self.model = model
|
|
45
|
+
self.max_completion_tokens = max_completion_tokens
|
|
46
|
+
|
|
47
|
+
self.aclient = instructor.from_litellm(
|
|
48
|
+
litellm.acompletion,
|
|
49
|
+
mode=instructor.Mode.MISTRAL_TOOLS,
|
|
50
|
+
api_key=get_llm_config().llm_api_key,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@retry(
|
|
54
|
+
stop=stop_after_delay(128),
|
|
55
|
+
wait=wait_exponential_jitter(2, 128),
|
|
56
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
57
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
58
|
+
reraise=True,
|
|
59
|
+
)
|
|
60
|
+
async def acreate_structured_output(
|
|
61
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
62
|
+
) -> BaseModel:
|
|
63
|
+
"""
|
|
64
|
+
Generate a response from the user query.
|
|
65
|
+
|
|
66
|
+
Parameters:
|
|
67
|
+
-----------
|
|
68
|
+
- text_input (str): The input text from the user to be processed.
|
|
69
|
+
- system_prompt (str): A prompt that sets the context for the query.
|
|
70
|
+
- response_model (Type[BaseModel]): The model to structure the response according to
|
|
71
|
+
its format.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
--------
|
|
75
|
+
- BaseModel: An instance of BaseModel containing the structured response.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
messages = [
|
|
79
|
+
{
|
|
80
|
+
"role": "system",
|
|
81
|
+
"content": system_prompt,
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"role": "user",
|
|
85
|
+
"content": f"""Use the given format to extract information
|
|
86
|
+
from the following input: {text_input}""",
|
|
87
|
+
},
|
|
88
|
+
]
|
|
89
|
+
try:
|
|
90
|
+
response = await self.aclient.chat.completions.create(
|
|
91
|
+
model=self.model,
|
|
92
|
+
max_tokens=self.max_completion_tokens,
|
|
93
|
+
max_retries=5,
|
|
94
|
+
messages=messages,
|
|
95
|
+
response_model=response_model,
|
|
96
|
+
)
|
|
97
|
+
if response.choices and response.choices[0].message.content:
|
|
98
|
+
content = response.choices[0].message.content
|
|
99
|
+
return response_model.model_validate_json(content)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError("Failed to get valid response after retries")
|
|
102
|
+
except litellm.exceptions.BadRequestError as e:
|
|
103
|
+
logger.error(f"Bad request error: {str(e)}")
|
|
104
|
+
raise ValueError(f"Invalid request: {str(e)}")
|
|
105
|
+
|
|
106
|
+
except JSONSchemaValidationError as e:
|
|
107
|
+
logger.error(f"Schema validation failed: {str(e)}")
|
|
108
|
+
logger.debug(f"Raw response: {e.raw_response}")
|
|
109
|
+
raise ValueError(f"Response failed schema validation: {str(e)}")
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import litellm
|
|
3
|
+
import logging
|
|
2
4
|
import instructor
|
|
3
5
|
from typing import Type
|
|
4
6
|
from openai import OpenAI
|
|
@@ -7,11 +9,17 @@ from pydantic import BaseModel
|
|
|
7
9
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
8
10
|
LLMInterface,
|
|
9
11
|
)
|
|
10
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
11
|
-
rate_limit_async,
|
|
12
|
-
sleep_and_retry_async,
|
|
13
|
-
)
|
|
14
12
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
13
|
+
from cognee.shared.logging_utils import get_logger
|
|
14
|
+
from tenacity import (
|
|
15
|
+
retry,
|
|
16
|
+
stop_after_delay,
|
|
17
|
+
wait_exponential_jitter,
|
|
18
|
+
retry_if_not_exception_type,
|
|
19
|
+
before_sleep_log,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = get_logger()
|
|
15
23
|
|
|
16
24
|
|
|
17
25
|
class OllamaAPIAdapter(LLMInterface):
|
|
@@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
47
55
|
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
|
48
56
|
)
|
|
49
57
|
|
|
50
|
-
@
|
|
51
|
-
|
|
58
|
+
@retry(
|
|
59
|
+
stop=stop_after_delay(128),
|
|
60
|
+
wait=wait_exponential_jitter(2, 128),
|
|
61
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
62
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
63
|
+
reraise=True,
|
|
64
|
+
)
|
|
52
65
|
async def acreate_structured_output(
|
|
53
66
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
54
67
|
) -> BaseModel:
|
|
@@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
90
103
|
|
|
91
104
|
return response
|
|
92
105
|
|
|
93
|
-
@
|
|
106
|
+
@retry(
|
|
107
|
+
stop=stop_after_delay(128),
|
|
108
|
+
wait=wait_exponential_jitter(2, 128),
|
|
109
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
110
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
111
|
+
reraise=True,
|
|
112
|
+
)
|
|
94
113
|
async def create_transcript(self, input_file: str) -> str:
|
|
95
114
|
"""
|
|
96
115
|
Generate an audio transcript from a user query.
|
|
@@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
123
142
|
|
|
124
143
|
return transcription.text
|
|
125
144
|
|
|
126
|
-
@
|
|
145
|
+
@retry(
|
|
146
|
+
stop=stop_after_delay(128),
|
|
147
|
+
wait=wait_exponential_jitter(2, 128),
|
|
148
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
149
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
150
|
+
reraise=True,
|
|
151
|
+
)
|
|
127
152
|
async def transcribe_image(self, input_file: str) -> str:
|
|
128
153
|
"""
|
|
129
154
|
Transcribe content from an image using base64 encoding.
|