cognee 0.5.0.dev1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. cognee/api/v1/add/add.py +2 -1
  2. cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
  3. cognee/api/v1/memify/routers/get_memify_router.py +1 -0
  4. cognee/infrastructure/databases/relational/config.py +16 -1
  5. cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
  6. cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
  7. cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
  8. cognee/infrastructure/llm/LLMGateway.py +0 -13
  9. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
  10. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
  11. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
  12. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
  13. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
  14. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
  15. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
  16. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
  17. cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
  18. cognee/modules/data/models/Data.py +2 -1
  19. cognee/modules/retrieval/triplet_retriever.py +1 -1
  20. cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
  21. cognee/tasks/ingestion/data_item.py +8 -0
  22. cognee/tasks/ingestion/ingest_data.py +12 -1
  23. cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
  24. cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
  25. cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
  26. cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
  27. cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
  28. cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
  29. cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
  30. cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
  31. cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
  32. cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
  33. cognee/tests/test_custom_data_label.py +68 -0
  34. cognee/tests/test_search_db.py +334 -181
  35. cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
  36. cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
  37. cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
  38. cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
  39. cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
  40. cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
  41. cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
  42. cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
  43. cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
  44. cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
  45. cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
  46. cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
  47. cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
  48. cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
  49. cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
  50. cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
  51. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
  52. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +56 -42
  53. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
  54. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
  55. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
  56. {cognee-0.5.0.dev1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
@@ -1,13 +1,13 @@
1
1
  import litellm
2
2
  import instructor
3
3
  from pydantic import BaseModel
4
- from typing import Type
4
+ from typing import Type, Optional
5
5
  from litellm import JSONSchemaValidationError
6
-
6
+ from cognee.infrastructure.files.utils.open_data_file import open_data_file
7
7
  from cognee.shared.logging_utils import get_logger
8
8
  from cognee.modules.observability.get_observe import get_observe
9
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
10
- LLMInterface,
9
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
10
+ GenericAPIAdapter,
11
11
  )
12
12
  from cognee.infrastructure.llm.config import get_llm_config
13
13
  from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
@@ -20,12 +20,14 @@ from tenacity import (
20
20
  retry_if_not_exception_type,
21
21
  before_sleep_log,
22
22
  )
23
+ from ..types import TranscriptionReturnType
24
+ from mistralai import Mistral
23
25
 
24
26
  logger = get_logger()
25
27
  observe = get_observe()
26
28
 
27
29
 
28
- class MistralAdapter(LLMInterface):
30
+ class MistralAdapter(GenericAPIAdapter):
29
31
  """
30
32
  Adapter for Mistral AI API, for structured output generation and prompt display.
31
33
 
@@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface):
34
36
  - show_prompt
35
37
  """
36
38
 
37
- name = "Mistral"
38
- model: str
39
- api_key: str
40
- max_completion_tokens: int
41
39
  default_instructor_mode = "mistral_tools"
42
40
 
43
41
  def __init__(
@@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface):
46
44
  model: str,
47
45
  max_completion_tokens: int,
48
46
  endpoint: str = None,
47
+ transcription_model: str = None,
48
+ image_transcribe_model: str = None,
49
49
  instructor_mode: str = None,
50
50
  ):
51
- from mistralai import Mistral
52
-
53
- self.model = model
54
- self.max_completion_tokens = max_completion_tokens
51
+ super().__init__(
52
+ api_key=api_key,
53
+ model=model,
54
+ max_completion_tokens=max_completion_tokens,
55
+ name="Mistral",
56
+ endpoint=endpoint,
57
+ transcription_model=transcription_model,
58
+ image_transcribe_model=image_transcribe_model,
59
+ )
55
60
 
56
61
  self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
57
62
 
@@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface):
60
65
  mode=instructor.Mode(self.instructor_mode),
61
66
  api_key=get_llm_config().llm_api_key,
62
67
  )
68
+ self.mistral_client = Mistral(api_key=self.api_key)
63
69
 
70
+ @observe(as_type="generation")
64
71
  @retry(
65
72
  stop=stop_after_delay(128),
66
73
  wait=wait_exponential_jitter(8, 128),
@@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface):
119
126
  logger.error(f"Schema validation failed: {str(e)}")
120
127
  logger.debug(f"Raw response: {e.raw_response}")
121
128
  raise ValueError(f"Response failed schema validation: {str(e)}")
129
+
130
+ @observe(as_type="transcription")
131
+ @retry(
132
+ stop=stop_after_delay(128),
133
+ wait=wait_exponential_jitter(2, 128),
134
+ retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
135
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
136
+ reraise=True,
137
+ )
138
+ async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
139
+ """
140
+ Generate an audio transcript from a user query.
141
+
142
+ This method creates a transcript from the specified audio file.
143
+ The audio file is processed and the transcription is retrieved from the API.
144
+
145
+ Parameters:
146
+ -----------
147
+ - input: The path to the audio file that needs to be transcribed.
148
+
149
+ Returns:
150
+ --------
151
+ The generated transcription of the audio file.
152
+ """
153
+ transcription_model = self.transcription_model
154
+ if self.transcription_model.startswith("mistral"):
155
+ transcription_model = self.transcription_model.split("/")[-1]
156
+ file_name = input.split("/")[-1]
157
+ async with open_data_file(input, mode="rb") as f:
158
+ transcription_response = self.mistral_client.audio.transcriptions.complete(
159
+ model=transcription_model,
160
+ file={
161
+ "content": f,
162
+ "file_name": file_name,
163
+ },
164
+ )
165
+
166
+ return TranscriptionReturnType(transcription_response.text, transcription_response)
@@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
12
12
  from cognee.infrastructure.files.utils.open_data_file import open_data_file
13
13
  from cognee.shared.logging_utils import get_logger
14
14
  from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
15
-
16
15
  from tenacity import (
17
16
  retry,
18
17
  stop_after_delay,
@@ -1,4 +1,3 @@
1
- import base64
2
1
  import litellm
3
2
  import instructor
4
3
  from typing import Type
@@ -16,8 +15,8 @@ from tenacity import (
16
15
  before_sleep_log,
17
16
  )
18
17
 
19
- from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
20
- LLMInterface,
18
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
19
+ GenericAPIAdapter,
21
20
  )
22
21
  from cognee.infrastructure.llm.exceptions import (
23
22
  ContentPolicyFilterError,
@@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
26
25
  from cognee.infrastructure.files.utils.open_data_file import open_data_file
27
26
  from cognee.modules.observability.get_observe import get_observe
28
27
  from cognee.shared.logging_utils import get_logger
28
+ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
29
+ TranscriptionReturnType,
30
+ )
29
31
 
30
32
  logger = get_logger()
31
33
 
32
34
  observe = get_observe()
33
35
 
34
36
 
35
- class OpenAIAdapter(LLMInterface):
37
+ class OpenAIAdapter(GenericAPIAdapter):
36
38
  """
37
39
  Adapter for OpenAI's GPT-3, GPT-4 API.
38
40
 
@@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface):
53
55
  - MAX_RETRIES
54
56
  """
55
57
 
56
- name = "OpenAI"
57
- model: str
58
- api_key: str
59
- api_version: str
60
58
  default_instructor_mode = "json_schema_mode"
61
-
62
59
  MAX_RETRIES = 5
63
60
 
64
61
  """Adapter for OpenAI's GPT-3, GPT=4 API"""
@@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface):
66
63
  def __init__(
67
64
  self,
68
65
  api_key: str,
69
- endpoint: str,
70
- api_version: str,
71
66
  model: str,
72
- transcription_model: str,
73
67
  max_completion_tokens: int,
68
+ endpoint: str = None,
69
+ api_version: str = None,
70
+ transcription_model: str = None,
74
71
  instructor_mode: str = None,
75
72
  streaming: bool = False,
76
73
  fallback_model: str = None,
77
74
  fallback_api_key: str = None,
78
75
  fallback_endpoint: str = None,
79
76
  ):
77
+ super().__init__(
78
+ api_key=api_key,
79
+ model=model,
80
+ max_completion_tokens=max_completion_tokens,
81
+ name="OpenAI",
82
+ endpoint=endpoint,
83
+ api_version=api_version,
84
+ transcription_model=transcription_model,
85
+ fallback_model=fallback_model,
86
+ fallback_api_key=fallback_api_key,
87
+ fallback_endpoint=fallback_endpoint,
88
+ )
80
89
  self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
81
90
  # TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
82
91
  # Make sure all new gpt models will work with this mode as well.
@@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface):
91
100
  self.aclient = instructor.from_litellm(litellm.acompletion)
92
101
  self.client = instructor.from_litellm(litellm.completion)
93
102
 
94
- self.transcription_model = transcription_model
95
- self.model = model
96
- self.api_key = api_key
97
- self.endpoint = endpoint
98
- self.api_version = api_version
99
- self.max_completion_tokens = max_completion_tokens
100
103
  self.streaming = streaming
101
104
 
102
- self.fallback_model = fallback_model
103
- self.fallback_api_key = fallback_api_key
104
- self.fallback_endpoint = fallback_endpoint
105
-
106
105
  @observe(as_type="generation")
107
106
  @retry(
108
107
  stop=stop_after_delay(128),
@@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface):
198
197
  f"The provided input contains content that is not aligned with our content policy: {text_input}"
199
198
  ) from error
200
199
 
201
- @observe
200
+ @observe(as_type="transcription")
202
201
  @retry(
203
202
  stop=stop_after_delay(128),
204
203
  wait=wait_exponential_jitter(2, 128),
@@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface):
206
205
  before_sleep=before_sleep_log(logger, logging.DEBUG),
207
206
  reraise=True,
208
207
  )
209
- def create_structured_output(
210
- self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
211
- ) -> BaseModel:
212
- """
213
- Generate a response from a user query.
214
-
215
- This method creates structured output by sending a synchronous request to the OpenAI API
216
- using the provided parameters to generate a completion based on the user input and
217
- system prompt.
218
-
219
- Parameters:
220
- -----------
221
-
222
- - text_input (str): The input text provided by the user for generating a response.
223
- - system_prompt (str): The system's prompt to guide the model's response.
224
- - response_model (Type[BaseModel]): The expected model type for the response.
225
-
226
- Returns:
227
- --------
228
-
229
- - BaseModel: A structured output generated by the model, returned as an instance of
230
- BaseModel.
231
- """
232
-
233
- return self.client.chat.completions.create(
234
- model=self.model,
235
- messages=[
236
- {
237
- "role": "user",
238
- "content": f"""{text_input}""",
239
- },
240
- {
241
- "role": "system",
242
- "content": system_prompt,
243
- },
244
- ],
245
- api_key=self.api_key,
246
- api_base=self.endpoint,
247
- api_version=self.api_version,
248
- response_model=response_model,
249
- max_retries=self.MAX_RETRIES,
250
- **kwargs,
251
- )
252
-
253
- @retry(
254
- stop=stop_after_delay(128),
255
- wait=wait_exponential_jitter(2, 128),
256
- retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
257
- before_sleep=before_sleep_log(logger, logging.DEBUG),
258
- reraise=True,
259
- )
260
- async def create_transcript(self, input, **kwargs):
208
+ async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
261
209
  """
262
210
  Generate an audio transcript from a user query.
263
211
 
@@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface):
286
234
  max_retries=self.MAX_RETRIES,
287
235
  **kwargs,
288
236
  )
237
+ return TranscriptionReturnType(transcription.text, transcription)
289
238
 
290
- return transcription
291
-
292
- @retry(
293
- stop=stop_after_delay(128),
294
- wait=wait_exponential_jitter(2, 128),
295
- retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
296
- before_sleep=before_sleep_log(logger, logging.DEBUG),
297
- reraise=True,
298
- )
299
- async def transcribe_image(self, input, **kwargs) -> BaseModel:
300
- """
301
- Generate a transcription of an image from a user query.
302
-
303
- This method encodes the image and sends a request to the OpenAI API to obtain a
304
- description of the contents of the image.
305
-
306
- Parameters:
307
- -----------
308
-
309
- - input: The path to the image file that needs to be transcribed.
310
-
311
- Returns:
312
- --------
313
-
314
- - BaseModel: A structured output generated by the model, returned as an instance of
315
- BaseModel.
316
- """
317
- async with open_data_file(input, mode="rb") as image_file:
318
- encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
319
-
320
- return litellm.completion(
321
- model=self.model,
322
- messages=[
323
- {
324
- "role": "user",
325
- "content": [
326
- {
327
- "type": "text",
328
- "text": "What's in this image?",
329
- },
330
- {
331
- "type": "image_url",
332
- "image_url": {
333
- "url": f"data:image/jpeg;base64,{encoded_image}",
334
- },
335
- },
336
- ],
337
- }
338
- ],
339
- api_key=self.api_key,
340
- api_base=self.endpoint,
341
- api_version=self.api_version,
342
- max_completion_tokens=300,
343
- max_retries=self.MAX_RETRIES,
344
- **kwargs,
345
- )
239
+ # transcribe_image is inherited from GenericAPIAdapter
@@ -0,0 +1,10 @@
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class TranscriptionReturnType:
5
+ text: str
6
+ payload: BaseModel
7
+
8
+ def __init__(self, text: str, payload: BaseModel):
9
+ self.text = text
10
+ self.payload = payload
@@ -13,7 +13,7 @@ class Data(Base):
13
13
  __tablename__ = "data"
14
14
 
15
15
  id = Column(UUID, primary_key=True, default=uuid4)
16
-
16
+ label = Column(String, nullable=True)
17
17
  name = Column(String)
18
18
  extension = Column(String)
19
19
  mime_type = Column(String)
@@ -49,6 +49,7 @@ class Data(Base):
49
49
  return {
50
50
  "id": str(self.id),
51
51
  "name": self.name,
52
+ "label": self.label,
52
53
  "extension": self.extension,
53
54
  "mimeType": self.mime_type,
54
55
  "rawDataLocation": self.raw_data_location,
@@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
36
36
  """Initialize retriever with optional custom prompt paths."""
37
37
  self.user_prompt_path = user_prompt_path
38
38
  self.system_prompt_path = system_prompt_path
39
- self.top_k = top_k if top_k is not None else 1
39
+ self.top_k = top_k if top_k is not None else 5
40
40
  self.system_prompt = system_prompt
41
41
 
42
42
  async def get_context(self, query: str) -> str:
@@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
16
16
 
17
17
 
18
18
  def format_triplets(edges):
19
- print("\n\n\n")
20
-
21
- def filter_attributes(obj, attributes):
22
- """Helper function to filter out non-None properties, including nested dicts."""
23
- result = {}
24
- for attr in attributes:
25
- value = getattr(obj, attr, None)
26
- if value is not None:
27
- # If the value is a dict, extract relevant keys from it
28
- if isinstance(value, dict):
29
- nested_values = {
30
- k: v for k, v in value.items() if k in attributes and v is not None
31
- }
32
- result[attr] = nested_values
33
- else:
34
- result[attr] = value
35
- return result
36
-
37
19
  triplets = []
38
20
  for edge in edges:
39
21
  node1 = edge.node1
@@ -0,0 +1,8 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional
3
+
4
+
5
+ @dataclass
6
+ class DataItem:
7
+ data: Any
8
+ label: Optional[str] = None
@@ -20,6 +20,7 @@ from cognee.modules.data.methods import (
20
20
 
21
21
  from .save_data_item_to_storage import save_data_item_to_storage
22
22
  from .data_item_to_text_file import data_item_to_text_file
23
+ from .data_item import DataItem
23
24
 
24
25
 
25
26
  async def ingest_data(
@@ -78,8 +79,16 @@ async def ingest_data(
78
79
  dataset_data_map = {str(data.id): True for data in dataset_data}
79
80
 
80
81
  for data_item in data:
82
+ # Support for DataItem (custom label + data wrapper)
83
+ current_label = None
84
+ underlying_data = data_item
85
+
86
+ if isinstance(data_item, DataItem):
87
+ underlying_data = data_item.data
88
+ current_label = data_item.label
89
+
81
90
  # Get file path of data item or create a file if it doesn't exist
82
- original_file_path = await save_data_item_to_storage(data_item)
91
+ original_file_path = await save_data_item_to_storage(underlying_data)
83
92
  # Transform file path to be OS usable
84
93
  actual_file_path = get_data_file_path(original_file_path)
85
94
 
@@ -139,6 +148,7 @@ async def ingest_data(
139
148
  data_point.external_metadata = ext_metadata
140
149
  data_point.node_set = json.dumps(node_set) if node_set else None
141
150
  data_point.tenant_id = user.tenant_id if user.tenant_id else None
151
+ data_point.label = current_label
142
152
 
143
153
  # Check if data is already in dataset
144
154
  if str(data_point.id) in dataset_data_map:
@@ -169,6 +179,7 @@ async def ingest_data(
169
179
  tenant_id=user.tenant_id if user.tenant_id else None,
170
180
  pipeline_status={},
171
181
  token_count=-1,
182
+ label=current_label,
172
183
  )
173
184
 
174
185
  new_datapoints.append(data_point)
@@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
9
9
  from pydantic_settings import BaseSettings, SettingsConfigDict
10
10
 
11
11
  from cognee.tasks.web_scraper.utils import fetch_page_content
12
+ from cognee.tasks.ingestion.data_item import DataItem
12
13
 
13
14
 
14
15
  logger = get_logger()
@@ -95,5 +96,9 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
95
96
  # data is text, save it to data storage and return the file path
96
97
  return await save_data_to_file(data_item)
97
98
 
99
+ if isinstance(data_item, DataItem):
100
+ # If instance is DataItem use the underlying data
101
+ return await save_data_item_to_storage(data_item.data)
102
+
98
103
  # data is not a supported type
99
104
  raise IngestionError(message=f"Data type not supported: {type(data_item)}")