cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
cognee/api/v1/add/add.py
CHANGED
|
@@ -10,13 +10,14 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|
|
10
10
|
)
|
|
11
11
|
from cognee.modules.engine.operations.setup import setup
|
|
12
12
|
from cognee.tasks.ingestion import ingest_data, resolve_data_directories
|
|
13
|
+
from cognee.tasks.ingestion.data_item import DataItem
|
|
13
14
|
from cognee.shared.logging_utils import get_logger
|
|
14
15
|
|
|
15
16
|
logger = get_logger()
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
async def add(
|
|
19
|
-
data: Union[BinaryIO, list[BinaryIO], str, list[str]],
|
|
20
|
+
data: Union[BinaryIO, list[BinaryIO], str, list[str], DataItem, list[DataItem]],
|
|
20
21
|
dataset_name: str = "main_dataset",
|
|
21
22
|
user: User = None,
|
|
22
23
|
node_set: Optional[List[str]] = None,
|
|
@@ -90,6 +90,7 @@ def get_memify_router() -> APIRouter:
|
|
|
90
90
|
dataset=payload.dataset_id if payload.dataset_id else payload.dataset_name,
|
|
91
91
|
node_name=payload.node_name,
|
|
92
92
|
user=user,
|
|
93
|
+
run_in_background=payload.run_in_background,
|
|
93
94
|
)
|
|
94
95
|
|
|
95
96
|
if isinstance(memify_run, PipelineRunErrored):
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import json
|
|
2
3
|
import pydantic
|
|
3
4
|
from typing import Union
|
|
4
5
|
from functools import lru_cache
|
|
@@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
|
|
|
19
20
|
db_username: Union[str, None] = None # "cognee"
|
|
20
21
|
db_password: Union[str, None] = None # "cognee"
|
|
21
22
|
db_provider: str = "sqlite"
|
|
23
|
+
database_connect_args: Union[str, None] = None
|
|
22
24
|
|
|
23
25
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
|
24
26
|
|
|
@@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
|
|
|
30
32
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
|
31
33
|
self.db_path = databases_directory_path
|
|
32
34
|
|
|
35
|
+
# Parse database_connect_args if provided as JSON string
|
|
36
|
+
if self.database_connect_args and isinstance(self.database_connect_args, str):
|
|
37
|
+
try:
|
|
38
|
+
parsed_args = json.loads(self.database_connect_args)
|
|
39
|
+
if isinstance(parsed_args, dict):
|
|
40
|
+
self.database_connect_args = parsed_args
|
|
41
|
+
else:
|
|
42
|
+
self.database_connect_args = {}
|
|
43
|
+
except json.JSONDecodeError:
|
|
44
|
+
self.database_connect_args = {}
|
|
45
|
+
|
|
33
46
|
return self
|
|
34
47
|
|
|
35
48
|
def to_dict(self) -> dict:
|
|
@@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
|
|
|
40
53
|
--------
|
|
41
54
|
|
|
42
55
|
- dict: A dictionary containing database configuration settings including db_path,
|
|
43
|
-
db_name, db_host, db_port, db_username, db_password, and
|
|
56
|
+
db_name, db_host, db_port, db_username, db_password, db_provider, and
|
|
57
|
+
database_connect_args.
|
|
44
58
|
"""
|
|
45
59
|
return {
|
|
46
60
|
"db_path": self.db_path,
|
|
@@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
|
|
|
50
64
|
"db_username": self.db_username,
|
|
51
65
|
"db_password": self.db_password,
|
|
52
66
|
"db_provider": self.db_provider,
|
|
67
|
+
"database_connect_args": self.database_connect_args,
|
|
53
68
|
}
|
|
54
69
|
|
|
55
70
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from sqlalchemy import URL
|
|
1
2
|
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
|
2
3
|
from functools import lru_cache
|
|
3
4
|
|
|
@@ -11,6 +12,7 @@ def create_relational_engine(
|
|
|
11
12
|
db_username: str,
|
|
12
13
|
db_password: str,
|
|
13
14
|
db_provider: str,
|
|
15
|
+
database_connect_args: dict = None,
|
|
14
16
|
):
|
|
15
17
|
"""
|
|
16
18
|
Create a relational database engine based on the specified parameters.
|
|
@@ -29,6 +31,7 @@ def create_relational_engine(
|
|
|
29
31
|
- db_password (str): The password for database authentication, required for
|
|
30
32
|
PostgreSQL.
|
|
31
33
|
- db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
|
|
34
|
+
- database_connect_args (dict, optional): Database driver connection arguments.
|
|
32
35
|
|
|
33
36
|
Returns:
|
|
34
37
|
--------
|
|
@@ -43,12 +46,19 @@ def create_relational_engine(
|
|
|
43
46
|
# Test if asyncpg is available
|
|
44
47
|
import asyncpg
|
|
45
48
|
|
|
46
|
-
|
|
47
|
-
|
|
49
|
+
# Handle special characters in username and password like # or @
|
|
50
|
+
connection_string = URL.create(
|
|
51
|
+
"postgresql+asyncpg",
|
|
52
|
+
username=db_username,
|
|
53
|
+
password=db_password,
|
|
54
|
+
host=db_host,
|
|
55
|
+
port=int(db_port),
|
|
56
|
+
database=db_name,
|
|
48
57
|
)
|
|
58
|
+
|
|
49
59
|
except ImportError:
|
|
50
60
|
raise ImportError(
|
|
51
61
|
"PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
|
|
52
62
|
)
|
|
53
63
|
|
|
54
|
-
return SQLAlchemyAdapter(connection_string)
|
|
64
|
+
return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)
|
|
@@ -29,10 +29,31 @@ class SQLAlchemyAdapter:
|
|
|
29
29
|
functions.
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
|
-
def __init__(self, connection_string: str):
|
|
32
|
+
def __init__(self, connection_string: str, connect_args: dict = None):
|
|
33
|
+
"""
|
|
34
|
+
Initialize the SQLAlchemy adapter with connection settings.
|
|
35
|
+
|
|
36
|
+
Parameters:
|
|
37
|
+
-----------
|
|
38
|
+
connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
|
|
39
|
+
or 'postgresql://user:pass@host:port/db').
|
|
40
|
+
connect_args (dict, optional): Database driver connection arguments.
|
|
41
|
+
Configuration is loaded from RelationalConfig.database_connect_args, which reads
|
|
42
|
+
from the DATABASE_CONNECT_ARGS environment variable.
|
|
43
|
+
|
|
44
|
+
Examples:
|
|
45
|
+
PostgreSQL with SSL:
|
|
46
|
+
DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
|
|
47
|
+
|
|
48
|
+
SQLite with custom timeout:
|
|
49
|
+
DATABASE_CONNECT_ARGS='{"timeout": 60}'
|
|
50
|
+
"""
|
|
33
51
|
self.db_path: str = None
|
|
34
52
|
self.db_uri: str = connection_string
|
|
35
53
|
|
|
54
|
+
# Use provided connect_args (already parsed from config)
|
|
55
|
+
final_connect_args = connect_args or {}
|
|
56
|
+
|
|
36
57
|
if "sqlite" in connection_string:
|
|
37
58
|
[prefix, db_path] = connection_string.split("///")
|
|
38
59
|
self.db_path = db_path
|
|
@@ -53,7 +74,7 @@ class SQLAlchemyAdapter:
|
|
|
53
74
|
self.engine = create_async_engine(
|
|
54
75
|
connection_string,
|
|
55
76
|
poolclass=NullPool,
|
|
56
|
-
connect_args={"timeout": 30},
|
|
77
|
+
connect_args={**{"timeout": 30}, **final_connect_args},
|
|
57
78
|
)
|
|
58
79
|
else:
|
|
59
80
|
self.engine = create_async_engine(
|
|
@@ -63,6 +84,7 @@ class SQLAlchemyAdapter:
|
|
|
63
84
|
pool_recycle=280,
|
|
64
85
|
pool_pre_ping=True,
|
|
65
86
|
pool_timeout=280,
|
|
87
|
+
connect_args=final_connect_args,
|
|
66
88
|
)
|
|
67
89
|
|
|
68
90
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from sqlalchemy import URL
|
|
2
|
+
|
|
1
3
|
from .supported_databases import supported_databases
|
|
2
4
|
from .embeddings import get_embedding_engine
|
|
3
5
|
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
|
@@ -66,8 +68,13 @@ def create_vector_engine(
|
|
|
66
68
|
if not (db_host and db_port and db_name and db_username and db_password):
|
|
67
69
|
raise EnvironmentError("Missing requred pgvector credentials!")
|
|
68
70
|
|
|
69
|
-
connection_string
|
|
70
|
-
|
|
71
|
+
connection_string = URL.create(
|
|
72
|
+
"postgresql+asyncpg",
|
|
73
|
+
username=db_username,
|
|
74
|
+
password=db_password,
|
|
75
|
+
host=db_host,
|
|
76
|
+
port=int(db_port),
|
|
77
|
+
database=db_name,
|
|
71
78
|
)
|
|
72
79
|
|
|
73
80
|
try:
|
|
@@ -37,19 +37,6 @@ class LLMGateway:
|
|
|
37
37
|
**kwargs,
|
|
38
38
|
)
|
|
39
39
|
|
|
40
|
-
@staticmethod
|
|
41
|
-
def create_structured_output(
|
|
42
|
-
text_input: str, system_prompt: str, response_model: Type[BaseModel]
|
|
43
|
-
) -> BaseModel:
|
|
44
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
|
45
|
-
get_llm_client,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
llm_client = get_llm_client()
|
|
49
|
-
return llm_client.create_structured_output(
|
|
50
|
-
text_input=text_input, system_prompt=system_prompt, response_model=response_model
|
|
51
|
-
)
|
|
52
|
-
|
|
53
40
|
@staticmethod
|
|
54
41
|
def create_transcript(input) -> Coroutine:
|
|
55
42
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py
CHANGED
|
@@ -3,7 +3,9 @@ from typing import Type
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
import litellm
|
|
5
5
|
import instructor
|
|
6
|
+
import anthropic
|
|
6
7
|
from cognee.shared.logging_utils import get_logger
|
|
8
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
7
9
|
from tenacity import (
|
|
8
10
|
retry,
|
|
9
11
|
stop_after_delay,
|
|
@@ -12,38 +14,41 @@ from tenacity import (
|
|
|
12
14
|
before_sleep_log,
|
|
13
15
|
)
|
|
14
16
|
|
|
15
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.
|
|
16
|
-
|
|
17
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
|
18
|
+
GenericAPIAdapter,
|
|
17
19
|
)
|
|
18
20
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
19
21
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
20
22
|
|
|
21
23
|
logger = get_logger()
|
|
24
|
+
observe = get_observe()
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
class AnthropicAdapter(
|
|
27
|
+
class AnthropicAdapter(GenericAPIAdapter):
|
|
25
28
|
"""
|
|
26
29
|
Adapter for interfacing with the Anthropic API, enabling structured output generation
|
|
27
30
|
and prompt display.
|
|
28
31
|
"""
|
|
29
32
|
|
|
30
|
-
name = "Anthropic"
|
|
31
|
-
model: str
|
|
32
33
|
default_instructor_mode = "anthropic_tools"
|
|
33
34
|
|
|
34
|
-
def __init__(
|
|
35
|
-
|
|
36
|
-
|
|
35
|
+
def __init__(
|
|
36
|
+
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
|
|
37
|
+
):
|
|
38
|
+
super().__init__(
|
|
39
|
+
api_key=api_key,
|
|
40
|
+
model=model,
|
|
41
|
+
max_completion_tokens=max_completion_tokens,
|
|
42
|
+
name="Anthropic",
|
|
43
|
+
)
|
|
37
44
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
38
45
|
|
|
39
46
|
self.aclient = instructor.patch(
|
|
40
|
-
create=anthropic.AsyncAnthropic(api_key=
|
|
47
|
+
create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
|
|
41
48
|
mode=instructor.Mode(self.instructor_mode),
|
|
42
49
|
)
|
|
43
50
|
|
|
44
|
-
|
|
45
|
-
self.max_completion_tokens = max_completion_tokens
|
|
46
|
-
|
|
51
|
+
@observe(as_type="generation")
|
|
47
52
|
@retry(
|
|
48
53
|
stop=stop_after_delay(128),
|
|
49
54
|
wait=wait_exponential_jitter(8, 128),
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Adapter for
|
|
1
|
+
"""Adapter for Gemini API LLM provider"""
|
|
2
2
|
|
|
3
3
|
import litellm
|
|
4
4
|
import instructor
|
|
@@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError
|
|
|
8
8
|
from litellm.exceptions import ContentPolicyViolationError
|
|
9
9
|
from instructor.core import InstructorRetryException
|
|
10
10
|
|
|
11
|
-
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
12
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
13
|
-
LLMInterface,
|
|
14
|
-
)
|
|
15
11
|
import logging
|
|
16
12
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
17
|
-
|
|
13
|
+
|
|
18
14
|
from tenacity import (
|
|
19
15
|
retry,
|
|
20
16
|
stop_after_delay,
|
|
@@ -23,55 +19,65 @@ from tenacity import (
|
|
|
23
19
|
before_sleep_log,
|
|
24
20
|
)
|
|
25
21
|
|
|
22
|
+
from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
23
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
|
24
|
+
GenericAPIAdapter,
|
|
25
|
+
)
|
|
26
|
+
from cognee.shared.logging_utils import get_logger
|
|
27
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
28
|
+
|
|
26
29
|
logger = get_logger()
|
|
30
|
+
observe = get_observe()
|
|
27
31
|
|
|
28
32
|
|
|
29
|
-
class GeminiAdapter(
|
|
33
|
+
class GeminiAdapter(GenericAPIAdapter):
|
|
30
34
|
"""
|
|
31
35
|
Adapter for Gemini API LLM provider.
|
|
32
36
|
|
|
33
37
|
This class initializes the API adapter with necessary credentials and configurations for
|
|
34
38
|
interacting with the gemini LLM models. It provides methods for creating structured outputs
|
|
35
|
-
based on user input and system prompts.
|
|
39
|
+
based on user input and system prompts, as well as multimodal processing capabilities.
|
|
36
40
|
|
|
37
41
|
Public methods:
|
|
38
|
-
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
|
39
|
-
|
|
42
|
+
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
|
|
43
|
+
- create_transcript(input) -> BaseModel: Transcribe audio files to text
|
|
44
|
+
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
|
|
40
45
|
"""
|
|
41
46
|
|
|
42
|
-
name: str
|
|
43
|
-
model: str
|
|
44
|
-
api_key: str
|
|
45
47
|
default_instructor_mode = "json_mode"
|
|
46
48
|
|
|
47
49
|
def __init__(
|
|
48
50
|
self,
|
|
49
|
-
endpoint,
|
|
50
51
|
api_key: str,
|
|
51
52
|
model: str,
|
|
52
|
-
api_version: str,
|
|
53
53
|
max_completion_tokens: int,
|
|
54
|
+
endpoint: str = None,
|
|
55
|
+
api_version: str = None,
|
|
56
|
+
transcription_model: str = None,
|
|
54
57
|
instructor_mode: str = None,
|
|
55
58
|
fallback_model: str = None,
|
|
56
59
|
fallback_api_key: str = None,
|
|
57
60
|
fallback_endpoint: str = None,
|
|
58
61
|
):
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
62
|
+
super().__init__(
|
|
63
|
+
api_key=api_key,
|
|
64
|
+
model=model,
|
|
65
|
+
max_completion_tokens=max_completion_tokens,
|
|
66
|
+
name="Gemini",
|
|
67
|
+
endpoint=endpoint,
|
|
68
|
+
api_version=api_version,
|
|
69
|
+
transcription_model=transcription_model,
|
|
70
|
+
fallback_model=fallback_model,
|
|
71
|
+
fallback_api_key=fallback_api_key,
|
|
72
|
+
fallback_endpoint=fallback_endpoint,
|
|
73
|
+
)
|
|
69
74
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
70
75
|
|
|
71
76
|
self.aclient = instructor.from_litellm(
|
|
72
77
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
|
73
78
|
)
|
|
74
79
|
|
|
80
|
+
@observe(as_type="generation")
|
|
75
81
|
@retry(
|
|
76
82
|
stop=stop_after_delay(128),
|
|
77
83
|
wait=wait_exponential_jitter(8, 128),
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
"""Adapter for Generic API LLM provider API"""
|
|
2
2
|
|
|
3
|
+
import base64
|
|
4
|
+
import mimetypes
|
|
3
5
|
import litellm
|
|
4
6
|
import instructor
|
|
5
|
-
from typing import Type
|
|
7
|
+
from typing import Type, Optional
|
|
6
8
|
from pydantic import BaseModel
|
|
7
9
|
from openai import ContentFilterFinishReasonError
|
|
8
10
|
from litellm.exceptions import ContentPolicyViolationError
|
|
@@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
|
|
|
12
14
|
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
|
|
13
15
|
LLMInterface,
|
|
14
16
|
)
|
|
17
|
+
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
18
|
+
from cognee.modules.observability.get_observe import get_observe
|
|
15
19
|
import logging
|
|
16
20
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
17
21
|
from cognee.shared.logging_utils import get_logger
|
|
@@ -23,7 +27,12 @@ from tenacity import (
|
|
|
23
27
|
before_sleep_log,
|
|
24
28
|
)
|
|
25
29
|
|
|
30
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
|
31
|
+
TranscriptionReturnType,
|
|
32
|
+
)
|
|
33
|
+
|
|
26
34
|
logger = get_logger()
|
|
35
|
+
observe = get_observe()
|
|
27
36
|
|
|
28
37
|
|
|
29
38
|
class GenericAPIAdapter(LLMInterface):
|
|
@@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
39
48
|
Type[BaseModel]) -> BaseModel
|
|
40
49
|
"""
|
|
41
50
|
|
|
42
|
-
|
|
43
|
-
model: str
|
|
44
|
-
api_key: str
|
|
51
|
+
MAX_RETRIES = 5
|
|
45
52
|
default_instructor_mode = "json_mode"
|
|
46
53
|
|
|
47
54
|
def __init__(
|
|
48
55
|
self,
|
|
49
|
-
endpoint,
|
|
50
56
|
api_key: str,
|
|
51
57
|
model: str,
|
|
52
|
-
name: str,
|
|
53
58
|
max_completion_tokens: int,
|
|
59
|
+
name: str,
|
|
60
|
+
endpoint: str = None,
|
|
61
|
+
api_version: str = None,
|
|
62
|
+
transcription_model: str = None,
|
|
63
|
+
image_transcribe_model: str = None,
|
|
54
64
|
instructor_mode: str = None,
|
|
55
65
|
fallback_model: str = None,
|
|
56
66
|
fallback_api_key: str = None,
|
|
@@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
59
69
|
self.name = name
|
|
60
70
|
self.model = model
|
|
61
71
|
self.api_key = api_key
|
|
72
|
+
self.api_version = api_version
|
|
62
73
|
self.endpoint = endpoint
|
|
63
74
|
self.max_completion_tokens = max_completion_tokens
|
|
64
|
-
|
|
75
|
+
self.transcription_model = transcription_model or model
|
|
76
|
+
self.image_transcribe_model = image_transcribe_model or model
|
|
65
77
|
self.fallback_model = fallback_model
|
|
66
78
|
self.fallback_api_key = fallback_api_key
|
|
67
79
|
self.fallback_endpoint = fallback_endpoint
|
|
@@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
72
84
|
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
|
|
73
85
|
)
|
|
74
86
|
|
|
87
|
+
@observe(as_type="generation")
|
|
75
88
|
@retry(
|
|
76
89
|
stop=stop_after_delay(128),
|
|
77
90
|
wait=wait_exponential_jitter(8, 128),
|
|
@@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface):
|
|
|
173
186
|
raise ContentPolicyFilterError(
|
|
174
187
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
175
188
|
) from error
|
|
189
|
+
|
|
190
|
+
@observe(as_type="transcription")
|
|
191
|
+
@retry(
|
|
192
|
+
stop=stop_after_delay(128),
|
|
193
|
+
wait=wait_exponential_jitter(2, 128),
|
|
194
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
195
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
196
|
+
reraise=True,
|
|
197
|
+
)
|
|
198
|
+
async def create_transcript(self, input) -> TranscriptionReturnType:
|
|
199
|
+
"""
|
|
200
|
+
Generate an audio transcript from a user query.
|
|
201
|
+
|
|
202
|
+
This method creates a transcript from the specified audio file, raising a
|
|
203
|
+
FileNotFoundError if the file does not exist. The audio file is processed and the
|
|
204
|
+
transcription is retrieved from the API.
|
|
205
|
+
|
|
206
|
+
Parameters:
|
|
207
|
+
-----------
|
|
208
|
+
- input: The path to the audio file that needs to be transcribed.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
--------
|
|
212
|
+
The generated transcription of the audio file.
|
|
213
|
+
"""
|
|
214
|
+
async with open_data_file(input, mode="rb") as audio_file:
|
|
215
|
+
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
|
|
216
|
+
mime_type, _ = mimetypes.guess_type(input)
|
|
217
|
+
if not mime_type or not mime_type.startswith("audio/"):
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
|
|
220
|
+
)
|
|
221
|
+
response = await litellm.acompletion(
|
|
222
|
+
model=self.transcription_model,
|
|
223
|
+
messages=[
|
|
224
|
+
{
|
|
225
|
+
"role": "user",
|
|
226
|
+
"content": [
|
|
227
|
+
{
|
|
228
|
+
"type": "file",
|
|
229
|
+
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
|
|
230
|
+
},
|
|
231
|
+
{"type": "text", "text": "Transcribe the following audio precisely."},
|
|
232
|
+
],
|
|
233
|
+
}
|
|
234
|
+
],
|
|
235
|
+
api_key=self.api_key,
|
|
236
|
+
api_version=self.api_version,
|
|
237
|
+
max_completion_tokens=self.max_completion_tokens,
|
|
238
|
+
api_base=self.endpoint,
|
|
239
|
+
max_retries=self.MAX_RETRIES,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return TranscriptionReturnType(response.choices[0].message.content, response)
|
|
243
|
+
|
|
244
|
+
@observe(as_type="transcribe_image")
|
|
245
|
+
@retry(
|
|
246
|
+
stop=stop_after_delay(128),
|
|
247
|
+
wait=wait_exponential_jitter(2, 128),
|
|
248
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
249
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
250
|
+
reraise=True,
|
|
251
|
+
)
|
|
252
|
+
async def transcribe_image(self, input) -> BaseModel:
|
|
253
|
+
"""
|
|
254
|
+
Generate a transcription of an image from a user query.
|
|
255
|
+
|
|
256
|
+
This method encodes the image and sends a request to the API to obtain a
|
|
257
|
+
description of the contents of the image.
|
|
258
|
+
|
|
259
|
+
Parameters:
|
|
260
|
+
-----------
|
|
261
|
+
- input: The path to the image file that needs to be transcribed.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
--------
|
|
265
|
+
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
266
|
+
BaseModel.
|
|
267
|
+
"""
|
|
268
|
+
async with open_data_file(input, mode="rb") as image_file:
|
|
269
|
+
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
270
|
+
mime_type, _ = mimetypes.guess_type(input)
|
|
271
|
+
if not mime_type or not mime_type.startswith("image/"):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
|
|
274
|
+
)
|
|
275
|
+
response = await litellm.acompletion(
|
|
276
|
+
model=self.image_transcribe_model,
|
|
277
|
+
messages=[
|
|
278
|
+
{
|
|
279
|
+
"role": "user",
|
|
280
|
+
"content": [
|
|
281
|
+
{
|
|
282
|
+
"type": "text",
|
|
283
|
+
"text": "What's in this image?",
|
|
284
|
+
},
|
|
285
|
+
{
|
|
286
|
+
"type": "image_url",
|
|
287
|
+
"image_url": {
|
|
288
|
+
"url": f"data:{mime_type};base64,{encoded_image}",
|
|
289
|
+
},
|
|
290
|
+
},
|
|
291
|
+
],
|
|
292
|
+
}
|
|
293
|
+
],
|
|
294
|
+
api_key=self.api_key,
|
|
295
|
+
api_base=self.endpoint,
|
|
296
|
+
api_version=self.api_version,
|
|
297
|
+
max_completion_tokens=300,
|
|
298
|
+
max_retries=self.MAX_RETRIES,
|
|
299
|
+
)
|
|
300
|
+
return response
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py
CHANGED
|
@@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
103
103
|
llm_config.llm_api_key,
|
|
104
104
|
llm_config.llm_model,
|
|
105
105
|
"Ollama",
|
|
106
|
-
max_completion_tokens
|
|
106
|
+
max_completion_tokens,
|
|
107
107
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
108
108
|
)
|
|
109
109
|
|
|
@@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
113
113
|
)
|
|
114
114
|
|
|
115
115
|
return AnthropicAdapter(
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
llm_config.llm_api_key,
|
|
117
|
+
llm_config.llm_model,
|
|
118
|
+
max_completion_tokens,
|
|
118
119
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
119
120
|
)
|
|
120
121
|
|
|
@@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True):
|
|
|
127
128
|
)
|
|
128
129
|
|
|
129
130
|
return GenericAPIAdapter(
|
|
130
|
-
llm_config.llm_endpoint,
|
|
131
131
|
llm_config.llm_api_key,
|
|
132
132
|
llm_config.llm_model,
|
|
133
|
+
max_completion_tokens,
|
|
133
134
|
"Custom",
|
|
134
|
-
max_completion_tokens=max_completion_tokens,
|
|
135
135
|
instructor_mode=llm_config.llm_instructor_mode.lower(),
|
|
136
136
|
fallback_api_key=llm_config.fallback_api_key,
|
|
137
137
|
fallback_endpoint=llm_config.fallback_endpoint,
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py
CHANGED
|
@@ -3,18 +3,14 @@
|
|
|
3
3
|
from typing import Type, Protocol
|
|
4
4
|
from abc import abstractmethod
|
|
5
5
|
from pydantic import BaseModel
|
|
6
|
-
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class LLMInterface(Protocol):
|
|
10
9
|
"""
|
|
11
|
-
Define an interface for LLM models with methods for structured output and prompt
|
|
12
|
-
display.
|
|
10
|
+
Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
|
|
13
11
|
|
|
14
12
|
Methods:
|
|
15
|
-
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
|
|
16
|
-
Type[BaseModel])
|
|
17
|
-
- show_prompt(text_input: str, system_prompt: str)
|
|
13
|
+
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
|
|
18
14
|
"""
|
|
19
15
|
|
|
20
16
|
@abstractmethod
|