cognee 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/health.py +2 -12
- cognee/api/v1/add/add.py +46 -6
- cognee/api/v1/add/routers/get_add_router.py +5 -1
- cognee/api/v1/cognify/cognify.py +29 -9
- cognee/api/v1/datasets/datasets.py +11 -0
- cognee/api/v1/responses/default_tools.py +0 -1
- cognee/api/v1/responses/dispatch_function.py +1 -1
- cognee/api/v1/responses/routers/default_tools.py +0 -1
- cognee/api/v1/search/search.py +11 -9
- cognee/api/v1/settings/routers/get_settings_router.py +7 -1
- cognee/api/v1/ui/ui.py +47 -16
- cognee/api/v1/update/routers/get_update_router.py +1 -1
- cognee/api/v1/update/update.py +3 -3
- cognee/cli/_cognee.py +61 -10
- cognee/cli/commands/add_command.py +3 -3
- cognee/cli/commands/cognify_command.py +3 -3
- cognee/cli/commands/config_command.py +9 -7
- cognee/cli/commands/delete_command.py +3 -3
- cognee/cli/commands/search_command.py +3 -7
- cognee/cli/config.py +0 -1
- cognee/context_global_variables.py +5 -0
- cognee/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/cache/__init__.py +2 -0
- cognee/infrastructure/databases/cache/cache_db_interface.py +79 -0
- cognee/infrastructure/databases/cache/config.py +44 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +67 -0
- cognee/infrastructure/databases/cache/redis/RedisAdapter.py +243 -0
- cognee/infrastructure/databases/exceptions/__init__.py +1 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +18 -2
- cognee/infrastructure/databases/graph/get_graph_engine.py +1 -1
- cognee/infrastructure/databases/graph/graph_db_interface.py +5 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +67 -44
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +13 -3
- cognee/infrastructure/databases/graph/neo4j_driver/deadlock_retry.py +1 -1
- cognee/infrastructure/databases/graph/neptune_driver/neptune_utils.py +1 -1
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +1 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +21 -3
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +17 -10
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +17 -4
- cognee/infrastructure/databases/vector/embeddings/config.py +2 -3
- cognee/infrastructure/databases/vector/exceptions/exceptions.py +1 -1
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +0 -1
- cognee/infrastructure/files/exceptions.py +1 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +9 -9
- cognee/infrastructure/files/storage/S3FileStorage.py +11 -11
- cognee/infrastructure/files/utils/guess_file_type.py +6 -0
- cognee/infrastructure/llm/prompts/search_type_selector_prompt.txt +0 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +19 -9
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +17 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +32 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py +0 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +109 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +33 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +40 -18
- cognee/infrastructure/loaders/LoaderEngine.py +27 -7
- cognee/infrastructure/loaders/external/__init__.py +7 -0
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +2 -8
- cognee/infrastructure/loaders/external/beautiful_soup_loader.py +310 -0
- cognee/infrastructure/loaders/supported_loaders.py +7 -0
- cognee/modules/data/exceptions/exceptions.py +1 -1
- cognee/modules/data/methods/__init__.py +3 -0
- cognee/modules/data/methods/get_dataset_data.py +4 -1
- cognee/modules/data/methods/has_dataset_data.py +21 -0
- cognee/modules/engine/models/TableRow.py +0 -1
- cognee/modules/ingestion/save_data_to_file.py +9 -2
- cognee/modules/pipelines/exceptions/exceptions.py +1 -1
- cognee/modules/pipelines/operations/pipeline.py +12 -1
- cognee/modules/pipelines/operations/run_tasks.py +25 -197
- cognee/modules/pipelines/operations/run_tasks_data_item.py +260 -0
- cognee/modules/pipelines/operations/run_tasks_distributed.py +121 -38
- cognee/modules/retrieval/EntityCompletionRetriever.py +48 -8
- cognee/modules/retrieval/base_graph_retriever.py +3 -1
- cognee/modules/retrieval/base_retriever.py +3 -1
- cognee/modules/retrieval/chunks_retriever.py +5 -1
- cognee/modules/retrieval/code_retriever.py +20 -2
- cognee/modules/retrieval/completion_retriever.py +50 -9
- cognee/modules/retrieval/cypher_search_retriever.py +11 -1
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +47 -8
- cognee/modules/retrieval/graph_completion_cot_retriever.py +32 -1
- cognee/modules/retrieval/graph_completion_retriever.py +54 -10
- cognee/modules/retrieval/lexical_retriever.py +20 -2
- cognee/modules/retrieval/natural_language_retriever.py +10 -1
- cognee/modules/retrieval/summaries_retriever.py +5 -1
- cognee/modules/retrieval/temporal_retriever.py +62 -10
- cognee/modules/retrieval/user_qa_feedback.py +3 -2
- cognee/modules/retrieval/utils/completion.py +5 -0
- cognee/modules/retrieval/utils/description_to_codepart_search.py +1 -1
- cognee/modules/retrieval/utils/session_cache.py +156 -0
- cognee/modules/search/methods/get_search_type_tools.py +0 -5
- cognee/modules/search/methods/no_access_control_search.py +12 -1
- cognee/modules/search/methods/search.py +34 -2
- cognee/modules/search/types/SearchType.py +0 -1
- cognee/modules/settings/get_settings.py +23 -0
- cognee/modules/users/methods/get_authenticated_user.py +3 -1
- cognee/modules/users/methods/get_default_user.py +1 -6
- cognee/modules/users/roles/methods/create_role.py +2 -2
- cognee/modules/users/tenants/methods/create_tenant.py +2 -2
- cognee/shared/exceptions/exceptions.py +1 -1
- cognee/tasks/codingagents/coding_rule_associations.py +1 -2
- cognee/tasks/documents/exceptions/exceptions.py +1 -1
- cognee/tasks/graph/extract_graph_from_data.py +2 -0
- cognee/tasks/ingestion/data_item_to_text_file.py +3 -3
- cognee/tasks/ingestion/ingest_data.py +11 -5
- cognee/tasks/ingestion/save_data_item_to_storage.py +12 -1
- cognee/tasks/storage/add_data_points.py +3 -10
- cognee/tasks/storage/index_data_points.py +19 -14
- cognee/tasks/storage/index_graph_edges.py +25 -11
- cognee/tasks/web_scraper/__init__.py +34 -0
- cognee/tasks/web_scraper/config.py +26 -0
- cognee/tasks/web_scraper/default_url_crawler.py +446 -0
- cognee/tasks/web_scraper/models.py +46 -0
- cognee/tasks/web_scraper/types.py +4 -0
- cognee/tasks/web_scraper/utils.py +142 -0
- cognee/tasks/web_scraper/web_scraper_task.py +396 -0
- cognee/tests/cli_tests/cli_unit_tests/test_cli_utils.py +0 -1
- cognee/tests/integration/web_url_crawler/test_default_url_crawler.py +13 -0
- cognee/tests/integration/web_url_crawler/test_tavily_crawler.py +19 -0
- cognee/tests/integration/web_url_crawler/test_url_adding_e2e.py +344 -0
- cognee/tests/subprocesses/reader.py +25 -0
- cognee/tests/subprocesses/simple_cognify_1.py +31 -0
- cognee/tests/subprocesses/simple_cognify_2.py +31 -0
- cognee/tests/subprocesses/writer.py +32 -0
- cognee/tests/tasks/descriptive_metrics/metrics_test_utils.py +0 -2
- cognee/tests/tasks/descriptive_metrics/neo4j_metrics_test.py +8 -3
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +89 -0
- cognee/tests/tasks/web_scraping/web_scraping_test.py +172 -0
- cognee/tests/test_add_docling_document.py +56 -0
- cognee/tests/test_chromadb.py +7 -11
- cognee/tests/test_concurrent_subprocess_access.py +76 -0
- cognee/tests/test_conversation_history.py +240 -0
- cognee/tests/test_kuzu.py +27 -15
- cognee/tests/test_lancedb.py +7 -11
- cognee/tests/test_library.py +32 -2
- cognee/tests/test_neo4j.py +24 -16
- cognee/tests/test_neptune_analytics_vector.py +7 -11
- cognee/tests/test_permissions.py +9 -13
- cognee/tests/test_pgvector.py +4 -4
- cognee/tests/test_remote_kuzu.py +8 -11
- cognee/tests/test_s3_file_storage.py +1 -1
- cognee/tests/test_search_db.py +6 -8
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +89 -0
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +154 -0
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/METADATA +22 -7
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/RECORD +155 -128
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/entry_points.txt +1 -0
- distributed/Dockerfile +0 -3
- distributed/entrypoint.py +21 -9
- distributed/signal.py +5 -0
- distributed/workers/data_point_saving_worker.py +64 -34
- distributed/workers/graph_saving_worker.py +71 -47
- cognee/infrastructure/databases/graph/memgraph/memgraph_adapter.py +0 -1116
- cognee/modules/retrieval/insights_retriever.py +0 -133
- cognee/tests/test_memgraph.py +0 -109
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +0 -251
- distributed/poetry.lock +0 -12238
- distributed/pyproject.toml +0 -185
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/WHEEL +0 -0
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.5.dist-info → cognee-0.3.7.dist-info}/licenses/NOTICE.md +0 -0
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
CHANGED
|
@@ -23,6 +23,7 @@ class LLMProvider(Enum):
|
|
|
23
23
|
- ANTHROPIC: Represents the Anthropic provider.
|
|
24
24
|
- CUSTOM: Represents a custom provider option.
|
|
25
25
|
- GEMINI: Represents the Gemini provider.
|
|
26
|
+
- MISTRAL: Represents the Mistral AI provider.
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
29
|
OPENAI = "openai"
|
|
@@ -30,6 +31,7 @@ class LLMProvider(Enum):
|
|
|
30
31
|
ANTHROPIC = "anthropic"
|
|
31
32
|
CUSTOM = "custom"
|
|
32
33
|
GEMINI = "gemini"
|
|
34
|
+
MISTRAL = "mistral"
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
def get_llm_client(raise_api_key_error: bool = True):
|
|
@@ -145,5 +147,35 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
145
147
|
api_version=llm_config.llm_api_version,
|
|
146
148
|
)
|
|
147
149
|
|
|
150
|
+
elif provider == LLMProvider.MISTRAL:
|
|
151
|
+
if llm_config.llm_api_key is None:
|
|
152
|
+
raise LLMAPIKeyNotSetError()
|
|
153
|
+
|
|
154
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
|
155
|
+
MistralAdapter,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return MistralAdapter(
|
|
159
|
+
api_key=llm_config.llm_api_key,
|
|
160
|
+
model=llm_config.llm_model,
|
|
161
|
+
max_completion_tokens=max_completion_tokens,
|
|
162
|
+
endpoint=llm_config.llm_endpoint,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
elif provider == LLMProvider.MISTRAL:
|
|
166
|
+
if llm_config.llm_api_key is None:
|
|
167
|
+
raise LLMAPIKeyNotSetError()
|
|
168
|
+
|
|
169
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.mistral.adapter import (
|
|
170
|
+
MistralAdapter,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return MistralAdapter(
|
|
174
|
+
api_key=llm_config.llm_api_key,
|
|
175
|
+
model=llm_config.llm_model,
|
|
176
|
+
max_completion_tokens=max_completion_tokens,
|
|
177
|
+
endpoint=llm_config.llm_endpoint,
|
|
178
|
+
)
|
|
179
|
+
|
|
148
180
|
else:
|
|
149
181
|
raise UnsupportedLLMProviderError(provider)
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/__init__.py
ADDED
|
File without changes
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import litellm
|
|
2
|
+
import instructor
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from typing import Type
|
|
5
|
+
from litellm import JSONSchemaValidationError
|
|
6
|
+
|
|
7
|
+
from cognee.shared.logging_utils import get_logger
|
|
8
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
9
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
10
|
+
LLMInterface,
|
|
11
|
+
)
|
|
12
|
+
from cognee.infrastructure.llm.config import get_llm_config
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
from tenacity import (
|
|
16
|
+
retry,
|
|
17
|
+
stop_after_delay,
|
|
18
|
+
wait_exponential_jitter,
|
|
19
|
+
retry_if_not_exception_type,
|
|
20
|
+
before_sleep_log,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = get_logger()
|
|
24
|
+
observe = get_observe()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MistralAdapter(LLMInterface):
|
|
28
|
+
"""
|
|
29
|
+
Adapter for Mistral AI API, for structured output generation and prompt display.
|
|
30
|
+
|
|
31
|
+
Public methods:
|
|
32
|
+
- acreate_structured_output
|
|
33
|
+
- show_prompt
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = "Mistral"
|
|
37
|
+
model: str
|
|
38
|
+
api_key: str
|
|
39
|
+
max_completion_tokens: int
|
|
40
|
+
|
|
41
|
+
def __init__(self, api_key: str, model: str, max_completion_tokens: int, endpoint: str = None):
|
|
42
|
+
from mistralai import Mistral
|
|
43
|
+
|
|
44
|
+
self.model = model
|
|
45
|
+
self.max_completion_tokens = max_completion_tokens
|
|
46
|
+
|
|
47
|
+
self.aclient = instructor.from_litellm(
|
|
48
|
+
litellm.acompletion,
|
|
49
|
+
mode=instructor.Mode.MISTRAL_TOOLS,
|
|
50
|
+
api_key=get_llm_config().llm_api_key,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@retry(
|
|
54
|
+
stop=stop_after_delay(128),
|
|
55
|
+
wait=wait_exponential_jitter(2, 128),
|
|
56
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
57
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
58
|
+
reraise=True,
|
|
59
|
+
)
|
|
60
|
+
async def acreate_structured_output(
|
|
61
|
+
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
62
|
+
) -> BaseModel:
|
|
63
|
+
"""
|
|
64
|
+
Generate a response from the user query.
|
|
65
|
+
|
|
66
|
+
Parameters:
|
|
67
|
+
-----------
|
|
68
|
+
- text_input (str): The input text from the user to be processed.
|
|
69
|
+
- system_prompt (str): A prompt that sets the context for the query.
|
|
70
|
+
- response_model (Type[BaseModel]): The model to structure the response according to
|
|
71
|
+
its format.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
--------
|
|
75
|
+
- BaseModel: An instance of BaseModel containing the structured response.
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
messages = [
|
|
79
|
+
{
|
|
80
|
+
"role": "system",
|
|
81
|
+
"content": system_prompt,
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"role": "user",
|
|
85
|
+
"content": f"""Use the given format to extract information
|
|
86
|
+
from the following input: {text_input}""",
|
|
87
|
+
},
|
|
88
|
+
]
|
|
89
|
+
try:
|
|
90
|
+
response = await self.aclient.chat.completions.create(
|
|
91
|
+
model=self.model,
|
|
92
|
+
max_tokens=self.max_completion_tokens,
|
|
93
|
+
max_retries=5,
|
|
94
|
+
messages=messages,
|
|
95
|
+
response_model=response_model,
|
|
96
|
+
)
|
|
97
|
+
if response.choices and response.choices[0].message.content:
|
|
98
|
+
content = response.choices[0].message.content
|
|
99
|
+
return response_model.model_validate_json(content)
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError("Failed to get valid response after retries")
|
|
102
|
+
except litellm.exceptions.BadRequestError as e:
|
|
103
|
+
logger.error(f"Bad request error: {str(e)}")
|
|
104
|
+
raise ValueError(f"Invalid request: {str(e)}")
|
|
105
|
+
|
|
106
|
+
except JSONSchemaValidationError as e:
|
|
107
|
+
logger.error(f"Schema validation failed: {str(e)}")
|
|
108
|
+
logger.debug(f"Raw response: {e.raw_response}")
|
|
109
|
+
raise ValueError(f"Response failed schema validation: {str(e)}")
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import litellm
|
|
3
|
+
import logging
|
|
2
4
|
import instructor
|
|
3
5
|
from typing import Type
|
|
4
6
|
from openai import OpenAI
|
|
@@ -7,11 +9,17 @@ from pydantic import BaseModel
|
|
|
7
9
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
8
10
|
LLMInterface,
|
|
9
11
|
)
|
|
10
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
11
|
-
rate_limit_async,
|
|
12
|
-
sleep_and_retry_async,
|
|
13
|
-
)
|
|
14
12
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
13
|
+
from cognee.shared.logging_utils import get_logger
|
|
14
|
+
from tenacity import (
|
|
15
|
+
retry,
|
|
16
|
+
stop_after_delay,
|
|
17
|
+
wait_exponential_jitter,
|
|
18
|
+
retry_if_not_exception_type,
|
|
19
|
+
before_sleep_log,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = get_logger()
|
|
15
23
|
|
|
16
24
|
|
|
17
25
|
class OllamaAPIAdapter(LLMInterface):
|
|
@@ -47,8 +55,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
47
55
|
OpenAI(base_url=self.endpoint, api_key=self.api_key), mode=instructor.Mode.JSON
|
|
48
56
|
)
|
|
49
57
|
|
|
50
|
-
@
|
|
51
|
-
|
|
58
|
+
@retry(
|
|
59
|
+
stop=stop_after_delay(128),
|
|
60
|
+
wait=wait_exponential_jitter(2, 128),
|
|
61
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
62
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
63
|
+
reraise=True,
|
|
64
|
+
)
|
|
52
65
|
async def acreate_structured_output(
|
|
53
66
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
54
67
|
) -> BaseModel:
|
|
@@ -90,7 +103,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
90
103
|
|
|
91
104
|
return response
|
|
92
105
|
|
|
93
|
-
@
|
|
106
|
+
@retry(
|
|
107
|
+
stop=stop_after_delay(128),
|
|
108
|
+
wait=wait_exponential_jitter(2, 128),
|
|
109
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
110
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
111
|
+
reraise=True,
|
|
112
|
+
)
|
|
94
113
|
async def create_transcript(self, input_file: str) -> str:
|
|
95
114
|
"""
|
|
96
115
|
Generate an audio transcript from a user query.
|
|
@@ -123,7 +142,13 @@ class OllamaAPIAdapter(LLMInterface):
|
|
|
123
142
|
|
|
124
143
|
return transcription.text
|
|
125
144
|
|
|
126
|
-
@
|
|
145
|
+
@retry(
|
|
146
|
+
stop=stop_after_delay(128),
|
|
147
|
+
wait=wait_exponential_jitter(2, 128),
|
|
148
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
149
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
150
|
+
reraise=True,
|
|
151
|
+
)
|
|
127
152
|
async def transcribe_image(self, input_file: str) -> str:
|
|
128
153
|
"""
|
|
129
154
|
Transcribe content from an image using base64 encoding.
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
CHANGED
|
@@ -7,6 +7,15 @@ from openai import ContentFilterFinishReasonError
|
|
|
7
7
|
from litellm.exceptions import ContentPolicyViolationError
|
|
8
8
|
from instructor.core import InstructorRetryException
|
|
9
9
|
|
|
10
|
+
import logging
|
|
11
|
+
from tenacity import (
|
|
12
|
+
retry,
|
|
13
|
+
stop_after_delay,
|
|
14
|
+
wait_exponential_jitter,
|
|
15
|
+
retry_if_not_exception_type,
|
|
16
|
+
before_sleep_log,
|
|
17
|
+
)
|
|
18
|
+
|
|
10
19
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
11
20
|
LLMInterface,
|
|
12
21
|
)
|
|
@@ -14,19 +23,13 @@ from cognee.infrastructure.llm.exceptions import (
|
|
|
14
23
|
ContentPolicyFilterError,
|
|
15
24
|
)
|
|
16
25
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
17
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.rate_limiter import (
|
|
18
|
-
rate_limit_async,
|
|
19
|
-
rate_limit_sync,
|
|
20
|
-
sleep_and_retry_async,
|
|
21
|
-
sleep_and_retry_sync,
|
|
22
|
-
)
|
|
23
26
|
from cognee.modules.observability.get_observe import get_observe
|
|
24
27
|
from cognee.shared.logging_utils import get_logger
|
|
25
28
|
|
|
26
|
-
observe = get_observe()
|
|
27
|
-
|
|
28
29
|
logger = get_logger()
|
|
29
30
|
|
|
31
|
+
observe = get_observe()
|
|
32
|
+
|
|
30
33
|
|
|
31
34
|
class OpenAIAdapter(LLMInterface):
|
|
32
35
|
"""
|
|
@@ -97,8 +100,13 @@ class OpenAIAdapter(LLMInterface):
|
|
|
97
100
|
self.fallback_endpoint = fallback_endpoint
|
|
98
101
|
|
|
99
102
|
@observe(as_type="generation")
|
|
100
|
-
@
|
|
101
|
-
|
|
103
|
+
@retry(
|
|
104
|
+
stop=stop_after_delay(128),
|
|
105
|
+
wait=wait_exponential_jitter(2, 128),
|
|
106
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
107
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
108
|
+
reraise=True,
|
|
109
|
+
)
|
|
102
110
|
async def acreate_structured_output(
|
|
103
111
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
104
112
|
) -> BaseModel:
|
|
@@ -148,10 +156,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
148
156
|
InstructorRetryException,
|
|
149
157
|
) as e:
|
|
150
158
|
if not (self.fallback_model and self.fallback_api_key):
|
|
151
|
-
raise
|
|
152
|
-
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
153
|
-
) from e
|
|
154
|
-
|
|
159
|
+
raise e
|
|
155
160
|
try:
|
|
156
161
|
return await self.aclient.chat.completions.create(
|
|
157
162
|
model=self.fallback_model,
|
|
@@ -186,8 +191,13 @@ class OpenAIAdapter(LLMInterface):
|
|
|
186
191
|
) from error
|
|
187
192
|
|
|
188
193
|
@observe
|
|
189
|
-
@
|
|
190
|
-
|
|
194
|
+
@retry(
|
|
195
|
+
stop=stop_after_delay(128),
|
|
196
|
+
wait=wait_exponential_jitter(2, 128),
|
|
197
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
198
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
199
|
+
reraise=True,
|
|
200
|
+
)
|
|
191
201
|
def create_structured_output(
|
|
192
202
|
self, text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
193
203
|
) -> BaseModel:
|
|
@@ -231,7 +241,13 @@ class OpenAIAdapter(LLMInterface):
|
|
|
231
241
|
max_retries=self.MAX_RETRIES,
|
|
232
242
|
)
|
|
233
243
|
|
|
234
|
-
@
|
|
244
|
+
@retry(
|
|
245
|
+
stop=stop_after_delay(128),
|
|
246
|
+
wait=wait_exponential_jitter(2, 128),
|
|
247
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
248
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
249
|
+
reraise=True,
|
|
250
|
+
)
|
|
235
251
|
async def create_transcript(self, input):
|
|
236
252
|
"""
|
|
237
253
|
Generate an audio transcript from a user query.
|
|
@@ -263,7 +279,13 @@ class OpenAIAdapter(LLMInterface):
|
|
|
263
279
|
|
|
264
280
|
return transcription
|
|
265
281
|
|
|
266
|
-
@
|
|
282
|
+
@retry(
|
|
283
|
+
stop=stop_after_delay(128),
|
|
284
|
+
wait=wait_exponential_jitter(2, 128),
|
|
285
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
286
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
287
|
+
reraise=True,
|
|
288
|
+
)
|
|
267
289
|
async def transcribe_image(self, input) -> BaseModel:
|
|
268
290
|
"""
|
|
269
291
|
Generate a transcription of an image from a user query.
|
|
@@ -27,11 +27,11 @@ class LoaderEngine:
|
|
|
27
27
|
|
|
28
28
|
self.default_loader_priority = [
|
|
29
29
|
"text_loader",
|
|
30
|
-
"advanced_pdf_loader",
|
|
31
30
|
"pypdf_loader",
|
|
32
31
|
"image_loader",
|
|
33
32
|
"audio_loader",
|
|
34
33
|
"unstructured_loader",
|
|
34
|
+
"advanced_pdf_loader",
|
|
35
35
|
]
|
|
36
36
|
|
|
37
37
|
def register_loader(self, loader: LoaderInterface) -> bool:
|
|
@@ -64,7 +64,9 @@ class LoaderEngine:
|
|
|
64
64
|
return True
|
|
65
65
|
|
|
66
66
|
def get_loader(
|
|
67
|
-
self,
|
|
67
|
+
self,
|
|
68
|
+
file_path: str,
|
|
69
|
+
preferred_loaders: dict[str, dict[str, Any]],
|
|
68
70
|
) -> Optional[LoaderInterface]:
|
|
69
71
|
"""
|
|
70
72
|
Get appropriate loader for a file.
|
|
@@ -76,14 +78,21 @@ class LoaderEngine:
|
|
|
76
78
|
Returns:
|
|
77
79
|
LoaderInterface that can handle the file, or None if not found
|
|
78
80
|
"""
|
|
81
|
+
from pathlib import Path
|
|
79
82
|
|
|
80
83
|
file_info = filetype.guess(file_path)
|
|
81
84
|
|
|
85
|
+
path_extension = Path(file_path).suffix.lstrip(".")
|
|
86
|
+
|
|
82
87
|
# Try preferred loaders first
|
|
83
88
|
if preferred_loaders:
|
|
84
89
|
for loader_name in preferred_loaders:
|
|
85
90
|
if loader_name in self._loaders:
|
|
86
91
|
loader = self._loaders[loader_name]
|
|
92
|
+
# Try with path extension first (for text formats like html)
|
|
93
|
+
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
|
|
94
|
+
return loader
|
|
95
|
+
# Fall back to content-detected extension
|
|
87
96
|
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
|
88
97
|
return loader
|
|
89
98
|
else:
|
|
@@ -93,6 +102,10 @@ class LoaderEngine:
|
|
|
93
102
|
for loader_name in self.default_loader_priority:
|
|
94
103
|
if loader_name in self._loaders:
|
|
95
104
|
loader = self._loaders[loader_name]
|
|
105
|
+
# Try with path extension first (for text formats like html)
|
|
106
|
+
if loader.can_handle(extension=path_extension, mime_type=file_info.mime):
|
|
107
|
+
return loader
|
|
108
|
+
# Fall back to content-detected extension
|
|
96
109
|
if loader.can_handle(extension=file_info.extension, mime_type=file_info.mime):
|
|
97
110
|
return loader
|
|
98
111
|
else:
|
|
@@ -105,8 +118,7 @@ class LoaderEngine:
|
|
|
105
118
|
async def load_file(
|
|
106
119
|
self,
|
|
107
120
|
file_path: str,
|
|
108
|
-
|
|
109
|
-
preferred_loaders: Optional[List[str]] = None,
|
|
121
|
+
preferred_loaders: dict[str, dict[str, Any]] = None,
|
|
110
122
|
**kwargs,
|
|
111
123
|
):
|
|
112
124
|
"""
|
|
@@ -114,7 +126,7 @@ class LoaderEngine:
|
|
|
114
126
|
|
|
115
127
|
Args:
|
|
116
128
|
file_path: Path to the file to be processed
|
|
117
|
-
preferred_loaders:
|
|
129
|
+
preferred_loaders: Dict of loader names to their configurations
|
|
118
130
|
**kwargs: Additional loader-specific configuration
|
|
119
131
|
|
|
120
132
|
Raises:
|
|
@@ -126,8 +138,16 @@ class LoaderEngine:
|
|
|
126
138
|
raise ValueError(f"No loader found for file: {file_path}")
|
|
127
139
|
|
|
128
140
|
logger.debug(f"Loading {file_path} with {loader.loader_name}")
|
|
129
|
-
|
|
130
|
-
|
|
141
|
+
|
|
142
|
+
# Extract loader-specific config from preferred_loaders
|
|
143
|
+
loader_config = {}
|
|
144
|
+
if preferred_loaders and loader.loader_name in preferred_loaders:
|
|
145
|
+
loader_config = preferred_loaders[loader.loader_name]
|
|
146
|
+
|
|
147
|
+
# Merge with any additional kwargs (kwargs take precedence)
|
|
148
|
+
merged_kwargs = {**loader_config, **kwargs}
|
|
149
|
+
|
|
150
|
+
return await loader.load(file_path, **merged_kwargs)
|
|
131
151
|
|
|
132
152
|
def get_available_loaders(self) -> List[str]:
|
|
133
153
|
"""
|
|
@@ -14,14 +14,6 @@ from cognee.infrastructure.loaders.external.pypdf_loader import PyPdfLoader
|
|
|
14
14
|
|
|
15
15
|
logger = get_logger(__name__)
|
|
16
16
|
|
|
17
|
-
try:
|
|
18
|
-
from unstructured.partition.pdf import partition_pdf
|
|
19
|
-
except ImportError as e:
|
|
20
|
-
logger.info(
|
|
21
|
-
"unstructured[pdf] not installed, can't use AdvancedPdfLoader, will use PyPdfLoader instead."
|
|
22
|
-
)
|
|
23
|
-
raise ImportError from e
|
|
24
|
-
|
|
25
17
|
|
|
26
18
|
@dataclass
|
|
27
19
|
class _PageBuffer:
|
|
@@ -88,6 +80,8 @@ class AdvancedPdfLoader(LoaderInterface):
|
|
|
88
80
|
**kwargs,
|
|
89
81
|
}
|
|
90
82
|
# Use partition to extract elements
|
|
83
|
+
from unstructured.partition.pdf import partition_pdf
|
|
84
|
+
|
|
91
85
|
elements = partition_pdf(**partition_kwargs)
|
|
92
86
|
|
|
93
87
|
# Process elements into text content
|