cognee 0.5.1__py3-none-any.whl → 0.5.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +1 -0
- cognee/api/v1/memify/routers/get_memify_router.py +1 -0
- cognee/api/v1/search/search.py +0 -4
- cognee/infrastructure/databases/relational/config.py +16 -1
- cognee/infrastructure/databases/relational/create_relational_engine.py +13 -3
- cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +24 -2
- cognee/infrastructure/databases/vector/create_vector_engine.py +9 -2
- cognee/infrastructure/llm/LLMGateway.py +0 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +17 -12
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +31 -25
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +132 -7
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +5 -5
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/llm_interface.py +2 -6
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +58 -13
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +0 -1
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +25 -131
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/types.py +10 -0
- cognee/modules/data/models/Data.py +2 -1
- cognee/modules/retrieval/triplet_retriever.py +1 -1
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +0 -18
- cognee/modules/search/methods/search.py +18 -25
- cognee/tasks/ingestion/data_item.py +8 -0
- cognee/tasks/ingestion/ingest_data.py +12 -1
- cognee/tasks/ingestion/save_data_item_to_storage.py +5 -0
- cognee/tests/integration/retrieval/test_chunks_retriever.py +252 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever.py +268 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_context_extension.py +226 -0
- cognee/tests/integration/retrieval/test_graph_completion_retriever_cot.py +218 -0
- cognee/tests/integration/retrieval/test_rag_completion_retriever.py +254 -0
- cognee/tests/{unit/modules/retrieval/structured_output_test.py → integration/retrieval/test_structured_output.py} +87 -77
- cognee/tests/integration/retrieval/test_summaries_retriever.py +184 -0
- cognee/tests/integration/retrieval/test_temporal_retriever.py +306 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +35 -0
- cognee/tests/test_custom_data_label.py +68 -0
- cognee/tests/test_search_db.py +334 -181
- cognee/tests/unit/eval_framework/benchmark_adapters_test.py +25 -0
- cognee/tests/unit/eval_framework/corpus_builder_test.py +33 -4
- cognee/tests/unit/infrastructure/databases/relational/test_RelationalConfig.py +69 -0
- cognee/tests/unit/modules/retrieval/chunks_retriever_test.py +181 -199
- cognee/tests/unit/modules/retrieval/conversation_history_test.py +338 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +454 -162
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +674 -156
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_test.py +625 -200
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +319 -203
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +189 -155
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +539 -58
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +218 -9
- cognee/tests/unit/modules/retrieval/test_completion.py +343 -0
- cognee/tests/unit/modules/retrieval/test_graph_summary_completion_retriever.py +157 -0
- cognee/tests/unit/modules/retrieval/test_user_qa_feedback.py +312 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +246 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/METADATA +1 -1
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/RECORD +58 -45
- cognee/tests/unit/modules/search/test_search.py +0 -100
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/WHEEL +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/entry_points.txt +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.5.1.dist-info → cognee-0.5.1.dev0.dist-info}/licenses/NOTICE.md +0 -0
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import litellm
|
|
2
2
|
import instructor
|
|
3
3
|
from pydantic import BaseModel
|
|
4
|
-
from typing import Type
|
|
4
|
+
from typing import Type, Optional
|
|
5
5
|
from litellm import JSONSchemaValidationError
|
|
6
|
-
|
|
6
|
+
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
7
7
|
from cognee.shared.logging_utils import get_logger
|
|
8
8
|
from cognee.modules.observability.get_observe import get_observe
|
|
9
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.
|
|
10
|
-
|
|
9
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
|
10
|
+
GenericAPIAdapter,
|
|
11
11
|
)
|
|
12
12
|
from cognee.infrastructure.llm.config import get_llm_config
|
|
13
13
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
@@ -20,12 +20,14 @@ from tenacity import (
|
|
|
20
20
|
retry_if_not_exception_type,
|
|
21
21
|
before_sleep_log,
|
|
22
22
|
)
|
|
23
|
+
from ..types import TranscriptionReturnType
|
|
24
|
+
from mistralai import Mistral
|
|
23
25
|
|
|
24
26
|
logger = get_logger()
|
|
25
27
|
observe = get_observe()
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
class MistralAdapter(
|
|
30
|
+
class MistralAdapter(GenericAPIAdapter):
|
|
29
31
|
"""
|
|
30
32
|
Adapter for Mistral AI API, for structured output generation and prompt display.
|
|
31
33
|
|
|
@@ -34,10 +36,6 @@ class MistralAdapter(LLMInterface):
|
|
|
34
36
|
- show_prompt
|
|
35
37
|
"""
|
|
36
38
|
|
|
37
|
-
name = "Mistral"
|
|
38
|
-
model: str
|
|
39
|
-
api_key: str
|
|
40
|
-
max_completion_tokens: int
|
|
41
39
|
default_instructor_mode = "mistral_tools"
|
|
42
40
|
|
|
43
41
|
def __init__(
|
|
@@ -46,12 +44,19 @@ class MistralAdapter(LLMInterface):
|
|
|
46
44
|
model: str,
|
|
47
45
|
max_completion_tokens: int,
|
|
48
46
|
endpoint: str = None,
|
|
47
|
+
transcription_model: str = None,
|
|
48
|
+
image_transcribe_model: str = None,
|
|
49
49
|
instructor_mode: str = None,
|
|
50
50
|
):
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
51
|
+
super().__init__(
|
|
52
|
+
api_key=api_key,
|
|
53
|
+
model=model,
|
|
54
|
+
max_completion_tokens=max_completion_tokens,
|
|
55
|
+
name="Mistral",
|
|
56
|
+
endpoint=endpoint,
|
|
57
|
+
transcription_model=transcription_model,
|
|
58
|
+
image_transcribe_model=image_transcribe_model,
|
|
59
|
+
)
|
|
55
60
|
|
|
56
61
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
57
62
|
|
|
@@ -60,7 +65,9 @@ class MistralAdapter(LLMInterface):
|
|
|
60
65
|
mode=instructor.Mode(self.instructor_mode),
|
|
61
66
|
api_key=get_llm_config().llm_api_key,
|
|
62
67
|
)
|
|
68
|
+
self.mistral_client = Mistral(api_key=self.api_key)
|
|
63
69
|
|
|
70
|
+
@observe(as_type="generation")
|
|
64
71
|
@retry(
|
|
65
72
|
stop=stop_after_delay(128),
|
|
66
73
|
wait=wait_exponential_jitter(8, 128),
|
|
@@ -119,3 +126,41 @@ class MistralAdapter(LLMInterface):
|
|
|
119
126
|
logger.error(f"Schema validation failed: {str(e)}")
|
|
120
127
|
logger.debug(f"Raw response: {e.raw_response}")
|
|
121
128
|
raise ValueError(f"Response failed schema validation: {str(e)}")
|
|
129
|
+
|
|
130
|
+
@observe(as_type="transcription")
|
|
131
|
+
@retry(
|
|
132
|
+
stop=stop_after_delay(128),
|
|
133
|
+
wait=wait_exponential_jitter(2, 128),
|
|
134
|
+
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
135
|
+
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
136
|
+
reraise=True,
|
|
137
|
+
)
|
|
138
|
+
async def create_transcript(self, input) -> Optional[TranscriptionReturnType]:
|
|
139
|
+
"""
|
|
140
|
+
Generate an audio transcript from a user query.
|
|
141
|
+
|
|
142
|
+
This method creates a transcript from the specified audio file.
|
|
143
|
+
The audio file is processed and the transcription is retrieved from the API.
|
|
144
|
+
|
|
145
|
+
Parameters:
|
|
146
|
+
-----------
|
|
147
|
+
- input: The path to the audio file that needs to be transcribed.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
--------
|
|
151
|
+
The generated transcription of the audio file.
|
|
152
|
+
"""
|
|
153
|
+
transcription_model = self.transcription_model
|
|
154
|
+
if self.transcription_model.startswith("mistral"):
|
|
155
|
+
transcription_model = self.transcription_model.split("/")[-1]
|
|
156
|
+
file_name = input.split("/")[-1]
|
|
157
|
+
async with open_data_file(input, mode="rb") as f:
|
|
158
|
+
transcription_response = self.mistral_client.audio.transcriptions.complete(
|
|
159
|
+
model=transcription_model,
|
|
160
|
+
file={
|
|
161
|
+
"content": f,
|
|
162
|
+
"file_name": file_name,
|
|
163
|
+
},
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return TranscriptionReturnType(transcription_response.text, transcription_response)
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py
CHANGED
|
@@ -12,7 +12,6 @@ from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.ll
|
|
|
12
12
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
13
13
|
from cognee.shared.logging_utils import get_logger
|
|
14
14
|
from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
15
|
-
|
|
16
15
|
from tenacity import (
|
|
17
16
|
retry,
|
|
18
17
|
stop_after_delay,
|
cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import base64
|
|
2
1
|
import litellm
|
|
3
2
|
import instructor
|
|
4
3
|
from typing import Type
|
|
@@ -16,8 +15,8 @@ from tenacity import (
|
|
|
16
15
|
before_sleep_log,
|
|
17
16
|
)
|
|
18
17
|
|
|
19
|
-
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.
|
|
20
|
-
|
|
18
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
|
|
19
|
+
GenericAPIAdapter,
|
|
21
20
|
)
|
|
22
21
|
from cognee.infrastructure.llm.exceptions import (
|
|
23
22
|
ContentPolicyFilterError,
|
|
@@ -26,13 +25,16 @@ from cognee.shared.rate_limiting import llm_rate_limiter_context_manager
|
|
|
26
25
|
from cognee.infrastructure.files.utils.open_data_file import open_data_file
|
|
27
26
|
from cognee.modules.observability.get_observe import get_observe
|
|
28
27
|
from cognee.shared.logging_utils import get_logger
|
|
28
|
+
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.types import (
|
|
29
|
+
TranscriptionReturnType,
|
|
30
|
+
)
|
|
29
31
|
|
|
30
32
|
logger = get_logger()
|
|
31
33
|
|
|
32
34
|
observe = get_observe()
|
|
33
35
|
|
|
34
36
|
|
|
35
|
-
class OpenAIAdapter(
|
|
37
|
+
class OpenAIAdapter(GenericAPIAdapter):
|
|
36
38
|
"""
|
|
37
39
|
Adapter for OpenAI's GPT-3, GPT-4 API.
|
|
38
40
|
|
|
@@ -53,12 +55,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
53
55
|
- MAX_RETRIES
|
|
54
56
|
"""
|
|
55
57
|
|
|
56
|
-
name = "OpenAI"
|
|
57
|
-
model: str
|
|
58
|
-
api_key: str
|
|
59
|
-
api_version: str
|
|
60
58
|
default_instructor_mode = "json_schema_mode"
|
|
61
|
-
|
|
62
59
|
MAX_RETRIES = 5
|
|
63
60
|
|
|
64
61
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
|
@@ -66,17 +63,29 @@ class OpenAIAdapter(LLMInterface):
|
|
|
66
63
|
def __init__(
|
|
67
64
|
self,
|
|
68
65
|
api_key: str,
|
|
69
|
-
endpoint: str,
|
|
70
|
-
api_version: str,
|
|
71
66
|
model: str,
|
|
72
|
-
transcription_model: str,
|
|
73
67
|
max_completion_tokens: int,
|
|
68
|
+
endpoint: str = None,
|
|
69
|
+
api_version: str = None,
|
|
70
|
+
transcription_model: str = None,
|
|
74
71
|
instructor_mode: str = None,
|
|
75
72
|
streaming: bool = False,
|
|
76
73
|
fallback_model: str = None,
|
|
77
74
|
fallback_api_key: str = None,
|
|
78
75
|
fallback_endpoint: str = None,
|
|
79
76
|
):
|
|
77
|
+
super().__init__(
|
|
78
|
+
api_key=api_key,
|
|
79
|
+
model=model,
|
|
80
|
+
max_completion_tokens=max_completion_tokens,
|
|
81
|
+
name="OpenAI",
|
|
82
|
+
endpoint=endpoint,
|
|
83
|
+
api_version=api_version,
|
|
84
|
+
transcription_model=transcription_model,
|
|
85
|
+
fallback_model=fallback_model,
|
|
86
|
+
fallback_api_key=fallback_api_key,
|
|
87
|
+
fallback_endpoint=fallback_endpoint,
|
|
88
|
+
)
|
|
80
89
|
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode
|
|
81
90
|
# TODO: With gpt5 series models OpenAI expects JSON_SCHEMA as a mode for structured outputs.
|
|
82
91
|
# Make sure all new gpt models will work with this mode as well.
|
|
@@ -91,18 +100,8 @@ class OpenAIAdapter(LLMInterface):
|
|
|
91
100
|
self.aclient = instructor.from_litellm(litellm.acompletion)
|
|
92
101
|
self.client = instructor.from_litellm(litellm.completion)
|
|
93
102
|
|
|
94
|
-
self.transcription_model = transcription_model
|
|
95
|
-
self.model = model
|
|
96
|
-
self.api_key = api_key
|
|
97
|
-
self.endpoint = endpoint
|
|
98
|
-
self.api_version = api_version
|
|
99
|
-
self.max_completion_tokens = max_completion_tokens
|
|
100
103
|
self.streaming = streaming
|
|
101
104
|
|
|
102
|
-
self.fallback_model = fallback_model
|
|
103
|
-
self.fallback_api_key = fallback_api_key
|
|
104
|
-
self.fallback_endpoint = fallback_endpoint
|
|
105
|
-
|
|
106
105
|
@observe(as_type="generation")
|
|
107
106
|
@retry(
|
|
108
107
|
stop=stop_after_delay(128),
|
|
@@ -198,7 +197,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
198
197
|
f"The provided input contains content that is not aligned with our content policy: {text_input}"
|
|
199
198
|
) from error
|
|
200
199
|
|
|
201
|
-
@observe
|
|
200
|
+
@observe(as_type="transcription")
|
|
202
201
|
@retry(
|
|
203
202
|
stop=stop_after_delay(128),
|
|
204
203
|
wait=wait_exponential_jitter(2, 128),
|
|
@@ -206,58 +205,7 @@ class OpenAIAdapter(LLMInterface):
|
|
|
206
205
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
207
206
|
reraise=True,
|
|
208
207
|
)
|
|
209
|
-
def
|
|
210
|
-
self, text_input: str, system_prompt: str, response_model: Type[BaseModel], **kwargs
|
|
211
|
-
) -> BaseModel:
|
|
212
|
-
"""
|
|
213
|
-
Generate a response from a user query.
|
|
214
|
-
|
|
215
|
-
This method creates structured output by sending a synchronous request to the OpenAI API
|
|
216
|
-
using the provided parameters to generate a completion based on the user input and
|
|
217
|
-
system prompt.
|
|
218
|
-
|
|
219
|
-
Parameters:
|
|
220
|
-
-----------
|
|
221
|
-
|
|
222
|
-
- text_input (str): The input text provided by the user for generating a response.
|
|
223
|
-
- system_prompt (str): The system's prompt to guide the model's response.
|
|
224
|
-
- response_model (Type[BaseModel]): The expected model type for the response.
|
|
225
|
-
|
|
226
|
-
Returns:
|
|
227
|
-
--------
|
|
228
|
-
|
|
229
|
-
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
230
|
-
BaseModel.
|
|
231
|
-
"""
|
|
232
|
-
|
|
233
|
-
return self.client.chat.completions.create(
|
|
234
|
-
model=self.model,
|
|
235
|
-
messages=[
|
|
236
|
-
{
|
|
237
|
-
"role": "user",
|
|
238
|
-
"content": f"""{text_input}""",
|
|
239
|
-
},
|
|
240
|
-
{
|
|
241
|
-
"role": "system",
|
|
242
|
-
"content": system_prompt,
|
|
243
|
-
},
|
|
244
|
-
],
|
|
245
|
-
api_key=self.api_key,
|
|
246
|
-
api_base=self.endpoint,
|
|
247
|
-
api_version=self.api_version,
|
|
248
|
-
response_model=response_model,
|
|
249
|
-
max_retries=self.MAX_RETRIES,
|
|
250
|
-
**kwargs,
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
@retry(
|
|
254
|
-
stop=stop_after_delay(128),
|
|
255
|
-
wait=wait_exponential_jitter(2, 128),
|
|
256
|
-
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
257
|
-
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
258
|
-
reraise=True,
|
|
259
|
-
)
|
|
260
|
-
async def create_transcript(self, input, **kwargs):
|
|
208
|
+
async def create_transcript(self, input, **kwargs) -> TranscriptionReturnType:
|
|
261
209
|
"""
|
|
262
210
|
Generate an audio transcript from a user query.
|
|
263
211
|
|
|
@@ -286,60 +234,6 @@ class OpenAIAdapter(LLMInterface):
|
|
|
286
234
|
max_retries=self.MAX_RETRIES,
|
|
287
235
|
**kwargs,
|
|
288
236
|
)
|
|
237
|
+
return TranscriptionReturnType(transcription.text, transcription)
|
|
289
238
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
@retry(
|
|
293
|
-
stop=stop_after_delay(128),
|
|
294
|
-
wait=wait_exponential_jitter(2, 128),
|
|
295
|
-
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
|
|
296
|
-
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
297
|
-
reraise=True,
|
|
298
|
-
)
|
|
299
|
-
async def transcribe_image(self, input, **kwargs) -> BaseModel:
|
|
300
|
-
"""
|
|
301
|
-
Generate a transcription of an image from a user query.
|
|
302
|
-
|
|
303
|
-
This method encodes the image and sends a request to the OpenAI API to obtain a
|
|
304
|
-
description of the contents of the image.
|
|
305
|
-
|
|
306
|
-
Parameters:
|
|
307
|
-
-----------
|
|
308
|
-
|
|
309
|
-
- input: The path to the image file that needs to be transcribed.
|
|
310
|
-
|
|
311
|
-
Returns:
|
|
312
|
-
--------
|
|
313
|
-
|
|
314
|
-
- BaseModel: A structured output generated by the model, returned as an instance of
|
|
315
|
-
BaseModel.
|
|
316
|
-
"""
|
|
317
|
-
async with open_data_file(input, mode="rb") as image_file:
|
|
318
|
-
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
|
|
319
|
-
|
|
320
|
-
return litellm.completion(
|
|
321
|
-
model=self.model,
|
|
322
|
-
messages=[
|
|
323
|
-
{
|
|
324
|
-
"role": "user",
|
|
325
|
-
"content": [
|
|
326
|
-
{
|
|
327
|
-
"type": "text",
|
|
328
|
-
"text": "What's in this image?",
|
|
329
|
-
},
|
|
330
|
-
{
|
|
331
|
-
"type": "image_url",
|
|
332
|
-
"image_url": {
|
|
333
|
-
"url": f"data:image/jpeg;base64,{encoded_image}",
|
|
334
|
-
},
|
|
335
|
-
},
|
|
336
|
-
],
|
|
337
|
-
}
|
|
338
|
-
],
|
|
339
|
-
api_key=self.api_key,
|
|
340
|
-
api_base=self.endpoint,
|
|
341
|
-
api_version=self.api_version,
|
|
342
|
-
max_completion_tokens=300,
|
|
343
|
-
max_retries=self.MAX_RETRIES,
|
|
344
|
-
**kwargs,
|
|
345
|
-
)
|
|
239
|
+
# transcribe_image is inherited from GenericAPIAdapter
|
|
@@ -13,7 +13,7 @@ class Data(Base):
|
|
|
13
13
|
__tablename__ = "data"
|
|
14
14
|
|
|
15
15
|
id = Column(UUID, primary_key=True, default=uuid4)
|
|
16
|
-
|
|
16
|
+
label = Column(String, nullable=True)
|
|
17
17
|
name = Column(String)
|
|
18
18
|
extension = Column(String)
|
|
19
19
|
mime_type = Column(String)
|
|
@@ -49,6 +49,7 @@ class Data(Base):
|
|
|
49
49
|
return {
|
|
50
50
|
"id": str(self.id),
|
|
51
51
|
"name": self.name,
|
|
52
|
+
"label": self.label,
|
|
52
53
|
"extension": self.extension,
|
|
53
54
|
"mimeType": self.mime_type,
|
|
54
55
|
"rawDataLocation": self.raw_data_location,
|
|
@@ -36,7 +36,7 @@ class TripletRetriever(BaseRetriever):
|
|
|
36
36
|
"""Initialize retriever with optional custom prompt paths."""
|
|
37
37
|
self.user_prompt_path = user_prompt_path
|
|
38
38
|
self.system_prompt_path = system_prompt_path
|
|
39
|
-
self.top_k = top_k if top_k is not None else
|
|
39
|
+
self.top_k = top_k if top_k is not None else 5
|
|
40
40
|
self.system_prompt = system_prompt
|
|
41
41
|
|
|
42
42
|
async def get_context(self, query: str) -> str:
|
|
@@ -16,24 +16,6 @@ logger = get_logger(level=ERROR)
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
def format_triplets(edges):
|
|
19
|
-
print("\n\n\n")
|
|
20
|
-
|
|
21
|
-
def filter_attributes(obj, attributes):
|
|
22
|
-
"""Helper function to filter out non-None properties, including nested dicts."""
|
|
23
|
-
result = {}
|
|
24
|
-
for attr in attributes:
|
|
25
|
-
value = getattr(obj, attr, None)
|
|
26
|
-
if value is not None:
|
|
27
|
-
# If the value is a dict, extract relevant keys from it
|
|
28
|
-
if isinstance(value, dict):
|
|
29
|
-
nested_values = {
|
|
30
|
-
k: v for k, v in value.items() if k in attributes and v is not None
|
|
31
|
-
}
|
|
32
|
-
result[attr] = nested_values
|
|
33
|
-
else:
|
|
34
|
-
result[attr] = value
|
|
35
|
-
return result
|
|
36
|
-
|
|
37
19
|
triplets = []
|
|
38
20
|
for edge in edges:
|
|
39
21
|
node1 = edge.node1
|
|
@@ -49,7 +49,6 @@ async def search(
|
|
|
49
49
|
session_id: Optional[str] = None,
|
|
50
50
|
wide_search_top_k: Optional[int] = 100,
|
|
51
51
|
triplet_distance_penalty: Optional[float] = 3.5,
|
|
52
|
-
verbose: bool = False,
|
|
53
52
|
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
|
54
53
|
"""
|
|
55
54
|
|
|
@@ -141,7 +140,6 @@ async def search(
|
|
|
141
140
|
)
|
|
142
141
|
|
|
143
142
|
if use_combined_context:
|
|
144
|
-
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
|
145
143
|
prepared_search_results = await prepare_search_result(
|
|
146
144
|
search_results[0] if isinstance(search_results, list) else search_results
|
|
147
145
|
)
|
|
@@ -175,30 +173,25 @@ async def search(
|
|
|
175
173
|
datasets = prepared_search_results["datasets"]
|
|
176
174
|
|
|
177
175
|
if only_context:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
return_value.append(search_result_dict)
|
|
176
|
+
return_value.append(
|
|
177
|
+
{
|
|
178
|
+
"search_result": [context] if context else None,
|
|
179
|
+
"dataset_id": datasets[0].id,
|
|
180
|
+
"dataset_name": datasets[0].name,
|
|
181
|
+
"dataset_tenant_id": datasets[0].tenant_id,
|
|
182
|
+
"graphs": graphs,
|
|
183
|
+
}
|
|
184
|
+
)
|
|
189
185
|
else:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
return_value.append(search_result_dict)
|
|
201
|
-
|
|
186
|
+
return_value.append(
|
|
187
|
+
{
|
|
188
|
+
"search_result": [result] if result else None,
|
|
189
|
+
"dataset_id": datasets[0].id,
|
|
190
|
+
"dataset_name": datasets[0].name,
|
|
191
|
+
"dataset_tenant_id": datasets[0].tenant_id,
|
|
192
|
+
"graphs": graphs,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
202
195
|
return return_value
|
|
203
196
|
else:
|
|
204
197
|
return_value = []
|
|
@@ -20,6 +20,7 @@ from cognee.modules.data.methods import (
|
|
|
20
20
|
|
|
21
21
|
from .save_data_item_to_storage import save_data_item_to_storage
|
|
22
22
|
from .data_item_to_text_file import data_item_to_text_file
|
|
23
|
+
from .data_item import DataItem
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
async def ingest_data(
|
|
@@ -78,8 +79,16 @@ async def ingest_data(
|
|
|
78
79
|
dataset_data_map = {str(data.id): True for data in dataset_data}
|
|
79
80
|
|
|
80
81
|
for data_item in data:
|
|
82
|
+
# Support for DataItem (custom label + data wrapper)
|
|
83
|
+
current_label = None
|
|
84
|
+
underlying_data = data_item
|
|
85
|
+
|
|
86
|
+
if isinstance(data_item, DataItem):
|
|
87
|
+
underlying_data = data_item.data
|
|
88
|
+
current_label = data_item.label
|
|
89
|
+
|
|
81
90
|
# Get file path of data item or create a file if it doesn't exist
|
|
82
|
-
original_file_path = await save_data_item_to_storage(
|
|
91
|
+
original_file_path = await save_data_item_to_storage(underlying_data)
|
|
83
92
|
# Transform file path to be OS usable
|
|
84
93
|
actual_file_path = get_data_file_path(original_file_path)
|
|
85
94
|
|
|
@@ -139,6 +148,7 @@ async def ingest_data(
|
|
|
139
148
|
data_point.external_metadata = ext_metadata
|
|
140
149
|
data_point.node_set = json.dumps(node_set) if node_set else None
|
|
141
150
|
data_point.tenant_id = user.tenant_id if user.tenant_id else None
|
|
151
|
+
data_point.label = current_label
|
|
142
152
|
|
|
143
153
|
# Check if data is already in dataset
|
|
144
154
|
if str(data_point.id) in dataset_data_map:
|
|
@@ -169,6 +179,7 @@ async def ingest_data(
|
|
|
169
179
|
tenant_id=user.tenant_id if user.tenant_id else None,
|
|
170
180
|
pipeline_status={},
|
|
171
181
|
token_count=-1,
|
|
182
|
+
label=current_label,
|
|
172
183
|
)
|
|
173
184
|
|
|
174
185
|
new_datapoints.append(data_point)
|
|
@@ -9,6 +9,7 @@ from cognee.shared.logging_utils import get_logger
|
|
|
9
9
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
10
10
|
|
|
11
11
|
from cognee.tasks.web_scraper.utils import fetch_page_content
|
|
12
|
+
from cognee.tasks.ingestion.data_item import DataItem
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
logger = get_logger()
|
|
@@ -95,5 +96,9 @@ async def save_data_item_to_storage(data_item: Union[BinaryIO, str, Any]) -> str
|
|
|
95
96
|
# data is text, save it to data storage and return the file path
|
|
96
97
|
return await save_data_to_file(data_item)
|
|
97
98
|
|
|
99
|
+
if isinstance(data_item, DataItem):
|
|
100
|
+
# If instance is DataItem use the underlying data
|
|
101
|
+
return await save_data_item_to_storage(data_item.data)
|
|
102
|
+
|
|
98
103
|
# data is not a supported type
|
|
99
104
|
raise IngestionError(message=f"Data type not supported: {type(data_item)}")
|