cognee 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/client.py +9 -5
- cognee/api/v1/add/add.py +2 -1
- cognee/api/v1/add/routers/get_add_router.py +3 -1
- cognee/api/v1/cognify/cognify.py +24 -16
- cognee/api/v1/cognify/routers/__init__.py +0 -1
- cognee/api/v1/cognify/routers/get_cognify_router.py +30 -1
- cognee/api/v1/datasets/routers/get_datasets_router.py +3 -3
- cognee/api/v1/ontologies/__init__.py +4 -0
- cognee/api/v1/ontologies/ontologies.py +158 -0
- cognee/api/v1/ontologies/routers/__init__.py +0 -0
- cognee/api/v1/ontologies/routers/get_ontology_router.py +109 -0
- cognee/api/v1/permissions/routers/get_permissions_router.py +41 -1
- cognee/api/v1/search/search.py +4 -0
- cognee/api/v1/ui/node_setup.py +360 -0
- cognee/api/v1/ui/npm_utils.py +50 -0
- cognee/api/v1/ui/ui.py +38 -68
- cognee/cli/commands/cognify_command.py +8 -1
- cognee/cli/config.py +1 -1
- cognee/context_global_variables.py +86 -9
- cognee/eval_framework/Dockerfile +29 -0
- cognee/eval_framework/answer_generation/answer_generation_executor.py +10 -0
- cognee/eval_framework/answer_generation/run_question_answering_module.py +1 -1
- cognee/eval_framework/corpus_builder/task_getters/get_cascade_graph_tasks.py +0 -2
- cognee/eval_framework/corpus_builder/task_getters/get_default_tasks_by_indices.py +4 -4
- cognee/eval_framework/eval_config.py +2 -2
- cognee/eval_framework/modal_run_eval.py +16 -28
- cognee/infrastructure/databases/cache/config.py +3 -1
- cognee/infrastructure/databases/cache/fscache/FsCacheAdapter.py +151 -0
- cognee/infrastructure/databases/cache/get_cache_engine.py +20 -10
- cognee/infrastructure/databases/dataset_database_handler/__init__.py +3 -0
- cognee/infrastructure/databases/dataset_database_handler/dataset_database_handler_interface.py +80 -0
- cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +18 -0
- cognee/infrastructure/databases/dataset_database_handler/use_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/exceptions/exceptions.py +16 -0
- cognee/infrastructure/databases/graph/config.py +7 -0
- cognee/infrastructure/databases/graph/get_graph_engine.py +3 -0
- cognee/infrastructure/databases/graph/graph_db_interface.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/KuzuDatasetDatabaseHandler.py +81 -0
- cognee/infrastructure/databases/graph/kuzu/adapter.py +228 -0
- cognee/infrastructure/databases/graph/neo4j_driver/Neo4jAuraDevDatasetDatabaseHandler.py +168 -0
- cognee/infrastructure/databases/graph/neo4j_driver/adapter.py +80 -1
- cognee/infrastructure/databases/hybrid/neptune_analytics/NeptuneAnalyticsAdapter.py +9 -0
- cognee/infrastructure/databases/utils/__init__.py +3 -0
- cognee/infrastructure/databases/utils/get_graph_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +66 -18
- cognee/infrastructure/databases/utils/get_vector_dataset_database_handler.py +10 -0
- cognee/infrastructure/databases/utils/resolve_dataset_database_connection_info.py +30 -0
- cognee/infrastructure/databases/vector/config.py +5 -0
- cognee/infrastructure/databases/vector/create_vector_engine.py +6 -1
- cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +8 -6
- cognee/infrastructure/databases/vector/embeddings/LiteLLMEmbeddingEngine.py +9 -7
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +11 -13
- cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +2 -0
- cognee/infrastructure/databases/vector/lancedb/LanceDBDatasetDatabaseHandler.py +50 -0
- cognee/infrastructure/databases/vector/vector_db_interface.py +35 -0
- cognee/infrastructure/engine/models/Edge.py +13 -1
- cognee/infrastructure/files/storage/s3_config.py +2 -0
- cognee/infrastructure/files/utils/guess_file_type.py +4 -0
- cognee/infrastructure/llm/LLMGateway.py +5 -2
- cognee/infrastructure/llm/config.py +37 -0
- cognee/infrastructure/llm/extraction/knowledge_graph/extract_content_graph.py +2 -2
- cognee/infrastructure/llm/structured_output_framework/baml/baml_src/extraction/acreate_structured_output.py +23 -8
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/anthropic/adapter.py +22 -18
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/__init__.py +5 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/bedrock/adapter.py +153 -0
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/gemini/adapter.py +47 -38
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/generic_llm_api/adapter.py +46 -37
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/get_llm_client.py +20 -10
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/mistral/adapter.py +23 -11
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/ollama/adapter.py +36 -23
- cognee/infrastructure/llm/structured_output_framework/litellm_instructor/llm/openai/adapter.py +47 -36
- cognee/infrastructure/loaders/LoaderEngine.py +1 -0
- cognee/infrastructure/loaders/core/__init__.py +2 -1
- cognee/infrastructure/loaders/core/csv_loader.py +93 -0
- cognee/infrastructure/loaders/core/text_loader.py +1 -2
- cognee/infrastructure/loaders/external/advanced_pdf_loader.py +0 -9
- cognee/infrastructure/loaders/supported_loaders.py +2 -1
- cognee/memify_pipelines/create_triplet_embeddings.py +53 -0
- cognee/memify_pipelines/persist_sessions_in_knowledge_graph.py +55 -0
- cognee/modules/chunking/CsvChunker.py +35 -0
- cognee/modules/chunking/models/DocumentChunk.py +2 -1
- cognee/modules/chunking/text_chunker_with_overlap.py +124 -0
- cognee/modules/cognify/config.py +2 -0
- cognee/modules/data/deletion/prune_system.py +52 -2
- cognee/modules/data/methods/__init__.py +1 -0
- cognee/modules/data/methods/create_dataset.py +4 -2
- cognee/modules/data/methods/delete_dataset.py +26 -0
- cognee/modules/data/methods/get_dataset_ids.py +5 -1
- cognee/modules/data/methods/get_unique_data_id.py +68 -0
- cognee/modules/data/methods/get_unique_dataset_id.py +66 -4
- cognee/modules/data/models/Dataset.py +2 -0
- cognee/modules/data/processing/document_types/CsvDocument.py +33 -0
- cognee/modules/data/processing/document_types/__init__.py +1 -0
- cognee/modules/engine/models/Triplet.py +9 -0
- cognee/modules/engine/models/__init__.py +1 -0
- cognee/modules/graph/cognee_graph/CogneeGraph.py +89 -39
- cognee/modules/graph/cognee_graph/CogneeGraphElements.py +8 -3
- cognee/modules/graph/utils/expand_with_nodes_and_edges.py +19 -2
- cognee/modules/graph/utils/resolve_edges_to_text.py +48 -49
- cognee/modules/ingestion/identify.py +4 -4
- cognee/modules/memify/memify.py +1 -7
- cognee/modules/notebooks/operations/run_in_local_sandbox.py +3 -0
- cognee/modules/ontology/rdf_xml/RDFLibOntologyResolver.py +55 -23
- cognee/modules/pipelines/operations/pipeline.py +18 -2
- cognee/modules/pipelines/operations/run_tasks_data_item.py +1 -1
- cognee/modules/retrieval/EntityCompletionRetriever.py +10 -3
- cognee/modules/retrieval/__init__.py +1 -1
- cognee/modules/retrieval/base_graph_retriever.py +7 -3
- cognee/modules/retrieval/base_retriever.py +7 -3
- cognee/modules/retrieval/completion_retriever.py +11 -4
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +10 -2
- cognee/modules/retrieval/graph_completion_cot_retriever.py +18 -51
- cognee/modules/retrieval/graph_completion_retriever.py +14 -1
- cognee/modules/retrieval/graph_summary_completion_retriever.py +4 -0
- cognee/modules/retrieval/register_retriever.py +10 -0
- cognee/modules/retrieval/registered_community_retrievers.py +1 -0
- cognee/modules/retrieval/temporal_retriever.py +13 -2
- cognee/modules/retrieval/triplet_retriever.py +182 -0
- cognee/modules/retrieval/utils/brute_force_triplet_search.py +43 -11
- cognee/modules/retrieval/utils/completion.py +2 -22
- cognee/modules/run_custom_pipeline/__init__.py +1 -0
- cognee/modules/run_custom_pipeline/run_custom_pipeline.py +76 -0
- cognee/modules/search/methods/get_search_type_tools.py +54 -8
- cognee/modules/search/methods/no_access_control_search.py +4 -0
- cognee/modules/search/methods/search.py +26 -3
- cognee/modules/search/types/SearchType.py +1 -1
- cognee/modules/settings/get_settings.py +19 -0
- cognee/modules/users/methods/create_user.py +12 -27
- cognee/modules/users/methods/get_authenticated_user.py +3 -2
- cognee/modules/users/methods/get_default_user.py +4 -2
- cognee/modules/users/methods/get_user.py +1 -1
- cognee/modules/users/methods/get_user_by_email.py +1 -1
- cognee/modules/users/models/DatasetDatabase.py +24 -3
- cognee/modules/users/models/Tenant.py +6 -7
- cognee/modules/users/models/User.py +6 -5
- cognee/modules/users/models/UserTenant.py +12 -0
- cognee/modules/users/models/__init__.py +1 -0
- cognee/modules/users/permissions/methods/get_all_user_permission_datasets.py +13 -13
- cognee/modules/users/roles/methods/add_user_to_role.py +3 -1
- cognee/modules/users/tenants/methods/__init__.py +1 -0
- cognee/modules/users/tenants/methods/add_user_to_tenant.py +21 -12
- cognee/modules/users/tenants/methods/create_tenant.py +22 -8
- cognee/modules/users/tenants/methods/select_tenant.py +62 -0
- cognee/shared/logging_utils.py +6 -0
- cognee/shared/rate_limiting.py +30 -0
- cognee/tasks/chunks/__init__.py +1 -0
- cognee/tasks/chunks/chunk_by_row.py +94 -0
- cognee/tasks/documents/__init__.py +0 -1
- cognee/tasks/documents/classify_documents.py +2 -0
- cognee/tasks/feedback/generate_improved_answers.py +3 -3
- cognee/tasks/graph/extract_graph_from_data.py +9 -10
- cognee/tasks/ingestion/ingest_data.py +1 -1
- cognee/tasks/memify/__init__.py +2 -0
- cognee/tasks/memify/cognify_session.py +41 -0
- cognee/tasks/memify/extract_user_sessions.py +73 -0
- cognee/tasks/memify/get_triplet_datapoints.py +289 -0
- cognee/tasks/storage/add_data_points.py +142 -2
- cognee/tasks/storage/index_data_points.py +33 -22
- cognee/tasks/storage/index_graph_edges.py +37 -57
- cognee/tests/integration/documents/CsvDocument_test.py +70 -0
- cognee/tests/integration/retrieval/test_triplet_retriever.py +84 -0
- cognee/tests/integration/tasks/test_add_data_points.py +139 -0
- cognee/tests/integration/tasks/test_get_triplet_datapoints.py +69 -0
- cognee/tests/tasks/entity_extraction/entity_extraction_test.py +1 -1
- cognee/tests/test_add_docling_document.py +2 -2
- cognee/tests/test_cognee_server_start.py +84 -3
- cognee/tests/test_conversation_history.py +68 -5
- cognee/tests/test_data/example_with_header.csv +3 -0
- cognee/tests/test_dataset_database_handler.py +137 -0
- cognee/tests/test_dataset_delete.py +76 -0
- cognee/tests/test_edge_centered_payload.py +170 -0
- cognee/tests/test_edge_ingestion.py +27 -0
- cognee/tests/test_feedback_enrichment.py +1 -1
- cognee/tests/test_library.py +6 -4
- cognee/tests/test_load.py +62 -0
- cognee/tests/test_multi_tenancy.py +165 -0
- cognee/tests/test_parallel_databases.py +2 -0
- cognee/tests/test_pipeline_cache.py +164 -0
- cognee/tests/test_relational_db_migration.py +54 -2
- cognee/tests/test_search_db.py +44 -2
- cognee/tests/unit/api/test_conditional_authentication_endpoints.py +12 -3
- cognee/tests/unit/api/test_ontology_endpoint.py +252 -0
- cognee/tests/unit/infrastructure/databases/cache/test_cache_config.py +5 -0
- cognee/tests/unit/infrastructure/databases/test_index_data_points.py +27 -0
- cognee/tests/unit/infrastructure/databases/test_index_graph_edges.py +14 -16
- cognee/tests/unit/infrastructure/llm/test_llm_config.py +46 -0
- cognee/tests/unit/infrastructure/mock_embedding_engine.py +3 -7
- cognee/tests/unit/infrastructure/test_embedding_rate_limiting_realistic.py +0 -5
- cognee/tests/unit/modules/chunking/test_text_chunker.py +248 -0
- cognee/tests/unit/modules/chunking/test_text_chunker_with_overlap.py +324 -0
- cognee/tests/unit/modules/graph/cognee_graph_elements_test.py +2 -2
- cognee/tests/unit/modules/graph/cognee_graph_test.py +406 -0
- cognee/tests/unit/modules/memify_tasks/test_cognify_session.py +111 -0
- cognee/tests/unit/modules/memify_tasks/test_extract_user_sessions.py +175 -0
- cognee/tests/unit/modules/memify_tasks/test_get_triplet_datapoints.py +214 -0
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +0 -51
- cognee/tests/unit/modules/retrieval/rag_completion_retriever_test.py +1 -0
- cognee/tests/unit/modules/retrieval/structured_output_test.py +204 -0
- cognee/tests/unit/modules/retrieval/summaries_retriever_test.py +1 -1
- cognee/tests/unit/modules/retrieval/temporal_retriever_test.py +0 -1
- cognee/tests/unit/modules/retrieval/test_brute_force_triplet_search.py +608 -0
- cognee/tests/unit/modules/retrieval/triplet_retriever_test.py +83 -0
- cognee/tests/unit/modules/users/test_conditional_authentication.py +0 -63
- cognee/tests/unit/processing/chunks/chunk_by_row_test.py +52 -0
- cognee/tests/unit/tasks/storage/test_add_data_points.py +288 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/METADATA +11 -7
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/RECORD +212 -160
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/entry_points.txt +0 -1
- cognee/api/v1/cognify/code_graph_pipeline.py +0 -119
- cognee/api/v1/cognify/routers/get_code_pipeline_router.py +0 -90
- cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +0 -544
- cognee/modules/retrieval/code_retriever.py +0 -232
- cognee/tasks/code/enrich_dependency_graph_checker.py +0 -35
- cognee/tasks/code/get_local_dependencies_checker.py +0 -20
- cognee/tasks/code/get_repo_dependency_graph_checker.py +0 -35
- cognee/tasks/documents/check_permissions_on_dataset.py +0 -26
- cognee/tasks/repo_processor/__init__.py +0 -2
- cognee/tasks/repo_processor/get_local_dependencies.py +0 -335
- cognee/tasks/repo_processor/get_non_code_files.py +0 -158
- cognee/tasks/repo_processor/get_repo_file_dependencies.py +0 -243
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/WHEEL +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.4.1.dist-info → cognee-0.5.0.dist-info}/licenses/NOTICE.md +0 -0
|
@@ -56,6 +56,68 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
56
56
|
def get_edges(self) -> List[Edge]:
|
|
57
57
|
return self.edges
|
|
58
58
|
|
|
59
|
+
async def _get_nodeset_subgraph(
|
|
60
|
+
self,
|
|
61
|
+
adapter,
|
|
62
|
+
node_type,
|
|
63
|
+
node_name,
|
|
64
|
+
):
|
|
65
|
+
"""Retrieve subgraph based on node type and name."""
|
|
66
|
+
logger.info("Retrieving graph filtered by node type and node name (NodeSet).")
|
|
67
|
+
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
|
68
|
+
node_type=node_type, node_name=node_name
|
|
69
|
+
)
|
|
70
|
+
if not nodes_data or not edges_data:
|
|
71
|
+
raise EntityNotFoundError(
|
|
72
|
+
message="Nodeset does not exist, or empty nodeset projected from the database."
|
|
73
|
+
)
|
|
74
|
+
return nodes_data, edges_data
|
|
75
|
+
|
|
76
|
+
async def _get_full_or_id_filtered_graph(
|
|
77
|
+
self,
|
|
78
|
+
adapter,
|
|
79
|
+
relevant_ids_to_filter,
|
|
80
|
+
):
|
|
81
|
+
"""Retrieve full or ID-filtered graph with fallback."""
|
|
82
|
+
if relevant_ids_to_filter is None:
|
|
83
|
+
logger.info("Retrieving full graph.")
|
|
84
|
+
nodes_data, edges_data = await adapter.get_graph_data()
|
|
85
|
+
if not nodes_data or not edges_data:
|
|
86
|
+
raise EntityNotFoundError(message="Empty graph projected from the database.")
|
|
87
|
+
return nodes_data, edges_data
|
|
88
|
+
|
|
89
|
+
get_graph_data_fn = getattr(adapter, "get_id_filtered_graph_data", adapter.get_graph_data)
|
|
90
|
+
if getattr(adapter.__class__, "get_id_filtered_graph_data", None):
|
|
91
|
+
logger.info("Retrieving ID-filtered graph from database.")
|
|
92
|
+
nodes_data, edges_data = await get_graph_data_fn(target_ids=relevant_ids_to_filter)
|
|
93
|
+
else:
|
|
94
|
+
logger.info("Retrieving full graph from database.")
|
|
95
|
+
nodes_data, edges_data = await get_graph_data_fn()
|
|
96
|
+
if hasattr(adapter, "get_id_filtered_graph_data") and (not nodes_data or not edges_data):
|
|
97
|
+
logger.warning(
|
|
98
|
+
"Id filtered graph returned empty, falling back to full graph retrieval."
|
|
99
|
+
)
|
|
100
|
+
logger.info("Retrieving full graph")
|
|
101
|
+
nodes_data, edges_data = await adapter.get_graph_data()
|
|
102
|
+
|
|
103
|
+
if not nodes_data or not edges_data:
|
|
104
|
+
raise EntityNotFoundError("Empty graph projected from the database.")
|
|
105
|
+
return nodes_data, edges_data
|
|
106
|
+
|
|
107
|
+
async def _get_filtered_graph(
|
|
108
|
+
self,
|
|
109
|
+
adapter,
|
|
110
|
+
memory_fragment_filter,
|
|
111
|
+
):
|
|
112
|
+
"""Retrieve graph filtered by attributes."""
|
|
113
|
+
logger.info("Retrieving graph filtered by memory fragment")
|
|
114
|
+
nodes_data, edges_data = await adapter.get_filtered_graph_data(
|
|
115
|
+
attribute_filters=memory_fragment_filter
|
|
116
|
+
)
|
|
117
|
+
if not nodes_data or not edges_data:
|
|
118
|
+
raise EntityNotFoundError(message="Empty filtered graph projected from the database.")
|
|
119
|
+
return nodes_data, edges_data
|
|
120
|
+
|
|
59
121
|
async def project_graph_from_db(
|
|
60
122
|
self,
|
|
61
123
|
adapter: Union[GraphDBInterface],
|
|
@@ -67,40 +129,39 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
67
129
|
memory_fragment_filter=[],
|
|
68
130
|
node_type: Optional[Type] = None,
|
|
69
131
|
node_name: Optional[List[str]] = None,
|
|
132
|
+
relevant_ids_to_filter: Optional[List[str]] = None,
|
|
133
|
+
triplet_distance_penalty: float = 3.5,
|
|
70
134
|
) -> None:
|
|
71
135
|
if node_dimension < 1 or edge_dimension < 1:
|
|
72
136
|
raise InvalidDimensionsError()
|
|
73
137
|
try:
|
|
74
|
-
import time
|
|
75
|
-
|
|
76
|
-
start_time = time.time()
|
|
77
|
-
|
|
78
|
-
# Determine projection strategy
|
|
79
138
|
if node_type is not None and node_name not in [None, [], ""]:
|
|
80
|
-
nodes_data, edges_data = await
|
|
81
|
-
node_type
|
|
139
|
+
nodes_data, edges_data = await self._get_nodeset_subgraph(
|
|
140
|
+
adapter, node_type, node_name
|
|
82
141
|
)
|
|
83
|
-
if not nodes_data or not edges_data:
|
|
84
|
-
raise EntityNotFoundError(
|
|
85
|
-
message="Nodeset does not exist, or empty nodetes projected from the database."
|
|
86
|
-
)
|
|
87
142
|
elif len(memory_fragment_filter) == 0:
|
|
88
|
-
nodes_data, edges_data = await
|
|
89
|
-
|
|
90
|
-
|
|
143
|
+
nodes_data, edges_data = await self._get_full_or_id_filtered_graph(
|
|
144
|
+
adapter, relevant_ids_to_filter
|
|
145
|
+
)
|
|
91
146
|
else:
|
|
92
|
-
nodes_data, edges_data = await
|
|
93
|
-
|
|
147
|
+
nodes_data, edges_data = await self._get_filtered_graph(
|
|
148
|
+
adapter, memory_fragment_filter
|
|
94
149
|
)
|
|
95
|
-
if not nodes_data or not edges_data:
|
|
96
|
-
raise EntityNotFoundError(
|
|
97
|
-
message="Empty filtered graph projected from the database."
|
|
98
|
-
)
|
|
99
150
|
|
|
151
|
+
import time
|
|
152
|
+
|
|
153
|
+
start_time = time.time()
|
|
100
154
|
# Process nodes
|
|
101
155
|
for node_id, properties in nodes_data:
|
|
102
156
|
node_attributes = {key: properties.get(key) for key in node_properties_to_project}
|
|
103
|
-
self.add_node(
|
|
157
|
+
self.add_node(
|
|
158
|
+
Node(
|
|
159
|
+
str(node_id),
|
|
160
|
+
node_attributes,
|
|
161
|
+
dimension=node_dimension,
|
|
162
|
+
node_penalty=triplet_distance_penalty,
|
|
163
|
+
)
|
|
164
|
+
)
|
|
104
165
|
|
|
105
166
|
# Process edges
|
|
106
167
|
for source_id, target_id, relationship_type, properties in edges_data:
|
|
@@ -118,6 +179,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
118
179
|
attributes=edge_attributes,
|
|
119
180
|
directed=directed,
|
|
120
181
|
dimension=edge_dimension,
|
|
182
|
+
edge_penalty=triplet_distance_penalty,
|
|
121
183
|
)
|
|
122
184
|
self.add_edge(edge)
|
|
123
185
|
|
|
@@ -149,30 +211,18 @@ class CogneeGraph(CogneeAbstractGraph):
|
|
|
149
211
|
node.add_attribute("vector_distance", score)
|
|
150
212
|
mapped_nodes += 1
|
|
151
213
|
|
|
152
|
-
async def map_vector_distances_to_graph_edges(
|
|
153
|
-
self, vector_engine, query_vector, edge_distances
|
|
154
|
-
) -> None:
|
|
214
|
+
async def map_vector_distances_to_graph_edges(self, edge_distances) -> None:
|
|
155
215
|
try:
|
|
156
|
-
if query_vector is None or len(query_vector) == 0:
|
|
157
|
-
raise ValueError("Failed to generate query embedding.")
|
|
158
|
-
|
|
159
216
|
if edge_distances is None:
|
|
160
|
-
|
|
161
|
-
edge_distances = await vector_engine.search(
|
|
162
|
-
collection_name="EdgeType_relationship_name",
|
|
163
|
-
query_vector=query_vector,
|
|
164
|
-
limit=None,
|
|
165
|
-
)
|
|
166
|
-
projection_time = time.time() - start_time
|
|
167
|
-
logger.info(
|
|
168
|
-
f"Edge collection distances were calculated separately from nodes in {projection_time:.2f}s"
|
|
169
|
-
)
|
|
217
|
+
return
|
|
170
218
|
|
|
171
219
|
embedding_map = {result.payload["text"]: result.score for result in edge_distances}
|
|
172
220
|
|
|
173
221
|
for edge in self.edges:
|
|
174
|
-
|
|
175
|
-
|
|
222
|
+
edge_key = edge.attributes.get("edge_text") or edge.attributes.get(
|
|
223
|
+
"relationship_type"
|
|
224
|
+
)
|
|
225
|
+
distance = embedding_map.get(edge_key, None)
|
|
176
226
|
if distance is not None:
|
|
177
227
|
edge.attributes["vector_distance"] = distance
|
|
178
228
|
|
|
@@ -20,13 +20,17 @@ class Node:
|
|
|
20
20
|
status: np.ndarray
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
23
|
-
self,
|
|
23
|
+
self,
|
|
24
|
+
node_id: str,
|
|
25
|
+
attributes: Optional[Dict[str, Any]] = None,
|
|
26
|
+
dimension: int = 1,
|
|
27
|
+
node_penalty: float = 3.5,
|
|
24
28
|
):
|
|
25
29
|
if dimension <= 0:
|
|
26
30
|
raise InvalidDimensionsError()
|
|
27
31
|
self.id = node_id
|
|
28
32
|
self.attributes = attributes if attributes is not None else {}
|
|
29
|
-
self.attributes["vector_distance"] =
|
|
33
|
+
self.attributes["vector_distance"] = node_penalty
|
|
30
34
|
self.skeleton_neighbours = []
|
|
31
35
|
self.skeleton_edges = []
|
|
32
36
|
self.status = np.ones(dimension, dtype=int)
|
|
@@ -105,13 +109,14 @@ class Edge:
|
|
|
105
109
|
attributes: Optional[Dict[str, Any]] = None,
|
|
106
110
|
directed: bool = True,
|
|
107
111
|
dimension: int = 1,
|
|
112
|
+
edge_penalty: float = 3.5,
|
|
108
113
|
):
|
|
109
114
|
if dimension <= 0:
|
|
110
115
|
raise InvalidDimensionsError()
|
|
111
116
|
self.node1 = node1
|
|
112
117
|
self.node2 = node2
|
|
113
118
|
self.attributes = attributes if attributes is not None else {}
|
|
114
|
-
self.attributes["vector_distance"] =
|
|
119
|
+
self.attributes["vector_distance"] = edge_penalty
|
|
115
120
|
self.directed = directed
|
|
116
121
|
self.status = np.ones(dimension, dtype=int)
|
|
117
122
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import Optional
|
|
2
2
|
|
|
3
|
+
from cognee.infrastructure.engine.models.Edge import Edge
|
|
3
4
|
from cognee.modules.chunking.models import DocumentChunk
|
|
4
5
|
from cognee.modules.engine.models import Entity, EntityType
|
|
5
6
|
from cognee.modules.engine.utils import (
|
|
@@ -243,10 +244,26 @@ def _process_graph_nodes(
|
|
|
243
244
|
ontology_relationships,
|
|
244
245
|
)
|
|
245
246
|
|
|
246
|
-
# Add entity to data chunk
|
|
247
247
|
if data_chunk.contains is None:
|
|
248
248
|
data_chunk.contains = []
|
|
249
|
-
|
|
249
|
+
|
|
250
|
+
edge_text = "; ".join(
|
|
251
|
+
[
|
|
252
|
+
"relationship_name: contains",
|
|
253
|
+
f"entity_name: {entity_node.name}",
|
|
254
|
+
f"entity_description: {entity_node.description}",
|
|
255
|
+
]
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
data_chunk.contains.append(
|
|
259
|
+
(
|
|
260
|
+
Edge(
|
|
261
|
+
relationship_type="contains",
|
|
262
|
+
edge_text=edge_text,
|
|
263
|
+
),
|
|
264
|
+
entity_node,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
250
267
|
|
|
251
268
|
|
|
252
269
|
def _process_graph_edges(
|
|
@@ -1,71 +1,70 @@
|
|
|
1
|
+
import string
|
|
1
2
|
from typing import List
|
|
2
|
-
from
|
|
3
|
-
|
|
3
|
+
from collections import Counter
|
|
4
4
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
Converts retrieved graph edges into a human-readable string format.
|
|
5
|
+
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
|
6
|
+
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
8
7
|
|
|
9
|
-
Parameters:
|
|
10
|
-
-----------
|
|
11
8
|
|
|
12
|
-
|
|
9
|
+
def _get_top_n_frequent_words(
|
|
10
|
+
text: str, stop_words: set = None, top_n: int = 3, separator: str = ", "
|
|
11
|
+
) -> str:
|
|
12
|
+
"""Concatenates the top N frequent words in text."""
|
|
13
|
+
if stop_words is None:
|
|
14
|
+
stop_words = DEFAULT_STOP_WORDS
|
|
13
15
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
+
words = [word.lower().strip(string.punctuation) for word in text.split()]
|
|
17
|
+
words = [word for word in words if word and word not in stop_words]
|
|
16
18
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
+
top_words = [word for word, freq in Counter(words).most_common(top_n)]
|
|
20
|
+
return separator.join(top_words)
|
|
19
21
|
|
|
20
|
-
def _get_nodes(retrieved_edges: List[Edge]) -> dict:
|
|
21
|
-
def _get_title(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
|
22
|
-
def _top_n_words(text, stop_words=None, top_n=3, separator=", "):
|
|
23
|
-
"""Concatenates the top N frequent words in text."""
|
|
24
|
-
if stop_words is None:
|
|
25
|
-
from cognee.modules.retrieval.utils.stop_words import DEFAULT_STOP_WORDS
|
|
26
22
|
|
|
27
|
-
|
|
23
|
+
def _create_title_from_text(text: str, first_n_words: int = 7, top_n_words: int = 3) -> str:
|
|
24
|
+
"""Creates a title by combining first words with most frequent words from the text."""
|
|
25
|
+
first_words = text.split()[:first_n_words]
|
|
26
|
+
top_words = _get_top_n_frequent_words(text, top_n=top_n_words)
|
|
27
|
+
return f"{' '.join(first_words)}... [{top_words}]"
|
|
28
28
|
|
|
29
|
-
import string
|
|
30
29
|
|
|
31
|
-
|
|
30
|
+
def _extract_nodes_from_edges(retrieved_edges: List[Edge]) -> dict:
|
|
31
|
+
"""Creates a dictionary of nodes with their names and content."""
|
|
32
|
+
nodes = {}
|
|
32
33
|
|
|
33
|
-
|
|
34
|
-
|
|
34
|
+
for edge in retrieved_edges:
|
|
35
|
+
for node in (edge.node1, edge.node2):
|
|
36
|
+
if node.id in nodes:
|
|
37
|
+
continue
|
|
35
38
|
|
|
36
|
-
|
|
39
|
+
text = node.attributes.get("text")
|
|
40
|
+
if text:
|
|
41
|
+
name = _create_title_from_text(text)
|
|
42
|
+
content = text
|
|
43
|
+
else:
|
|
44
|
+
name = node.attributes.get("name", "Unnamed Node")
|
|
45
|
+
content = node.attributes.get("description", name)
|
|
37
46
|
|
|
38
|
-
|
|
47
|
+
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
39
48
|
|
|
40
|
-
|
|
49
|
+
return nodes
|
|
41
50
|
|
|
42
|
-
"""Creates a title, by combining first words with most frequent words from the text."""
|
|
43
|
-
first_words = text.split()[:first_n_words]
|
|
44
|
-
top_words = _top_n_words(text, top_n=first_n_words)
|
|
45
|
-
return f"{' '.join(first_words)}... [{top_words}]"
|
|
46
51
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
for node in (edge.node1, edge.node2):
|
|
51
|
-
if node.id not in nodes:
|
|
52
|
-
text = node.attributes.get("text")
|
|
53
|
-
if text:
|
|
54
|
-
name = _get_title(text)
|
|
55
|
-
content = text
|
|
56
|
-
else:
|
|
57
|
-
name = node.attributes.get("name", "Unnamed Node")
|
|
58
|
-
content = node.attributes.get("description", name)
|
|
59
|
-
nodes[node.id] = {"node": node, "name": name, "content": content}
|
|
60
|
-
return nodes
|
|
52
|
+
async def resolve_edges_to_text(retrieved_edges: List[Edge]) -> str:
|
|
53
|
+
"""Converts retrieved graph edges into a human-readable string format."""
|
|
54
|
+
nodes = _extract_nodes_from_edges(retrieved_edges)
|
|
61
55
|
|
|
62
|
-
nodes = _get_nodes(retrieved_edges)
|
|
63
56
|
node_section = "\n".join(
|
|
64
57
|
f"Node: {info['name']}\n__node_content_start__\n{info['content']}\n__node_content_end__\n"
|
|
65
58
|
for info in nodes.values()
|
|
66
59
|
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
60
|
+
|
|
61
|
+
connections = []
|
|
62
|
+
for edge in retrieved_edges:
|
|
63
|
+
source_name = nodes[edge.node1.id]["name"]
|
|
64
|
+
target_name = nodes[edge.node2.id]["name"]
|
|
65
|
+
edge_label = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
|
|
66
|
+
connections.append(f"{source_name} --[{edge_label}]--> {target_name}")
|
|
67
|
+
|
|
68
|
+
connection_section = "\n".join(connections)
|
|
69
|
+
|
|
71
70
|
return f"Nodes:\n{node_section}\n\nConnections:\n{connection_section}"
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from uuid import
|
|
1
|
+
from uuid import UUID
|
|
2
2
|
from .data_types import IngestionData
|
|
3
3
|
|
|
4
4
|
from cognee.modules.users.models import User
|
|
5
|
+
from cognee.modules.data.methods import get_unique_data_id
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
def identify(data: IngestionData, user: User) ->
|
|
8
|
+
async def identify(data: IngestionData, user: User) -> UUID:
|
|
8
9
|
data_content_hash: str = data.get_identifier()
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
return uuid5(NAMESPACE_OID, f"{data_content_hash}{user.id}")
|
|
11
|
+
return await get_unique_data_id(data_identifier=data_content_hash, user=user)
|
cognee/modules/memify/memify.py
CHANGED
|
@@ -12,9 +12,6 @@ from cognee.modules.users.models import User
|
|
|
12
12
|
from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|
13
13
|
resolve_authorized_user_datasets,
|
|
14
14
|
)
|
|
15
|
-
from cognee.modules.pipelines.layers.reset_dataset_pipeline_run_status import (
|
|
16
|
-
reset_dataset_pipeline_run_status,
|
|
17
|
-
)
|
|
18
15
|
from cognee.modules.engine.operations.setup import setup
|
|
19
16
|
from cognee.modules.pipelines.layers.pipeline_execution_mode import get_pipeline_executor
|
|
20
17
|
from cognee.tasks.memify.extract_subgraph_chunks import extract_subgraph_chunks
|
|
@@ -97,10 +94,6 @@ async def memify(
|
|
|
97
94
|
*enrichment_tasks,
|
|
98
95
|
]
|
|
99
96
|
|
|
100
|
-
await reset_dataset_pipeline_run_status(
|
|
101
|
-
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
|
102
|
-
)
|
|
103
|
-
|
|
104
97
|
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
|
105
98
|
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
|
106
99
|
|
|
@@ -113,6 +106,7 @@ async def memify(
|
|
|
113
106
|
datasets=authorized_dataset.id,
|
|
114
107
|
vector_db_config=vector_db_config,
|
|
115
108
|
graph_db_config=graph_db_config,
|
|
109
|
+
use_pipeline_cache=False,
|
|
116
110
|
incremental_loading=False,
|
|
117
111
|
pipeline_name="memify_pipeline",
|
|
118
112
|
)
|
|
@@ -2,6 +2,8 @@ import io
|
|
|
2
2
|
import sys
|
|
3
3
|
import traceback
|
|
4
4
|
|
|
5
|
+
import cognee
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
def wrap_in_async_handler(user_code: str) -> str:
|
|
7
9
|
return (
|
|
@@ -34,6 +36,7 @@ def run_in_local_sandbox(code, environment=None, loop=None):
|
|
|
34
36
|
|
|
35
37
|
environment["print"] = customPrintFunction
|
|
36
38
|
environment["running_loop"] = loop
|
|
39
|
+
environment["cognee"] = cognee
|
|
37
40
|
|
|
38
41
|
try:
|
|
39
42
|
exec(code, environment)
|
|
@@ -2,7 +2,7 @@ import os
|
|
|
2
2
|
import difflib
|
|
3
3
|
from cognee.shared.logging_utils import get_logger
|
|
4
4
|
from collections import deque
|
|
5
|
-
from typing import List, Tuple, Dict, Optional, Any, Union
|
|
5
|
+
from typing import List, Tuple, Dict, Optional, Any, Union, IO
|
|
6
6
|
from rdflib import Graph, URIRef, RDF, RDFS, OWL
|
|
7
7
|
|
|
8
8
|
from cognee.modules.ontology.exceptions import (
|
|
@@ -26,44 +26,76 @@ class RDFLibOntologyResolver(BaseOntologyResolver):
|
|
|
26
26
|
|
|
27
27
|
def __init__(
|
|
28
28
|
self,
|
|
29
|
-
ontology_file: Optional[Union[str, List[str]]] = None,
|
|
29
|
+
ontology_file: Optional[Union[str, List[str], IO, List[IO]]] = None,
|
|
30
30
|
matching_strategy: Optional[MatchingStrategy] = None,
|
|
31
31
|
) -> None:
|
|
32
32
|
super().__init__(matching_strategy)
|
|
33
33
|
self.ontology_file = ontology_file
|
|
34
34
|
try:
|
|
35
|
-
|
|
35
|
+
self.graph = None
|
|
36
36
|
if ontology_file is not None:
|
|
37
|
-
|
|
37
|
+
files_to_load = []
|
|
38
|
+
file_objects = []
|
|
39
|
+
|
|
40
|
+
if hasattr(ontology_file, "read"):
|
|
41
|
+
file_objects = [ontology_file]
|
|
42
|
+
elif isinstance(ontology_file, str):
|
|
38
43
|
files_to_load = [ontology_file]
|
|
39
44
|
elif isinstance(ontology_file, list):
|
|
40
|
-
|
|
45
|
+
if all(hasattr(item, "read") for item in ontology_file):
|
|
46
|
+
file_objects = ontology_file
|
|
47
|
+
else:
|
|
48
|
+
files_to_load = ontology_file
|
|
41
49
|
else:
|
|
42
50
|
raise ValueError(
|
|
43
|
-
f"ontology_file must be a string, list of strings, or None. Got: {type(ontology_file)}"
|
|
51
|
+
f"ontology_file must be a string, list of strings, file-like object, list of file-like objects, or None. Got: {type(ontology_file)}"
|
|
44
52
|
)
|
|
45
53
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
+
if file_objects:
|
|
55
|
+
self.graph = Graph()
|
|
56
|
+
loaded_objects = []
|
|
57
|
+
for file_obj in file_objects:
|
|
58
|
+
try:
|
|
59
|
+
content = file_obj.read()
|
|
60
|
+
self.graph.parse(data=content, format="xml")
|
|
61
|
+
loaded_objects.append(file_obj)
|
|
62
|
+
logger.info("Ontology loaded successfully from file object")
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.warning("Failed to parse ontology file object: %s", str(e))
|
|
65
|
+
|
|
66
|
+
if not loaded_objects:
|
|
67
|
+
logger.info(
|
|
68
|
+
"No valid ontology file objects found. No owl ontology will be attached to the graph."
|
|
69
|
+
)
|
|
70
|
+
self.graph = None
|
|
54
71
|
else:
|
|
55
|
-
logger.
|
|
56
|
-
|
|
57
|
-
|
|
72
|
+
logger.info("Total ontology file objects loaded: %d", len(loaded_objects))
|
|
73
|
+
|
|
74
|
+
elif files_to_load:
|
|
75
|
+
self.graph = Graph()
|
|
76
|
+
loaded_files = []
|
|
77
|
+
for file_path in files_to_load:
|
|
78
|
+
if os.path.exists(file_path):
|
|
79
|
+
self.graph.parse(file_path)
|
|
80
|
+
loaded_files.append(file_path)
|
|
81
|
+
logger.info("Ontology loaded successfully from file: %s", file_path)
|
|
82
|
+
else:
|
|
83
|
+
logger.warning(
|
|
84
|
+
"Ontology file '%s' not found. Skipping this file.",
|
|
85
|
+
file_path,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if not loaded_files:
|
|
89
|
+
logger.info(
|
|
90
|
+
"No valid ontology files found. No owl ontology will be attached to the graph."
|
|
58
91
|
)
|
|
59
|
-
|
|
60
|
-
|
|
92
|
+
self.graph = None
|
|
93
|
+
else:
|
|
94
|
+
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
|
95
|
+
else:
|
|
61
96
|
logger.info(
|
|
62
|
-
"No
|
|
97
|
+
"No ontology file provided. No owl ontology will be attached to the graph."
|
|
63
98
|
)
|
|
64
|
-
self.graph = None
|
|
65
|
-
else:
|
|
66
|
-
logger.info("Total ontology files loaded: %d", len(loaded_files))
|
|
67
99
|
else:
|
|
68
100
|
logger.info(
|
|
69
101
|
"No ontology file provided. No owl ontology will be attached to the graph."
|
|
@@ -20,6 +20,9 @@ from cognee.modules.pipelines.layers.resolve_authorized_user_datasets import (
|
|
|
20
20
|
from cognee.modules.pipelines.layers.check_pipeline_run_qualification import (
|
|
21
21
|
check_pipeline_run_qualification,
|
|
22
22
|
)
|
|
23
|
+
from cognee.modules.pipelines.models.PipelineRunInfo import (
|
|
24
|
+
PipelineRunStarted,
|
|
25
|
+
)
|
|
23
26
|
from typing import Any
|
|
24
27
|
|
|
25
28
|
logger = get_logger("cognee.pipeline")
|
|
@@ -35,6 +38,7 @@ async def run_pipeline(
|
|
|
35
38
|
pipeline_name: str = "custom_pipeline",
|
|
36
39
|
vector_db_config: dict = None,
|
|
37
40
|
graph_db_config: dict = None,
|
|
41
|
+
use_pipeline_cache: bool = False,
|
|
38
42
|
incremental_loading: bool = False,
|
|
39
43
|
data_per_batch: int = 20,
|
|
40
44
|
):
|
|
@@ -51,6 +55,7 @@ async def run_pipeline(
|
|
|
51
55
|
data=data,
|
|
52
56
|
pipeline_name=pipeline_name,
|
|
53
57
|
context={"dataset": dataset},
|
|
58
|
+
use_pipeline_cache=use_pipeline_cache,
|
|
54
59
|
incremental_loading=incremental_loading,
|
|
55
60
|
data_per_batch=data_per_batch,
|
|
56
61
|
):
|
|
@@ -64,6 +69,7 @@ async def run_pipeline_per_dataset(
|
|
|
64
69
|
data=None,
|
|
65
70
|
pipeline_name: str = "custom_pipeline",
|
|
66
71
|
context: dict = None,
|
|
72
|
+
use_pipeline_cache=False,
|
|
67
73
|
incremental_loading=False,
|
|
68
74
|
data_per_batch: int = 20,
|
|
69
75
|
):
|
|
@@ -77,8 +83,18 @@ async def run_pipeline_per_dataset(
|
|
|
77
83
|
if process_pipeline_status:
|
|
78
84
|
# If pipeline was already processed or is currently being processed
|
|
79
85
|
# return status information to async generator and finish execution
|
|
80
|
-
|
|
81
|
-
|
|
86
|
+
if use_pipeline_cache:
|
|
87
|
+
# If pipeline caching is enabled we do not proceed with re-processing
|
|
88
|
+
yield process_pipeline_status
|
|
89
|
+
return
|
|
90
|
+
else:
|
|
91
|
+
# If pipeline caching is disabled we always return pipeline started information and proceed with re-processing
|
|
92
|
+
yield PipelineRunStarted(
|
|
93
|
+
pipeline_run_id=process_pipeline_status.pipeline_run_id,
|
|
94
|
+
dataset_id=dataset.id,
|
|
95
|
+
dataset_name=dataset.name,
|
|
96
|
+
payload=data,
|
|
97
|
+
)
|
|
82
98
|
|
|
83
99
|
pipeline_run = run_tasks(
|
|
84
100
|
tasks,
|
|
@@ -69,7 +69,7 @@ async def run_tasks_data_item_incremental(
|
|
|
69
69
|
async with open_data_file(file_path) as file:
|
|
70
70
|
classified_data = ingestion.classify(file)
|
|
71
71
|
# data_id is the hash of file contents + owner id to avoid duplicate data
|
|
72
|
-
data_id = ingestion.identify(classified_data, user)
|
|
72
|
+
data_id = await ingestion.identify(classified_data, user)
|
|
73
73
|
else:
|
|
74
74
|
# If data was already processed by Cognee get data id
|
|
75
75
|
data_id = data_item.id
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from typing import Any, Optional, List
|
|
2
|
+
from typing import Any, Optional, List, Type
|
|
3
3
|
from cognee.shared.logging_utils import get_logger
|
|
4
4
|
|
|
5
5
|
from cognee.infrastructure.entities.BaseEntityExtractor import BaseEntityExtractor
|
|
@@ -85,8 +85,12 @@ class EntityCompletionRetriever(BaseRetriever):
|
|
|
85
85
|
return None
|
|
86
86
|
|
|
87
87
|
async def get_completion(
|
|
88
|
-
self,
|
|
89
|
-
|
|
88
|
+
self,
|
|
89
|
+
query: str,
|
|
90
|
+
context: Optional[Any] = None,
|
|
91
|
+
session_id: Optional[str] = None,
|
|
92
|
+
response_model: Type = str,
|
|
93
|
+
) -> List[Any]:
|
|
90
94
|
"""
|
|
91
95
|
Generate completion using provided context or fetch new context.
|
|
92
96
|
|
|
@@ -102,6 +106,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|
|
102
106
|
fetched if not provided. (default None)
|
|
103
107
|
- session_id (Optional[str]): Optional session identifier for caching. If None,
|
|
104
108
|
defaults to 'default_session'. (default None)
|
|
109
|
+
- response_model (Type): The Pydantic model type for structured output. (default str)
|
|
105
110
|
|
|
106
111
|
Returns:
|
|
107
112
|
--------
|
|
@@ -133,6 +138,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|
|
133
138
|
user_prompt_path=self.user_prompt_path,
|
|
134
139
|
system_prompt_path=self.system_prompt_path,
|
|
135
140
|
conversation_history=conversation_history,
|
|
141
|
+
response_model=response_model,
|
|
136
142
|
),
|
|
137
143
|
)
|
|
138
144
|
else:
|
|
@@ -141,6 +147,7 @@ class EntityCompletionRetriever(BaseRetriever):
|
|
|
141
147
|
context=context,
|
|
142
148
|
user_prompt_path=self.user_prompt_path,
|
|
143
149
|
system_prompt_path=self.system_prompt_path,
|
|
150
|
+
response_model=response_model,
|
|
144
151
|
)
|
|
145
152
|
|
|
146
153
|
if session_save:
|
|
@@ -1 +1 @@
|
|
|
1
|
-
|
|
1
|
+
|