cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/infrastructure/databases/relational/config.py +16 -1
  5. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  8. cognee/infrastructure/llm/LLMGateway.py +0 -13
  9. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  18. cognee/modules/data/models/Data.py +2 -1
  19. cognee/modules/retrieval/triplet_retriever.py +1 -1
  20. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  21. cognee/tasks/ingestion/data_item.py +8 -0
  22. cognee/tasks/ingestion/ingest_data.py +12 -1
  23. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  24. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  25. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  26. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  28. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  29. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  30. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  31. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  32. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  33. cognee/tests/test_custom_data_label.py +68 -0
  34. cognee/tests/test_search_db.py +334 -181
  35. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  36. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  37. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  38. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  39. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  40. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  41. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  43. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  44. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  45. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  46. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  47. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  48. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  49. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  50. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  51. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  52. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
  53. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  54. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  55. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  56. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
cognee/api/v1/add/add.py CHANGED
@@ -10,13 +10,14 @@ from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
10
10
  )
11
11
  from cognee.modules.engine.operations.setup import setup
12
12
  from cognee.tasks.ingestion import ingest_data, resolve_data_directories
13
+ from cognee.tasks.ingestion.data_item import DataItem
13
14
  from cognee.shared.logging_utils import get_logger
14
15
 
15
16
  logger = get_logger()
16
17
 
17
18
 
18
19
  async def add(
19
- data: Union[BinaryIO, list[BinaryIO], str, list[str]],
20
+ data: Union[BinaryIO, list[BinaryIO], str, list[str], DataItem, list[DataItem]],
20
21
  dataset_name: str = "main_dataset",
21
22
  user: User = None,
22
23
  node_set: Optional[List[str]] = None,
@@ -44,6 +44,7 @@ class DatasetDTO(OutDTO):
44
44
  class DataDTO(OutDTO):
45
45
  id: UUID
46
46
  name: str
47
+ label: Optional[str] = None
47
48
  created_at: datetime
48
49
  updated_at: Optional[datetime] = None
49
50
  extension: str
@@ -90,6 +90,7 @@ def get_memify_router() -> APIRouter:
90
90
  dataset=payload.dataset_id if payload.dataset_id else payload.dataset_name,
91
91
  node_name=payload.node_name,
92
92
  user=user,
93
+ run_in_background=payload.run_in_background,
93
94
  )
94
95
 
95
96
  if isinstance(memify_run, PipelineRunErrored):
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import json
2
3
  import pydantic
3
4
  from typing import Union
4
5
  from functools import lru_cache
@@ -19,6 +20,7 @@ class RelationalConfig(BaseSettings):
19
20
  db_username: Union[str, None] = None # "cognee"
20
21
  db_password: Union[str, None] = None # "cognee"
21
22
  db_provider: str = "sqlite"
23
+ database_connect_args: Union[str, None] = None
22
24
 
23
25
  model_config = SettingsConfigDict(env_file=".env", extra="allow")
24
26
 
@@ -30,6 +32,17 @@ class RelationalConfig(BaseSettings):
30
32
  databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
31
33
  self.db_path = databases_directory_path
32
34
 
35
+ # Parse database_connect_args if provided as JSON string
36
+ if self.database_connect_args and isinstance(self.database_connect_args, str):
37
+ try:
38
+ parsed_args = json.loads(self.database_connect_args)
39
+ if isinstance(parsed_args, dict):
40
+ self.database_connect_args = parsed_args
41
+ else:
42
+ self.database_connect_args = {}
43
+ except json.JSONDecodeError:
44
+ self.database_connect_args = {}
45
+
33
46
  return self
34
47
 
35
48
  def to_dict(self) -> dict:
@@ -40,7 +53,8 @@ class RelationalConfig(BaseSettings):
40
53
  --------
41
54
 
42
55
  - dict: A dictionary containing database configuration settings including db_path,
43
- db_name, db_host, db_port, db_username, db_password, and db_provider.
56
+ db_name, db_host, db_port, db_username, db_password, db_provider, and
57
+ database_connect_args.
44
58
  """
45
59
  return {
46
60
  "db_path": self.db_path,
@@ -50,6 +64,7 @@ class RelationalConfig(BaseSettings):
50
64
  "db_username": self.db_username,
51
65
  "db_password": self.db_password,
52
66
  "db_provider": self.db_provider,
67
+ "database_connect_args": self.database_connect_args,
53
68
  }
54
69
 
55
70
 
@@ -1,3 +1,4 @@
1
+ from sqlalchemy import URL
1
2
  from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
2
3
  from functools import lru_cache
3
4
 
@@ -11,6 +12,7 @@ def create_relational_engine(
11
12
  db_username: str,
12
13
  db_password: str,
13
14
  db_provider: str,
15
+ database_connect_args: dict = None,
14
16
  ):
15
17
  """
16
18
  Create a relational database engine based on the specified parameters.
@@ -29,6 +31,7 @@ def create_relational_engine(
29
31
  - db_password (str): The password for database authentication, required for
30
32
  PostgreSQL.
31
33
  - db_provider (str): The type of database provider (e.g., 'sqlite' or 'postgres').
34
+ - database_connect_args (dict, optional): Database driver connection arguments.
32
35
 
33
36
  Returns:
34
37
  --------
@@ -43,12 +46,19 @@ def create_relational_engine(
43
46
  # Test if asyncpg is available
44
47
  import asyncpg
45
48
 
46
- connection_string = (
47
- f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
49
+ # Handle special characters in username and password like # or @
50
+ connection_string = URL.create(
51
+ "postgresql+asyncpg",
52
+ username=db_username,
53
+ password=db_password,
54
+ host=db_host,
55
+ port=int(db_port),
56
+ database=db_name,
48
57
  )
58
+
49
59
  except ImportError:
50
60
  raise ImportError(
51
61
  "PostgreSQL dependencies are not installed. Please install with 'pip install cognee\"[postgres]\"' or 'pip install cognee\"[postgres-binary]\"' to use PostgreSQL functionality."
52
62
  )
53
63
 
54
- return SQLAlchemyAdapter(connection_string)
64
+ return SQLAlchemyAdapter(connection_string, connect_args=database_connect_args)
@@ -29,10 +29,31 @@ class SQLAlchemyAdapter:
29
29
  functions.
30
30
  """
31
31
 
32
- def __init__(self, connection_string: str):
32
+ def __init__(self, connection_string: str, connect_args: dict = None):
33
+ """
34
+ Initialize the SQLAlchemy adapter with connection settings.
35
+
36
+ Parameters:
37
+ -----------
38
+ connection_string (str): The database connection string (e.g., 'sqlite:///path/to/db'
39
+ or 'postgresql://user:pass@host:port/db').
40
+ connect_args (dict, optional): Database driver connection arguments.
41
+ Configuration is loaded from RelationalConfig.database_connect_args, which reads
42
+ from the DATABASE_CONNECT_ARGS environment variable.
43
+
44
+ Examples:
45
+ PostgreSQL with SSL:
46
+ DATABASE_CONNECT_ARGS='{"sslmode": "require", "connect_timeout": 10}'
47
+
48
+ SQLite with custom timeout:
49
+ DATABASE_CONNECT_ARGS='{"timeout": 60}'
50
+ """
33
51
  self.db_path: str = None
34
52
  self.db_uri: str = connection_string
35
53
 
54
+ # Use provided connect_args (already parsed from config)
55
+ final_connect_args = connect_args or {}
56
+
36
57
  if "sqlite" in connection_string:
37
58
  [prefix, db_path] = connection_string.split("///")
38
59
  self.db_path = db_path
@@ -53,7 +74,7 @@ class SQLAlchemyAdapter:
53
74
  self.engine = create_async_engine(
54
75
  connection_string,
55
76
  poolclass=NullPool,
56
- connect_args={"timeout": 30},
77
+ connect_args={**{"timeout": 30}, **final_connect_args},
57
78
  )
58
79
  else:
59
80
  self.engine = create_async_engine(
@@ -63,6 +84,7 @@ class SQLAlchemyAdapter:
63
84
  pool_recycle=280,
64
85
  pool_pre_ping=True,
65
86
  pool_timeout=280,
87
+ connect_args=final_connect_args,
66
88
  )
67
89
 
68
90
  self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
@@ -1,3 +1,5 @@
1
+ from sqlalchemy import URL
2
+
1
3
  from .supported_databases import supported_databases
2
4
  from .embeddings import get_embedding_engine
3
5
  from cognee.infrastructure.databases.graph.config import get_graph_context_config
@@ -66,8 +68,13 @@ def create_vector_engine(
66
68
  if not (db_host and db_port and db_name and db_username and db_password):
67
69
  raise EnvironmentError("Missing requred pgvector credentials!")
68
70
 
69
- connection_string: str = (
70
- f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
71
+ connection_string = URL.create(
72
+ "postgresql+asyncpg",
73
+ username=db_username,
74
+ password=db_password,
75
+ host=db_host,
76
+ port=int(db_port),
77
+ database=db_name,
71
78
  )
72
79
 
73
80
  try:
@@ -37,19 +37,6 @@ class LLMGateway:
37
37
  **kwargs,
38
38
  )
39
39
 
40
- @staticmethod
41
- def create_structured_output(
42
- text_input: str, system_prompt: str, response_model: Type[BaseModel]
43
- ) -> BaseModel:
44
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
45
- get_llm_client,
46
- )
47
-
48
- llm_client = get_llm_client()
49
- return llm_client.create_structured_output(
50
- text_input=text_input, system_prompt=system_prompt, response_model=response_model
51
- )
52
-
53
40
  @staticmethod
54
41
  def create_transcript(input) -> Coroutine:
55
42
  from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
@@ -3,7 +3,9 @@ from typing import Type
3
3
  from pydantic import BaseModel
4
4
  import litellm
5
5
  import instructor
6
+ import anthropic
6
7
  from cognee.shared.logging_utils import get_logger
8
+ from cognee.modules.observability.get_observe import get_observe
7
9
  from tenacity import (
8
10
  retry,
9
11
  stop_after_delay,
@@ -12,38 +14,41 @@ from tenacity import (
12
14
  before_sleep_log,
13
15
  )
14
16
 
15
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
16
- LLMInterface,
17
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
18
+ GenericAPIAdapter,
17
19
  )
18
20
  from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
19
21
  from cognee.infrastructure.llm.config import get_llm_config
20
22
 
21
23
  logger = get_logger()
24
+ observe = get_observe()
22
25
 
23
26
 
24
- class AnthropicAdapter(LLMInterface):
27
+ class AnthropicAdapter(GenericAPIAdapter):
25
28
  """
26
29
  Adapter for interfacing with the Anthropic API, enabling structured output generation
27
30
  and prompt display.
28
31
  """
29
32
 
30
- name = "Anthropic"
31
- model: str
32
33
  default_instructor_mode = "anthropic_tools"
33
34
 
34
- def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
35
- import anthropic
36
-
35
+ def __init__(
36
+ self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
37
+ ):
38
+ super().__init__(
39
+ api_key=api_key,
40
+ model=model,
41
+ max_completion_tokens=max_completion_tokens,
42
+ name="Anthropic",
43
+ )
37
44
  self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
38
45
 
39
46
  self.aclient = instructor.patch(
40
- create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
47
+ create=anthropic.AsyncAnthropic(api_key=self.api_key).messages.create,
41
48
  mode=instructor.Mode(self.instructor_mode),
42
49
  )
43
50
 
44
- self.model = model
45
- self.max_completion_tokens = max_completion_tokens
46
-
51
+ @observe(as_type="generation")
47
52
  @retry(
48
53
  stop=stop_after_delay(128),
49
54
  wait=wait_exponential_jitter(8, 128),
@@ -1,4 +1,4 @@
1
- """Adapter for Generic API LLM provider API"""
1
+ """Adapter for Gemini API LLM provider"""
2
2
 
3
3
  import litellm
4
4
  import instructor
@@ -8,13 +8,9 @@ from openai import ContentFilterFinishReasonError
8
8
  from litellm.exceptions import ContentPolicyViolationError
9
9
  from instructor.core import InstructorRetryException
10
10
 
11
- from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
12
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
13
- LLMInterface,
14
- )
15
11
  import logging
16
12
  from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
17
- from cognee.shared.logging_utils import get_logger
13
+
18
14
  from tenacity import (
19
15
  retry,
20
16
  stop_after_delay,
@@ -23,55 +19,65 @@ from tenacity import (
23
19
  before_sleep_log,
24
20
  )
25
21
 
22
+ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
23
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
24
+ GenericAPIAdapter,
25
+ )
26
+ from cognee.shared.logging_utils import get_logger
27
+ from cognee.modules.observability.get_observe import get_observe
28
+
26
29
  logger = get_logger()
30
+ observe = get_observe()
27
31
 
28
32
 
29
- class GeminiAdapter(LLMInterface):
33
+ class GeminiAdapter(GenericAPIAdapter):
30
34
  """
31
35
  Adapter for Gemini API LLM provider.
32
36
 
33
37
  This class initializes the API adapter with necessary credentials and configurations for
34
38
  interacting with the gemini LLM models. It provides methods for creating structured outputs
35
- based on user input and system prompts.
39
+ based on user input and system prompts, as well as multimodal processing capabilities.
36
40
 
37
41
  Public methods:
38
- - acreate_structured_output(text_input: str, system_prompt: str, response_model:
39
- Type[BaseModel]) -> BaseModel
42
+ - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
43
+ - create_transcript(input) -> BaseModel: Transcribe audio files to text
44
+ - transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
40
45
  """
41
46
 
42
- name: str
43
- model: str
44
- api_key: str
45
47
  default_instructor_mode = "json_mode"
46
48
 
47
49
  def __init__(
48
50
  self,
49
- endpoint,
50
51
  api_key: str,
51
52
  model: str,
52
- api_version: str,
53
53
  max_completion_tokens: int,
54
+ endpoint: str = None,
55
+ api_version: str = None,
56
+ transcription_model: str = None,
54
57
  instructor_mode: str = None,
55
58
  fallback_model: str = None,
56
59
  fallback_api_key: str = None,
57
60
  fallback_endpoint: str = None,
58
61
  ):
59
- self.model = model
60
- self.api_key = api_key
61
- self.endpoint = endpoint
62
- self.api_version = api_version
63
- self.max_completion_tokens = max_completion_tokens
64
-
65
- self.fallback_model = fallback_model
66
- self.fallback_api_key = fallback_api_key
67
- self.fallback_endpoint = fallback_endpoint
68
-
62
+ super().__init__(
63
+ api_key=api_key,
64
+ model=model,
65
+ max_completion_tokens=max_completion_tokens,
66
+ name="Gemini",
67
+ endpoint=endpoint,
68
+ api_version=api_version,
69
+ transcription_model=transcription_model,
70
+ fallback_model=fallback_model,
71
+ fallback_api_key=fallback_api_key,
72
+ fallback_endpoint=fallback_endpoint,
73
+ )
69
74
  self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
70
75
 
71
76
  self.aclient = instructor.from_litellm(
72
77
  litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
73
78
  )
74
79
 
80
+ @observe(as_type="generation")
75
81
  @retry(
76
82
  stop=stop_after_delay(128),
77
83
  wait=wait_exponential_jitter(8, 128),
@@ -1,8 +1,10 @@
1
1
  """Adapter for Generic API LLM provider API"""
2
2
 
3
+ import base64
4
+ import mimetypes
3
5
  import litellm
4
6
  import instructor
5
- from typing import Type
7
+ from typing import Type, Optional
6
8
  from pydantic import BaseModel
7
9
  from openai import ContentFilterFinishReasonError
8
10
  from litellm.exceptions import ContentPolicyViolationError
@@ -12,6 +14,8 @@ from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
12
14
  from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
13
15
  LLMInterface,
14
16
  )
17
+ from cognee.infrastructure.files.utils.open_data_file import open_data_file
18
+ from cognee.modules.observability.get_observe import get_observe
15
19
  import logging
16
20
  from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
17
21
  from cognee.shared.logging_utils import get_logger
@@ -23,7 +27,12 @@ from tenacity import (
23
27
  before_sleep_log,
24
28
  )
25
29
 
30
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
31
+ TranscriptionReturnType,
32
+ )
33
+
26
34
  logger = get_logger()
35
+ observe = get_observe()
27
36
 
28
37
 
29
38
  class GenericAPIAdapter(LLMInterface):
@@ -39,18 +48,19 @@ class GenericAPIAdapter(LLMInterface):
39
48
  Type[BaseModel]) -> BaseModel
40
49
  """
41
50
 
42
- name: str
43
- model: str
44
- api_key: str
51
+ MAX_RETRIES = 5
45
52
  default_instructor_mode = "json_mode"
46
53
 
47
54
  def __init__(
48
55
  self,
49
- endpoint,
50
56
  api_key: str,
51
57
  model: str,
52
- name: str,
53
58
  max_completion_tokens: int,
59
+ name: str,
60
+ endpoint: str = None,
61
+ api_version: str = None,
62
+ transcription_model: str = None,
63
+ image_transcribe_model: str = None,
54
64
  instructor_mode: str = None,
55
65
  fallback_model: str = None,
56
66
  fallback_api_key: str = None,
@@ -59,9 +69,11 @@ class GenericAPIAdapter(LLMInterface):
59
69
  self.name = name
60
70
  self.model = model
61
71
  self.api_key = api_key
72
+ self.api_version = api_version
62
73
  self.endpoint = endpoint
63
74
  self.max_completion_tokens = max_completion_tokens
64
-
75
+ self.transcription_model = transcription_model or model
76
+ self.image_transcribe_model = image_transcribe_model or model
65
77
  self.fallback_model = fallback_model
66
78
  self.fallback_api_key = fallback_api_key
67
79
  self.fallback_endpoint = fallback_endpoint
@@ -72,6 +84,7 @@ class GenericAPIAdapter(LLMInterface):
72
84
  litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
73
85
  )
74
86
 
87
+ @observe(as_type="generation")
75
88
  @retry(
76
89
  stop=stop_after_delay(128),
77
90
  wait=wait_exponential_jitter(8, 128),
@@ -173,3 +186,115 @@ class GenericAPIAdapter(LLMInterface):
173
186
  raise ContentPolicyFilterError(
174
187
  f"The provided input contains content that is not aligned with our content policy: {text_input}"
175
188
  ) from error
189
+
190
+ @observe(as_type="transcription")
191
+ @retry(
192
+ stop=stop_after_delay(128),
193
+ wait=wait_exponential_jitter(2, 128),
194
+ retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
195
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
196
+ reraise=True,
197
+ )
198
+ async def create_transcript(self, input) -> TranscriptionReturnType:
199
+ """
200
+ Generate an audio transcript from a user query.
201
+
202
+ This method creates a transcript from the specified audio file, raising a
203
+ FileNotFoundError if the file does not exist. The audio file is processed and the
204
+ transcription is retrieved from the API.
205
+
206
+ Parameters:
207
+ -----------
208
+ - input: The path to the audio file that needs to be transcribed.
209
+
210
+ Returns:
211
+ --------
212
+ The generated transcription of the audio file.
213
+ """
214
+ async with open_data_file(input, mode="rb") as audio_file:
215
+ encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
216
+ mime_type, _ = mimetypes.guess_type(input)
217
+ if not mime_type or not mime_type.startswith("audio/"):
218
+ raise ValueError(
219
+ f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
220
+ )
221
+ response = await litellm.acompletion(
222
+ model=self.transcription_model,
223
+ messages=[
224
+ {
225
+ "role": "user",
226
+ "content": [
227
+ {
228
+ "type": "file",
229
+ "file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
230
+ },
231
+ {"type": "text", "text": "Transcribe the following audio precisely."},
232
+ ],
233
+ }
234
+ ],
235
+ api_key=self.api_key,
236
+ api_version=self.api_version,
237
+ max_completion_tokens=self.max_completion_tokens,
238
+ api_base=self.endpoint,
239
+ max_retries=self.MAX_RETRIES,
240
+ )
241
+
242
+ return TranscriptionReturnType(response.choices[0].message.content, response)
243
+
244
+ @observe(as_type="transcribe_image")
245
+ @retry(
246
+ stop=stop_after_delay(128),
247
+ wait=wait_exponential_jitter(2, 128),
248
+ retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
249
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
250
+ reraise=True,
251
+ )
252
+ async def transcribe_image(self, input) -> BaseModel:
253
+ """
254
+ Generate a transcription of an image from a user query.
255
+
256
+ This method encodes the image and sends a request to the API to obtain a
257
+ description of the contents of the image.
258
+
259
+ Parameters:
260
+ -----------
261
+ - input: The path to the image file that needs to be transcribed.
262
+
263
+ Returns:
264
+ --------
265
+ - BaseModel: A structured output generated by the model, returned as an instance of
266
+ BaseModel.
267
+ """
268
+ async with open_data_file(input, mode="rb") as image_file:
269
+ encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
270
+ mime_type, _ = mimetypes.guess_type(input)
271
+ if not mime_type or not mime_type.startswith("image/"):
272
+ raise ValueError(
273
+ f"Could not determine MIME type for image file: {input}. Is the extension correct?"
274
+ )
275
+ response = await litellm.acompletion(
276
+ model=self.image_transcribe_model,
277
+ messages=[
278
+ {
279
+ "role": "user",
280
+ "content": [
281
+ {
282
+ "type": "text",
283
+ "text": "What's in this image?",
284
+ },
285
+ {
286
+ "type": "image_url",
287
+ "image_url": {
288
+ "url": f"data:{mime_type};base64,{encoded_image}",
289
+ },
290
+ },
291
+ ],
292
+ }
293
+ ],
294
+ api_key=self.api_key,
295
+ api_base=self.endpoint,
296
+ api_version=self.api_version,
297
+ max_completion_tokens=300,
298
+ max_retries=self.MAX_RETRIES,
299
+ )
300
+ return response
@@ -103,7 +103,7 @@ def get_llm_client(raise_api_key_error: bool = True):
103
103
  llm_config.llm_api_key,
104
104
  llm_config.llm_model,
105
105
  "Ollama",
106
- max_completion_tokens=max_completion_tokens,
106
+ max_completion_tokens,
107
107
  instructor_mode=llm_config.llm_instructor_mode.lower(),
108
108
  )
109
109
 
@@ -113,8 +113,9 @@ def get_llm_client(raise_api_key_error: bool = True):
113
113
  )
114
114
 
115
115
  return AnthropicAdapter(
116
- max_completion_tokens=max_completion_tokens,
117
- model=llm_config.llm_model,
116
+ llm_config.llm_api_key,
117
+ llm_config.llm_model,
118
+ max_completion_tokens,
118
119
  instructor_mode=llm_config.llm_instructor_mode.lower(),
119
120
  )
120
121
 
@@ -127,11 +128,10 @@ def get_llm_client(raise_api_key_error: bool = True):
127
128
  )
128
129
 
129
130
  return GenericAPIAdapter(
130
- llm_config.llm_endpoint,
131
131
  llm_config.llm_api_key,
132
132
  llm_config.llm_model,
133
+ max_completion_tokens,
133
134
  "Custom",
134
- max_completion_tokens=max_completion_tokens,
135
135
  instructor_mode=llm_config.llm_instructor_mode.lower(),
136
136
  fallback_api_key=llm_config.fallback_api_key,
137
137
  fallback_endpoint=llm_config.fallback_endpoint,
@@ -3,18 +3,14 @@
3
3
  from typing import Type, Protocol
4
4
  from abc import abstractmethod
5
5
  from pydantic import BaseModel
6
- from cognee.infrastructure.llm.LLMGateway import LLMGateway
7
6
 
8
7
 
9
8
  class LLMInterface(Protocol):
10
9
  """
11
- Define an interface for LLM models with methods for structured output and prompt
12
- display.
10
+ Define an interface for LLM models with methods for structured output, multimodal processing, and prompt display.
13
11
 
14
12
  Methods:
15
- - acreate_structured_output(text_input: str, system_prompt: str, response_model:
16
- Type[BaseModel])
17
- - show_prompt(text_input: str, system_prompt: str)
13
+ - acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel])
18
14
  """
19
15
 
20
16
  @abstractmethod