cognee 0.3.0.dev0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/__init__.py +1 -0
- cognee/api/v1/save/save.py +335 -0
- cognee/api/v1/search/routers/get_search_router.py +3 -3
- cognee/api/v1/ui/__init__.py +1 -0
- cognee/api/v1/ui/ui.py +624 -0
- cognee/cli/_cognee.py +102 -0
- cognee/modules/retrieval/graph_completion_context_extension_retriever.py +1 -1
- cognee/modules/retrieval/graph_completion_cot_retriever.py +1 -1
- cognee/modules/retrieval/graph_completion_retriever.py +1 -1
- cognee/modules/retrieval/insights_retriever.py +12 -11
- cognee/modules/retrieval/temporal_retriever.py +1 -1
- cognee/modules/search/methods/search.py +31 -8
- cognee/tests/test_permissions.py +3 -3
- cognee/tests/test_relational_db_migration.py +3 -5
- cognee/tests/test_save_export_path.py +116 -0
- cognee/tests/test_search_db.py +10 -7
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_context_extension_test.py +12 -6
- cognee/tests/unit/modules/retrieval/graph_completion_retriever_cot_test.py +12 -6
- cognee/tests/unit/modules/retrieval/insights_retriever_test.py +2 -4
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/METADATA +2 -2
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/RECORD +34 -30
- distributed/pyproject.toml +1 -1
- /cognee/tests/{integration/cli → cli_tests/cli_integration_tests}/__init__.py +0 -0
- /cognee/tests/{integration/cli → cli_tests/cli_integration_tests}/test_cli_integration.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/__init__.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/test_cli_commands.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/test_cli_edge_cases.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/test_cli_main.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/test_cli_runner.py +0 -0
- /cognee/tests/{unit/cli → cli_tests/cli_unit_tests}/test_cli_utils.py +0 -0
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/WHEEL +0 -0
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/entry_points.txt +0 -0
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.0.dev0.dist-info → cognee-0.3.2.dist-info}/licenses/NOTICE.md +0 -0
cognee/cli/_cognee.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import os
|
|
3
3
|
import argparse
|
|
4
|
+
import signal
|
|
5
|
+
import subprocess
|
|
4
6
|
from typing import Any, Sequence, Dict, Type, cast, List
|
|
5
7
|
import click
|
|
6
8
|
|
|
@@ -51,6 +53,31 @@ class DebugAction(argparse.Action):
|
|
|
51
53
|
fmt.note("Debug mode enabled. Full stack traces will be shown.")
|
|
52
54
|
|
|
53
55
|
|
|
56
|
+
class UiAction(argparse.Action):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
option_strings: Sequence[str],
|
|
60
|
+
dest: Any = argparse.SUPPRESS,
|
|
61
|
+
default: Any = argparse.SUPPRESS,
|
|
62
|
+
help: str = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
super(UiAction, self).__init__(
|
|
65
|
+
option_strings=option_strings, dest=dest, default=default, nargs=0, help=help
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def __call__(
|
|
69
|
+
self,
|
|
70
|
+
parser: argparse.ArgumentParser,
|
|
71
|
+
namespace: argparse.Namespace,
|
|
72
|
+
values: Any,
|
|
73
|
+
option_string: str = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
# Set a flag to indicate UI should be started
|
|
76
|
+
global ACTION_EXECUTED
|
|
77
|
+
ACTION_EXECUTED = True
|
|
78
|
+
namespace.start_ui = True
|
|
79
|
+
|
|
80
|
+
|
|
54
81
|
# Debug functionality is now in cognee.cli.debug module
|
|
55
82
|
|
|
56
83
|
|
|
@@ -97,6 +124,11 @@ def _create_parser() -> tuple[argparse.ArgumentParser, Dict[str, SupportsCliComm
|
|
|
97
124
|
action=DebugAction,
|
|
98
125
|
help="Enable debug mode to show full stack traces on exceptions",
|
|
99
126
|
)
|
|
127
|
+
parser.add_argument(
|
|
128
|
+
"-ui",
|
|
129
|
+
action=UiAction,
|
|
130
|
+
help="Start the cognee web UI interface",
|
|
131
|
+
)
|
|
100
132
|
|
|
101
133
|
subparsers = parser.add_subparsers(title="Available commands", dest="command")
|
|
102
134
|
|
|
@@ -140,6 +172,76 @@ def main() -> int:
|
|
|
140
172
|
parser, installed_commands = _create_parser()
|
|
141
173
|
args = parser.parse_args()
|
|
142
174
|
|
|
175
|
+
# Handle UI flag
|
|
176
|
+
if hasattr(args, "start_ui") and args.start_ui:
|
|
177
|
+
spawned_pids = []
|
|
178
|
+
|
|
179
|
+
def signal_handler(signum, frame):
|
|
180
|
+
"""Handle Ctrl+C and other termination signals"""
|
|
181
|
+
nonlocal spawned_pids
|
|
182
|
+
fmt.echo("\nShutting down UI server...")
|
|
183
|
+
|
|
184
|
+
for pid in spawned_pids:
|
|
185
|
+
try:
|
|
186
|
+
pgid = os.getpgid(pid)
|
|
187
|
+
os.killpg(pgid, signal.SIGTERM)
|
|
188
|
+
fmt.success(f"✓ Process group {pgid} (PID {pid}) terminated.")
|
|
189
|
+
except (OSError, ProcessLookupError) as e:
|
|
190
|
+
fmt.warning(f"Could not terminate process {pid}: {e}")
|
|
191
|
+
|
|
192
|
+
sys.exit(0)
|
|
193
|
+
|
|
194
|
+
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
|
195
|
+
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
from cognee import start_ui
|
|
199
|
+
|
|
200
|
+
fmt.echo("Starting cognee UI...")
|
|
201
|
+
|
|
202
|
+
# Callback to capture PIDs of all spawned processes
|
|
203
|
+
def pid_callback(pid):
|
|
204
|
+
nonlocal spawned_pids
|
|
205
|
+
spawned_pids.append(pid)
|
|
206
|
+
|
|
207
|
+
server_process = start_ui(
|
|
208
|
+
host="localhost",
|
|
209
|
+
port=3000,
|
|
210
|
+
open_browser=True,
|
|
211
|
+
start_backend=True,
|
|
212
|
+
auto_download=True,
|
|
213
|
+
pid_callback=pid_callback,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if server_process:
|
|
217
|
+
fmt.success("UI server started successfully!")
|
|
218
|
+
fmt.echo("The interface is available at: http://localhost:3000")
|
|
219
|
+
fmt.echo("The API backend is available at: http://localhost:8000")
|
|
220
|
+
fmt.note("Press Ctrl+C to stop the server...")
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
# Keep the server running
|
|
224
|
+
import time
|
|
225
|
+
|
|
226
|
+
while server_process.poll() is None: # While process is still running
|
|
227
|
+
time.sleep(1)
|
|
228
|
+
except KeyboardInterrupt:
|
|
229
|
+
# This shouldn't happen now due to signal handler, but kept for safety
|
|
230
|
+
signal_handler(signal.SIGINT, None)
|
|
231
|
+
|
|
232
|
+
return 0
|
|
233
|
+
else:
|
|
234
|
+
fmt.error("Failed to start UI server. Check the logs above for details.")
|
|
235
|
+
signal_handler(signal.SIGTERM, None)
|
|
236
|
+
return 1
|
|
237
|
+
|
|
238
|
+
except Exception as ex:
|
|
239
|
+
fmt.error(f"Error starting UI: {str(ex)}")
|
|
240
|
+
signal_handler(signal.SIGTERM, None)
|
|
241
|
+
if debug.is_debug_enabled():
|
|
242
|
+
raise ex
|
|
243
|
+
return 1
|
|
244
|
+
|
|
143
245
|
if cmd := installed_commands.get(args.command):
|
|
144
246
|
try:
|
|
145
247
|
cmd.execute(args)
|
|
@@ -171,7 +171,7 @@ class GraphCompletionRetriever(BaseGraphRetriever):
|
|
|
171
171
|
question=query, answer=completion, context=context_text, triplets=triplets
|
|
172
172
|
)
|
|
173
173
|
|
|
174
|
-
return completion
|
|
174
|
+
return [completion]
|
|
175
175
|
|
|
176
176
|
async def save_qa(self, question: str, answer: str, context: str, triplets: List) -> None:
|
|
177
177
|
"""
|
|
@@ -96,17 +96,18 @@ class InsightsRetriever(BaseGraphRetriever):
|
|
|
96
96
|
unique_node_connections_map[unique_id] = True
|
|
97
97
|
unique_node_connections.append(node_connection)
|
|
98
98
|
|
|
99
|
-
return
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
99
|
+
return unique_node_connections
|
|
100
|
+
# return [
|
|
101
|
+
# Edge(
|
|
102
|
+
# node1=Node(node_id=connection[0]["id"], attributes=connection[0]),
|
|
103
|
+
# node2=Node(node_id=connection[2]["id"], attributes=connection[2]),
|
|
104
|
+
# attributes={
|
|
105
|
+
# **connection[1],
|
|
106
|
+
# "relationship_type": connection[1]["relationship_name"],
|
|
107
|
+
# },
|
|
108
|
+
# )
|
|
109
|
+
# for connection in unique_node_connections
|
|
110
|
+
# ]
|
|
110
111
|
|
|
111
112
|
async def get_completion(self, query: str, context: Optional[Any] = None) -> Any:
|
|
112
113
|
"""
|
|
@@ -132,14 +132,37 @@ async def search(
|
|
|
132
132
|
],
|
|
133
133
|
)
|
|
134
134
|
else:
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
135
|
+
# This is for maintaining backwards compatibility
|
|
136
|
+
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
|
137
|
+
return_value = []
|
|
138
|
+
for search_result in search_results:
|
|
139
|
+
result, context, datasets = search_result
|
|
140
|
+
return_value.append(
|
|
141
|
+
{
|
|
142
|
+
"search_result": result,
|
|
143
|
+
"dataset_id": datasets[0].id,
|
|
144
|
+
"dataset_name": datasets[0].name,
|
|
145
|
+
}
|
|
146
|
+
)
|
|
147
|
+
return return_value
|
|
148
|
+
else:
|
|
149
|
+
return_value = []
|
|
150
|
+
for search_result in search_results:
|
|
151
|
+
result, context, datasets = search_result
|
|
152
|
+
return_value.append(result)
|
|
153
|
+
# For maintaining backwards compatibility
|
|
154
|
+
if len(return_value) == 1 and isinstance(return_value[0], list):
|
|
155
|
+
return return_value[0]
|
|
156
|
+
else:
|
|
157
|
+
return return_value
|
|
158
|
+
# return [
|
|
159
|
+
# SearchResult(
|
|
160
|
+
# search_result=result,
|
|
161
|
+
# dataset_id=datasets[min(index, len(datasets) - 1)].id if datasets else None,
|
|
162
|
+
# dataset_name=datasets[min(index, len(datasets) - 1)].name if datasets else None,
|
|
163
|
+
# )
|
|
164
|
+
# for index, (result, _, datasets) in enumerate(search_results)
|
|
165
|
+
# ]
|
|
143
166
|
|
|
144
167
|
|
|
145
168
|
async def authorized_search(
|
cognee/tests/test_permissions.py
CHANGED
|
@@ -79,7 +79,7 @@ async def main():
|
|
|
79
79
|
print("\n\nExtracted sentences are:\n")
|
|
80
80
|
for result in search_results:
|
|
81
81
|
print(f"{result}\n")
|
|
82
|
-
assert search_results[0]
|
|
82
|
+
assert search_results[0]["dataset_name"] == "NLP", (
|
|
83
83
|
f"Dict must contain dataset name 'NLP': {search_results[0]}"
|
|
84
84
|
)
|
|
85
85
|
|
|
@@ -93,7 +93,7 @@ async def main():
|
|
|
93
93
|
print("\n\nExtracted sentences are:\n")
|
|
94
94
|
for result in search_results:
|
|
95
95
|
print(f"{result}\n")
|
|
96
|
-
assert search_results[0]
|
|
96
|
+
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
|
97
97
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
|
98
98
|
)
|
|
99
99
|
|
|
@@ -170,7 +170,7 @@ async def main():
|
|
|
170
170
|
for result in search_results:
|
|
171
171
|
print(f"{result}\n")
|
|
172
172
|
|
|
173
|
-
assert search_results[0]
|
|
173
|
+
assert search_results[0]["dataset_name"] == "QUANTUM", (
|
|
174
174
|
f"Dict must contain dataset name 'QUANTUM': {search_results[0]}"
|
|
175
175
|
)
|
|
176
176
|
|
|
@@ -45,15 +45,13 @@ async def relational_db_migration():
|
|
|
45
45
|
await migrate_relational_database(graph_engine, schema=schema)
|
|
46
46
|
|
|
47
47
|
# 1. Search the graph
|
|
48
|
-
search_results
|
|
48
|
+
search_results = await cognee.search(
|
|
49
49
|
query_type=SearchType.GRAPH_COMPLETION, query_text="Tell me about the artist AC/DC"
|
|
50
|
-
)
|
|
50
|
+
)
|
|
51
51
|
print("Search results:", search_results)
|
|
52
52
|
|
|
53
53
|
# 2. Assert that the search results contain "AC/DC"
|
|
54
|
-
assert any("AC/DC" in r
|
|
55
|
-
"AC/DC not found in search results!"
|
|
56
|
-
)
|
|
54
|
+
assert any("AC/DC" in r for r in search_results), "AC/DC not found in search results!"
|
|
57
55
|
|
|
58
56
|
migration_db_provider = migration_engine.engine.dialect.name
|
|
59
57
|
if migration_db_provider == "postgresql":
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import asyncio
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.mark.asyncio
|
|
9
|
+
async def test_save_uses_custom_export_path(tmp_path, monkeypatch):
|
|
10
|
+
# Import target after tmp fixtures are ready
|
|
11
|
+
from cognee.api.v1.save import save as save_mod
|
|
12
|
+
|
|
13
|
+
# Prepare two mock datasets
|
|
14
|
+
class Dataset:
|
|
15
|
+
def __init__(self, id_, name):
|
|
16
|
+
self.id = id_
|
|
17
|
+
self.name = name
|
|
18
|
+
|
|
19
|
+
ds1 = Dataset(uuid4(), "dataset_alpha")
|
|
20
|
+
ds2 = Dataset(uuid4(), "dataset_beta")
|
|
21
|
+
|
|
22
|
+
# Mock dataset discovery
|
|
23
|
+
async def mock_get_authorized_existing_datasets(datasets, permission_type, user):
|
|
24
|
+
return [ds1, ds2]
|
|
25
|
+
|
|
26
|
+
monkeypatch.setattr(
|
|
27
|
+
save_mod, "get_authorized_existing_datasets", mock_get_authorized_existing_datasets
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Mock data items (with filename collision in ds1)
|
|
31
|
+
class DataItem:
|
|
32
|
+
def __init__(self, id_, name, original_path=None):
|
|
33
|
+
self.id = id_
|
|
34
|
+
self.name = name
|
|
35
|
+
self.original_data_location = original_path
|
|
36
|
+
|
|
37
|
+
ds1_items = [
|
|
38
|
+
DataItem(uuid4(), "report.txt", "/root/a/report.txt"),
|
|
39
|
+
DataItem(uuid4(), "report.txt", "/root/b/report.txt"), # collision
|
|
40
|
+
]
|
|
41
|
+
ds2_items = [
|
|
42
|
+
DataItem(uuid4(), "notes.md", "/root/x/notes.md"),
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
async def mock_get_dataset_data(dataset_id):
|
|
46
|
+
if dataset_id == ds1.id:
|
|
47
|
+
return ds1_items
|
|
48
|
+
if dataset_id == ds2.id:
|
|
49
|
+
return ds2_items
|
|
50
|
+
return []
|
|
51
|
+
|
|
52
|
+
monkeypatch.setattr(save_mod, "get_dataset_data", mock_get_dataset_data)
|
|
53
|
+
|
|
54
|
+
# Mock summary retrieval
|
|
55
|
+
async def mock_get_document_summaries_text(data_id: str) -> str:
|
|
56
|
+
return "This is a summary."
|
|
57
|
+
|
|
58
|
+
monkeypatch.setattr(save_mod, "_get_document_summaries_text", mock_get_document_summaries_text)
|
|
59
|
+
|
|
60
|
+
# Mock questions
|
|
61
|
+
async def mock_generate_questions(file_name: str, summary_text: str):
|
|
62
|
+
return ["Q1?", "Q2?", "Q3?"]
|
|
63
|
+
|
|
64
|
+
monkeypatch.setattr(save_mod, "_generate_questions", mock_generate_questions)
|
|
65
|
+
|
|
66
|
+
# Mock searches per question
|
|
67
|
+
async def mock_run_searches_for_question(question, dataset_id, search_types, top_k):
|
|
68
|
+
return {st.value: [f"{question} -> ok"] for st in search_types}
|
|
69
|
+
|
|
70
|
+
monkeypatch.setattr(save_mod, "_run_searches_for_question", mock_run_searches_for_question)
|
|
71
|
+
|
|
72
|
+
# Use custom export path
|
|
73
|
+
export_dir = tmp_path / "my_exports"
|
|
74
|
+
export_dir_str = str(export_dir)
|
|
75
|
+
|
|
76
|
+
# Run
|
|
77
|
+
result = await save_mod.save(
|
|
78
|
+
datasets=None,
|
|
79
|
+
export_root_directory=export_dir_str,
|
|
80
|
+
max_questions=3,
|
|
81
|
+
search_types=["GRAPH_COMPLETION", "INSIGHTS", "CHUNKS"],
|
|
82
|
+
top_k=2,
|
|
83
|
+
include_summary=True,
|
|
84
|
+
include_ascii_tree=True,
|
|
85
|
+
concurrency=2,
|
|
86
|
+
timeout=None,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Verify returned mapping points to our custom path
|
|
90
|
+
assert str(ds1.id) in result and str(ds2.id) in result
|
|
91
|
+
assert result[str(ds1.id)].startswith(export_dir_str)
|
|
92
|
+
assert result[str(ds2.id)].startswith(export_dir_str)
|
|
93
|
+
|
|
94
|
+
# Verify directories and files exist
|
|
95
|
+
ds1_dir = result[str(ds1.id)]
|
|
96
|
+
ds2_dir = result[str(ds2.id)]
|
|
97
|
+
|
|
98
|
+
assert os.path.isdir(ds1_dir)
|
|
99
|
+
assert os.path.isdir(ds2_dir)
|
|
100
|
+
|
|
101
|
+
# index.md present
|
|
102
|
+
assert os.path.isfile(os.path.join(ds1_dir, "index.md"))
|
|
103
|
+
assert os.path.isfile(os.path.join(ds2_dir, "index.md"))
|
|
104
|
+
|
|
105
|
+
# File markdowns exist; collision handling: two files with similar base
|
|
106
|
+
ds1_files = [f for f in os.listdir(ds1_dir) if f.endswith(".md") and f != "index.md"]
|
|
107
|
+
assert len(ds1_files) == 2
|
|
108
|
+
assert any(f == "report.txt.md" for f in ds1_files)
|
|
109
|
+
assert any(f.startswith("report.txt__") and f.endswith(".md") for f in ds1_files)
|
|
110
|
+
|
|
111
|
+
# Content sanity: ensure question headers exist in one file
|
|
112
|
+
sample_md_path = os.path.join(ds1_dir, ds1_files[0])
|
|
113
|
+
with open(sample_md_path, "r", encoding="utf-8") as fh:
|
|
114
|
+
content = fh.read()
|
|
115
|
+
assert "## Question ideas" in content
|
|
116
|
+
assert "## Searches" in content
|
cognee/tests/test_search_db.py
CHANGED
|
@@ -144,13 +144,16 @@ async def main():
|
|
|
144
144
|
("GRAPH_COMPLETION_CONTEXT_EXTENSION", completion_ext),
|
|
145
145
|
("GRAPH_SUMMARY_COMPLETION", completion_sum),
|
|
146
146
|
]:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
147
|
+
assert isinstance(search_results, list), f"{name}: should return a list"
|
|
148
|
+
assert len(search_results) == 1, (
|
|
149
|
+
f"{name}: expected single-element list, got {len(search_results)}"
|
|
150
|
+
)
|
|
151
|
+
text = search_results[0]
|
|
152
|
+
assert isinstance(text, str), f"{name}: element should be a string"
|
|
153
|
+
assert text.strip(), f"{name}: string should not be empty"
|
|
154
|
+
assert "netherlands" in text.lower(), (
|
|
155
|
+
f"{name}: expected 'netherlands' in result, got: {text!r}"
|
|
156
|
+
)
|
|
154
157
|
|
|
155
158
|
graph_engine = await get_graph_engine()
|
|
156
159
|
graph = await graph_engine.get_graph_data()
|
|
@@ -59,8 +59,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|
|
59
59
|
|
|
60
60
|
answer = await retriever.get_completion("Who works at Canva?")
|
|
61
61
|
|
|
62
|
-
assert isinstance(answer,
|
|
63
|
-
assert
|
|
62
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
63
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
64
|
+
"Answer must contain only non-empty strings"
|
|
65
|
+
)
|
|
64
66
|
|
|
65
67
|
@pytest.mark.asyncio
|
|
66
68
|
async def test_graph_completion_extension_context_complex(self):
|
|
@@ -140,8 +142,10 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|
|
140
142
|
|
|
141
143
|
answer = await retriever.get_completion("Who works at Figma?")
|
|
142
144
|
|
|
143
|
-
assert isinstance(answer,
|
|
144
|
-
assert
|
|
145
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
146
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
147
|
+
"Answer must contain only non-empty strings"
|
|
148
|
+
)
|
|
145
149
|
|
|
146
150
|
@pytest.mark.asyncio
|
|
147
151
|
async def test_get_graph_completion_extension_context_on_empty_graph(self):
|
|
@@ -171,5 +175,7 @@ class TestGraphCompletionWithContextExtensionRetriever:
|
|
|
171
175
|
|
|
172
176
|
answer = await retriever.get_completion("Who works at Figma?")
|
|
173
177
|
|
|
174
|
-
assert isinstance(answer,
|
|
175
|
-
assert
|
|
178
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
179
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
180
|
+
"Answer must contain only non-empty strings"
|
|
181
|
+
)
|
|
@@ -55,8 +55,10 @@ class TestGraphCompletionCoTRetriever:
|
|
|
55
55
|
|
|
56
56
|
answer = await retriever.get_completion("Who works at Canva?")
|
|
57
57
|
|
|
58
|
-
assert isinstance(answer,
|
|
59
|
-
assert
|
|
58
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
59
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
60
|
+
"Answer must contain only non-empty strings"
|
|
61
|
+
)
|
|
60
62
|
|
|
61
63
|
@pytest.mark.asyncio
|
|
62
64
|
async def test_graph_completion_cot_context_complex(self):
|
|
@@ -133,8 +135,10 @@ class TestGraphCompletionCoTRetriever:
|
|
|
133
135
|
|
|
134
136
|
answer = await retriever.get_completion("Who works at Figma?")
|
|
135
137
|
|
|
136
|
-
assert isinstance(answer,
|
|
137
|
-
assert
|
|
138
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
139
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
140
|
+
"Answer must contain only non-empty strings"
|
|
141
|
+
)
|
|
138
142
|
|
|
139
143
|
@pytest.mark.asyncio
|
|
140
144
|
async def test_get_graph_completion_cot_context_on_empty_graph(self):
|
|
@@ -164,5 +168,7 @@ class TestGraphCompletionCoTRetriever:
|
|
|
164
168
|
|
|
165
169
|
answer = await retriever.get_completion("Who works at Figma?")
|
|
166
170
|
|
|
167
|
-
assert isinstance(answer,
|
|
168
|
-
assert
|
|
171
|
+
assert isinstance(answer, list), f"Expected list, got {type(answer).__name__}"
|
|
172
|
+
assert all(isinstance(item, str) and item.strip() for item in answer), (
|
|
173
|
+
"Answer must contain only non-empty strings"
|
|
174
|
+
)
|
|
@@ -82,7 +82,7 @@ class TestInsightsRetriever:
|
|
|
82
82
|
|
|
83
83
|
context = await retriever.get_context("Mike")
|
|
84
84
|
|
|
85
|
-
assert context[0]
|
|
85
|
+
assert context[0][0]["name"] == "Mike Broski", "Failed to get Mike Broski"
|
|
86
86
|
|
|
87
87
|
@pytest.mark.asyncio
|
|
88
88
|
async def test_insights_context_complex(self):
|
|
@@ -222,9 +222,7 @@ class TestInsightsRetriever:
|
|
|
222
222
|
|
|
223
223
|
context = await retriever.get_context("Christina")
|
|
224
224
|
|
|
225
|
-
assert context[0]
|
|
226
|
-
"Failed to get Christina Mayer"
|
|
227
|
-
)
|
|
225
|
+
assert context[0][0]["name"] == "Christina Mayer", "Failed to get Christina Mayer"
|
|
228
226
|
|
|
229
227
|
@pytest.mark.asyncio
|
|
230
228
|
async def test_insights_context_on_empty_graph(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cognee
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning.
|
|
5
5
|
Project-URL: Homepage, https://www.cognee.ai
|
|
6
6
|
Project-URL: Repository, https://github.com/topoteretes/cognee
|
|
@@ -44,7 +44,7 @@ Requires-Dist: pydantic-settings<3,>=2.2.1
|
|
|
44
44
|
Requires-Dist: pydantic<3.0.0,>=2.10.5
|
|
45
45
|
Requires-Dist: pylance<1.0.0,>=0.22.0
|
|
46
46
|
Requires-Dist: pympler<2.0.0,>=1.1
|
|
47
|
-
Requires-Dist: pypdf<
|
|
47
|
+
Requires-Dist: pypdf<7.0.0,>=4.1.0
|
|
48
48
|
Requires-Dist: python-dotenv<2.0.0,>=1.0.1
|
|
49
49
|
Requires-Dist: python-magic-bin<0.5; platform_system == 'Windows'
|
|
50
50
|
Requires-Dist: python-multipart<1.0.0,>=0.0.20
|