cognee 0.3.4.dev2__py3-none-any.whl → 0.3.4.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cognee/api/v1/prune/prune.py +2 -2
- cognee/api/v1/sync/sync.py +16 -5
- cognee/base_config.py +15 -0
- cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py +4 -1
- cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -1
- cognee/infrastructure/files/storage/LocalFileStorage.py +50 -0
- cognee/infrastructure/files/storage/S3FileStorage.py +56 -9
- cognee/infrastructure/files/storage/StorageManager.py +18 -0
- cognee/modules/cloud/operations/check_api_key.py +4 -1
- cognee/modules/data/deletion/prune_system.py +5 -1
- cognee/modules/notebooks/methods/create_notebook.py +32 -0
- cognee/modules/notebooks/models/Notebook.py +206 -1
- cognee/modules/users/methods/create_user.py +5 -23
- cognee/shared/cache.py +346 -0
- cognee/shared/utils.py +12 -0
- cognee/tasks/ingestion/save_data_item_to_storage.py +1 -0
- cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +399 -0
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/METADATA +2 -1
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/RECORD +23 -21
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/WHEEL +0 -0
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/entry_points.txt +0 -0
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/licenses/LICENSE +0 -0
- {cognee-0.3.4.dev2.dist-info → cognee-0.3.4.dev3.dist-info}/licenses/NOTICE.md +0 -0
cognee/api/v1/prune/prune.py
CHANGED
|
@@ -7,8 +7,8 @@ class prune:
|
|
|
7
7
|
await _prune_data()
|
|
8
8
|
|
|
9
9
|
@staticmethod
|
|
10
|
-
async def prune_system(graph=True, vector=True, metadata=False):
|
|
11
|
-
await _prune_system(graph, vector, metadata)
|
|
10
|
+
async def prune_system(graph=True, vector=True, metadata=False, cache=True):
|
|
11
|
+
await _prune_system(graph, vector, metadata, cache)
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
if __name__ == "__main__":
|
cognee/api/v1/sync/sync.py
CHANGED
|
@@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
|
|
|
23
23
|
mark_sync_completed,
|
|
24
24
|
mark_sync_failed,
|
|
25
25
|
)
|
|
26
|
+
from cognee.shared.utils import create_secure_ssl_context
|
|
26
27
|
|
|
27
28
|
logger = get_logger("sync")
|
|
28
29
|
|
|
@@ -583,7 +584,9 @@ async def _check_hashes_diff(
|
|
|
583
584
|
logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
|
|
584
585
|
|
|
585
586
|
try:
|
|
586
|
-
|
|
587
|
+
ssl_context = create_secure_ssl_context()
|
|
588
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
589
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
587
590
|
async with session.post(url, json=payload.dict(), headers=headers) as response:
|
|
588
591
|
if response.status == 200:
|
|
589
592
|
data = await response.json()
|
|
@@ -630,7 +633,9 @@ async def _download_missing_files(
|
|
|
630
633
|
|
|
631
634
|
headers = {"X-Api-Key": auth_token}
|
|
632
635
|
|
|
633
|
-
|
|
636
|
+
ssl_context = create_secure_ssl_context()
|
|
637
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
638
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
634
639
|
for file_hash in hashes_missing_on_local:
|
|
635
640
|
try:
|
|
636
641
|
# Download file from cloud by hash
|
|
@@ -749,7 +754,9 @@ async def _upload_missing_files(
|
|
|
749
754
|
|
|
750
755
|
headers = {"X-Api-Key": auth_token}
|
|
751
756
|
|
|
752
|
-
|
|
757
|
+
ssl_context = create_secure_ssl_context()
|
|
758
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
759
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
753
760
|
for file_info in files_to_upload:
|
|
754
761
|
try:
|
|
755
762
|
file_dir = os.path.dirname(file_info.raw_data_location)
|
|
@@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
|
|
|
809
816
|
logger.info("Pruning cloud dataset to match local state")
|
|
810
817
|
|
|
811
818
|
try:
|
|
812
|
-
|
|
819
|
+
ssl_context = create_secure_ssl_context()
|
|
820
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
821
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
813
822
|
async with session.put(url, json=payload.dict(), headers=headers) as response:
|
|
814
823
|
if response.status == 200:
|
|
815
824
|
data = await response.json()
|
|
@@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
|
|
|
852
861
|
logger.info(f"Triggering cognify processing for dataset {dataset_id}")
|
|
853
862
|
|
|
854
863
|
try:
|
|
855
|
-
|
|
864
|
+
ssl_context = create_secure_ssl_context()
|
|
865
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
866
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
856
867
|
async with session.post(url, json=payload, headers=headers) as response:
|
|
857
868
|
if response.status == 200:
|
|
858
869
|
data = await response.json()
|
cognee/base_config.py
CHANGED
|
@@ -10,13 +10,27 @@ import pydantic
|
|
|
10
10
|
class BaseConfig(BaseSettings):
|
|
11
11
|
data_root_directory: str = get_absolute_path(".data_storage")
|
|
12
12
|
system_root_directory: str = get_absolute_path(".cognee_system")
|
|
13
|
+
cache_root_directory: str = get_absolute_path(".cognee_cache")
|
|
13
14
|
monitoring_tool: object = Observer.LANGFUSE
|
|
14
15
|
|
|
15
16
|
@pydantic.model_validator(mode="after")
|
|
16
17
|
def validate_paths(self):
|
|
18
|
+
# Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
|
|
19
|
+
# I'll remove this after we update documentation for S3 storage
|
|
20
|
+
# Auto-configure cache root directory for S3 storage if not explicitly set
|
|
21
|
+
storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
|
|
22
|
+
cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
|
|
23
|
+
|
|
24
|
+
if storage_backend == "s3" and not cache_root_env:
|
|
25
|
+
# Auto-generate S3 cache path when using S3 storage
|
|
26
|
+
bucket_name = os.getenv("STORAGE_BUCKET_NAME")
|
|
27
|
+
if bucket_name:
|
|
28
|
+
self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
|
|
29
|
+
|
|
17
30
|
# Require absolute paths for root directories
|
|
18
31
|
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
|
19
32
|
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
|
33
|
+
self.cache_root_directory = ensure_absolute_path(self.cache_root_directory)
|
|
20
34
|
return self
|
|
21
35
|
|
|
22
36
|
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
|
@@ -31,6 +45,7 @@ class BaseConfig(BaseSettings):
|
|
|
31
45
|
"data_root_directory": self.data_root_directory,
|
|
32
46
|
"system_root_directory": self.system_root_directory,
|
|
33
47
|
"monitoring_tool": self.monitoring_tool,
|
|
48
|
+
"cache_root_directory": self.cache_root_directory,
|
|
34
49
|
}
|
|
35
50
|
|
|
36
51
|
|
|
@@ -7,6 +7,7 @@ import aiohttp
|
|
|
7
7
|
from uuid import UUID
|
|
8
8
|
|
|
9
9
|
from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
|
|
10
|
+
from cognee.shared.utils import create_secure_ssl_context
|
|
10
11
|
|
|
11
12
|
logger = get_logger()
|
|
12
13
|
|
|
@@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
|
|
|
42
43
|
async def _get_session(self) -> aiohttp.ClientSession:
|
|
43
44
|
"""Get or create an aiohttp session."""
|
|
44
45
|
if self._session is None or self._session.closed:
|
|
45
|
-
|
|
46
|
+
ssl_context = create_secure_ssl_context()
|
|
47
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
48
|
+
self._session = aiohttp.ClientSession(connector=connector)
|
|
46
49
|
return self._session
|
|
47
50
|
|
|
48
51
|
async def close(self):
|
|
@@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
|
|
|
14
14
|
embedding_rate_limit_async,
|
|
15
15
|
embedding_sleep_and_retry_async,
|
|
16
16
|
)
|
|
17
|
+
from cognee.shared.utils import create_secure_ssl_context
|
|
17
18
|
|
|
18
19
|
logger = get_logger("OllamaEmbeddingEngine")
|
|
19
20
|
|
|
@@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
|
|
|
101
102
|
if api_key:
|
|
102
103
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
103
104
|
|
|
104
|
-
|
|
105
|
+
ssl_context = create_secure_ssl_context()
|
|
106
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
107
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
105
108
|
async with session.post(
|
|
106
109
|
self.endpoint, json=payload, headers=headers, timeout=60.0
|
|
107
110
|
) as response:
|
|
@@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
|
|
|
253
253
|
if os.path.exists(full_file_path):
|
|
254
254
|
os.remove(full_file_path)
|
|
255
255
|
|
|
256
|
+
def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
|
257
|
+
"""
|
|
258
|
+
List all files in the specified directory.
|
|
259
|
+
|
|
260
|
+
Parameters:
|
|
261
|
+
-----------
|
|
262
|
+
- directory_path (str): The directory path to list files from
|
|
263
|
+
- recursive (bool): If True, list files recursively in subdirectories
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
--------
|
|
267
|
+
- list[str]: List of file paths relative to the storage root
|
|
268
|
+
"""
|
|
269
|
+
from pathlib import Path
|
|
270
|
+
|
|
271
|
+
parsed_storage_path = get_parsed_path(self.storage_path)
|
|
272
|
+
|
|
273
|
+
if directory_path:
|
|
274
|
+
full_directory_path = os.path.join(parsed_storage_path, directory_path)
|
|
275
|
+
else:
|
|
276
|
+
full_directory_path = parsed_storage_path
|
|
277
|
+
|
|
278
|
+
directory_pathlib = Path(full_directory_path)
|
|
279
|
+
|
|
280
|
+
if not directory_pathlib.exists() or not directory_pathlib.is_dir():
|
|
281
|
+
return []
|
|
282
|
+
|
|
283
|
+
files = []
|
|
284
|
+
|
|
285
|
+
if recursive:
|
|
286
|
+
# Use rglob for recursive search
|
|
287
|
+
for file_path in directory_pathlib.rglob("*"):
|
|
288
|
+
if file_path.is_file():
|
|
289
|
+
# Get relative path from storage root
|
|
290
|
+
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
|
291
|
+
# Normalize path separators for consistency
|
|
292
|
+
relative_path = relative_path.replace(os.sep, "/")
|
|
293
|
+
files.append(relative_path)
|
|
294
|
+
else:
|
|
295
|
+
# Use iterdir for just immediate directory
|
|
296
|
+
for file_path in directory_pathlib.iterdir():
|
|
297
|
+
if file_path.is_file():
|
|
298
|
+
# Get relative path from storage root
|
|
299
|
+
relative_path = os.path.relpath(str(file_path), parsed_storage_path)
|
|
300
|
+
# Normalize path separators for consistency
|
|
301
|
+
relative_path = relative_path.replace(os.sep, "/")
|
|
302
|
+
files.append(relative_path)
|
|
303
|
+
|
|
304
|
+
return files
|
|
305
|
+
|
|
256
306
|
def remove_all(self, tree_path: str = None):
|
|
257
307
|
"""
|
|
258
308
|
Remove an entire directory tree at the specified path, including all files and
|
|
@@ -155,21 +155,19 @@ class S3FileStorage(Storage):
|
|
|
155
155
|
"""
|
|
156
156
|
Ensure that the specified directory exists, creating it if necessary.
|
|
157
157
|
|
|
158
|
-
|
|
158
|
+
For S3 storage, this is a no-op since directories are created implicitly
|
|
159
|
+
when files are written to paths. S3 doesn't have actual directories,
|
|
160
|
+
just object keys with prefixes that appear as directories.
|
|
159
161
|
|
|
160
162
|
Parameters:
|
|
161
163
|
-----------
|
|
162
164
|
|
|
163
165
|
- directory_path (str): The path of the directory to check or create.
|
|
164
166
|
"""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
if not self.s3.exists(directory_path):
|
|
170
|
-
self.s3.makedirs(directory_path, exist_ok=True)
|
|
171
|
-
|
|
172
|
-
await run_async(ensure_directory)
|
|
167
|
+
# In S3, directories don't exist as separate entities - they're just prefixes
|
|
168
|
+
# When you write a file to s3://bucket/path/to/file.txt, the "directories"
|
|
169
|
+
# path/ and path/to/ are implicitly created. No explicit action needed.
|
|
170
|
+
pass
|
|
173
171
|
|
|
174
172
|
async def copy_file(self, source_file_path: str, destination_file_path: str):
|
|
175
173
|
"""
|
|
@@ -213,6 +211,55 @@ class S3FileStorage(Storage):
|
|
|
213
211
|
|
|
214
212
|
await run_async(remove_file)
|
|
215
213
|
|
|
214
|
+
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
|
215
|
+
"""
|
|
216
|
+
List all files in the specified directory.
|
|
217
|
+
|
|
218
|
+
Parameters:
|
|
219
|
+
-----------
|
|
220
|
+
- directory_path (str): The directory path to list files from
|
|
221
|
+
- recursive (bool): If True, list files recursively in subdirectories
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
--------
|
|
225
|
+
- list[str]: List of file paths relative to the storage root
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
def list_files_sync():
|
|
229
|
+
if directory_path:
|
|
230
|
+
# Combine storage path with directory path
|
|
231
|
+
full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
|
|
232
|
+
else:
|
|
233
|
+
full_path = self.storage_path.replace("s3://", "")
|
|
234
|
+
|
|
235
|
+
if recursive:
|
|
236
|
+
# Use ** for recursive search
|
|
237
|
+
pattern = f"{full_path}/**"
|
|
238
|
+
else:
|
|
239
|
+
# Just files in the immediate directory
|
|
240
|
+
pattern = f"{full_path}/*"
|
|
241
|
+
|
|
242
|
+
# Use s3fs glob to find files
|
|
243
|
+
try:
|
|
244
|
+
all_paths = self.s3.glob(pattern)
|
|
245
|
+
# Filter to only files (not directories)
|
|
246
|
+
files = [path for path in all_paths if self.s3.isfile(path)]
|
|
247
|
+
|
|
248
|
+
# Convert back to relative paths from storage root
|
|
249
|
+
storage_prefix = self.storage_path.replace("s3://", "")
|
|
250
|
+
relative_files = []
|
|
251
|
+
for file_path in files:
|
|
252
|
+
if file_path.startswith(storage_prefix):
|
|
253
|
+
relative_path = file_path[len(storage_prefix) :].lstrip("/")
|
|
254
|
+
relative_files.append(relative_path)
|
|
255
|
+
|
|
256
|
+
return relative_files
|
|
257
|
+
except Exception:
|
|
258
|
+
# If directory doesn't exist or other error, return empty list
|
|
259
|
+
return []
|
|
260
|
+
|
|
261
|
+
return await run_async(list_files_sync)
|
|
262
|
+
|
|
216
263
|
async def remove_all(self, tree_path: str):
|
|
217
264
|
"""
|
|
218
265
|
Remove an entire directory tree at the specified path, including all files and
|
|
@@ -135,6 +135,24 @@ class StorageManager:
|
|
|
135
135
|
else:
|
|
136
136
|
return self.storage.remove(file_path)
|
|
137
137
|
|
|
138
|
+
async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
|
|
139
|
+
"""
|
|
140
|
+
List all files in the specified directory.
|
|
141
|
+
|
|
142
|
+
Parameters:
|
|
143
|
+
-----------
|
|
144
|
+
- directory_path (str): The directory path to list files from
|
|
145
|
+
- recursive (bool): If True, list files recursively in subdirectories
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
--------
|
|
149
|
+
- list[str]: List of file paths relative to the storage root
|
|
150
|
+
"""
|
|
151
|
+
if inspect.iscoroutinefunction(self.storage.list_files):
|
|
152
|
+
return await self.storage.list_files(directory_path, recursive)
|
|
153
|
+
else:
|
|
154
|
+
return self.storage.list_files(directory_path, recursive)
|
|
155
|
+
|
|
138
156
|
async def remove_all(self, tree_path: str = None):
|
|
139
157
|
"""
|
|
140
158
|
Remove an entire directory tree at the specified path, including all files and
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import aiohttp
|
|
2
2
|
|
|
3
3
|
from cognee.modules.cloud.exceptions import CloudConnectionError
|
|
4
|
+
from cognee.shared.utils import create_secure_ssl_context
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
async def check_api_key(auth_token: str):
|
|
@@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
|
|
|
10
11
|
headers = {"X-Api-Key": auth_token}
|
|
11
12
|
|
|
12
13
|
try:
|
|
13
|
-
|
|
14
|
+
ssl_context = create_secure_ssl_context()
|
|
15
|
+
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
|
16
|
+
async with aiohttp.ClientSession(connector=connector) as session:
|
|
14
17
|
async with session.post(url, headers=headers) as response:
|
|
15
18
|
if response.status == 200:
|
|
16
19
|
return
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
|
2
2
|
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
|
3
3
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
4
|
+
from cognee.shared.cache import delete_cache
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
async def prune_system(graph=True, vector=True, metadata=True):
|
|
7
|
+
async def prune_system(graph=True, vector=True, metadata=True, cache=True):
|
|
7
8
|
if graph:
|
|
8
9
|
graph_engine = await get_graph_engine()
|
|
9
10
|
await graph_engine.delete_graph()
|
|
@@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
|
|
|
15
16
|
if metadata:
|
|
16
17
|
db_engine = get_relational_engine()
|
|
17
18
|
await db_engine.delete_database()
|
|
19
|
+
|
|
20
|
+
if cache:
|
|
21
|
+
await delete_cache()
|
|
@@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session
|
|
|
7
7
|
from ..models.Notebook import Notebook, NotebookCell
|
|
8
8
|
|
|
9
9
|
|
|
10
|
+
async def _create_tutorial_notebook(
|
|
11
|
+
user_id: UUID, session: AsyncSession, force_refresh: bool = False
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Create the default tutorial notebook for new users.
|
|
15
|
+
Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
|
|
16
|
+
"""
|
|
17
|
+
TUTORIAL_ZIP_URL = (
|
|
18
|
+
"https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
# Create notebook from remote zip file (includes notebook + data files)
|
|
23
|
+
notebook = await Notebook.from_ipynb_zip_url(
|
|
24
|
+
zip_url=TUTORIAL_ZIP_URL,
|
|
25
|
+
owner_id=user_id,
|
|
26
|
+
notebook_filename="tutorial.ipynb",
|
|
27
|
+
name="Python Development with Cognee Tutorial 🧠",
|
|
28
|
+
deletable=False,
|
|
29
|
+
force=force_refresh,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Add to session and commit
|
|
33
|
+
session.add(notebook)
|
|
34
|
+
await session.commit()
|
|
35
|
+
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
|
|
38
|
+
|
|
39
|
+
raise e
|
|
40
|
+
|
|
41
|
+
|
|
10
42
|
@with_async_session
|
|
11
43
|
async def create_notebook(
|
|
12
44
|
user_id: UUID,
|
|
@@ -1,13 +1,24 @@
|
|
|
1
1
|
import json
|
|
2
|
-
|
|
2
|
+
import nbformat
|
|
3
|
+
import asyncio
|
|
4
|
+
from nbformat.notebooknode import NotebookNode
|
|
5
|
+
from typing import List, Literal, Optional, cast, Tuple
|
|
3
6
|
from uuid import uuid4, UUID as UUID_t
|
|
4
7
|
from pydantic import BaseModel, ConfigDict
|
|
5
8
|
from datetime import datetime, timezone
|
|
6
9
|
from fastapi.encoders import jsonable_encoder
|
|
7
10
|
from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
|
|
8
11
|
from sqlalchemy.orm import mapped_column, Mapped
|
|
12
|
+
from pathlib import Path
|
|
9
13
|
|
|
10
14
|
from cognee.infrastructure.databases.relational import Base
|
|
15
|
+
from cognee.shared.cache import (
|
|
16
|
+
download_and_extract_zip,
|
|
17
|
+
get_tutorial_data_dir,
|
|
18
|
+
generate_content_hash,
|
|
19
|
+
)
|
|
20
|
+
from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
|
|
21
|
+
from cognee.base_config import get_base_config
|
|
11
22
|
|
|
12
23
|
|
|
13
24
|
class NotebookCell(BaseModel):
|
|
@@ -51,3 +62,197 @@ class Notebook(Base):
|
|
|
51
62
|
deletable: Mapped[bool] = mapped_column(Boolean, default=True)
|
|
52
63
|
|
|
53
64
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
async def from_ipynb_zip_url(
|
|
68
|
+
cls,
|
|
69
|
+
zip_url: str,
|
|
70
|
+
owner_id: UUID_t,
|
|
71
|
+
notebook_filename: str = "tutorial.ipynb",
|
|
72
|
+
name: Optional[str] = None,
|
|
73
|
+
deletable: bool = True,
|
|
74
|
+
force: bool = False,
|
|
75
|
+
) -> "Notebook":
|
|
76
|
+
"""
|
|
77
|
+
Create a Notebook instance from a remote zip file containing notebook + data files.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
zip_url: Remote URL to fetch the .zip file from
|
|
81
|
+
owner_id: UUID of the notebook owner
|
|
82
|
+
notebook_filename: Name of the .ipynb file within the zip
|
|
83
|
+
name: Optional custom name for the notebook
|
|
84
|
+
deletable: Whether the notebook can be deleted
|
|
85
|
+
force: If True, re-download even if already cached
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Notebook instance
|
|
89
|
+
"""
|
|
90
|
+
# Generate a cache key based on the zip URL
|
|
91
|
+
content_hash = generate_content_hash(zip_url, notebook_filename)
|
|
92
|
+
|
|
93
|
+
# Download and extract the zip file to tutorial_data/{content_hash}
|
|
94
|
+
try:
|
|
95
|
+
extracted_cache_dir = await download_and_extract_zip(
|
|
96
|
+
url=zip_url,
|
|
97
|
+
cache_dir_name=f"tutorial_data/{content_hash}",
|
|
98
|
+
version_or_hash=content_hash,
|
|
99
|
+
force=force,
|
|
100
|
+
)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
|
|
103
|
+
|
|
104
|
+
# Use cache system to access the notebook file
|
|
105
|
+
from cognee.shared.cache import cache_file_exists, read_cache_file
|
|
106
|
+
|
|
107
|
+
notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
|
|
108
|
+
|
|
109
|
+
# Check if the notebook file exists in cache
|
|
110
|
+
if not await cache_file_exists(notebook_file_path):
|
|
111
|
+
raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
|
|
112
|
+
|
|
113
|
+
# Read and parse the notebook using cache system
|
|
114
|
+
async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
|
|
115
|
+
notebook_content = await asyncio.to_thread(f.read)
|
|
116
|
+
notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
|
|
117
|
+
|
|
118
|
+
# Update file paths in notebook cells to point to actual cached data files
|
|
119
|
+
await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
|
|
120
|
+
|
|
121
|
+
return notebook
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Update file paths in code cells to use actual cached data files.
|
|
127
|
+
Works with both local filesystem and S3 storage.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
notebook: Parsed Notebook instance with cells to update
|
|
131
|
+
cache_dir: Path to the cached tutorial directory containing data files
|
|
132
|
+
"""
|
|
133
|
+
import re
|
|
134
|
+
from cognee.shared.cache import list_cache_files, cache_file_exists
|
|
135
|
+
from cognee.shared.logging_utils import get_logger
|
|
136
|
+
|
|
137
|
+
logger = get_logger()
|
|
138
|
+
|
|
139
|
+
# Look for data files in the data subdirectory
|
|
140
|
+
data_dir = f"{cache_dir}/data"
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
# Get all data files in the cache directory using cache system
|
|
144
|
+
data_files = {}
|
|
145
|
+
if await cache_file_exists(data_dir):
|
|
146
|
+
file_list = await list_cache_files(data_dir)
|
|
147
|
+
else:
|
|
148
|
+
file_list = []
|
|
149
|
+
|
|
150
|
+
for file_path in file_list:
|
|
151
|
+
# Extract just the filename
|
|
152
|
+
filename = file_path.split("/")[-1]
|
|
153
|
+
# Use the file path as provided by cache system
|
|
154
|
+
data_files[filename] = file_path
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
# If we can't list files, skip updating paths
|
|
158
|
+
logger.error(f"Error listing data files in {data_dir}: {e}")
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
# Pattern to match file://data/filename patterns in code cells
|
|
162
|
+
file_pattern = r'"file://data/([^"]+)"'
|
|
163
|
+
|
|
164
|
+
def replace_path(match):
|
|
165
|
+
filename = match.group(1)
|
|
166
|
+
if filename in data_files:
|
|
167
|
+
file_path = data_files[filename]
|
|
168
|
+
# For local filesystem, preserve file:// prefix
|
|
169
|
+
if not file_path.startswith("s3://"):
|
|
170
|
+
return f'"file://{file_path}"'
|
|
171
|
+
else:
|
|
172
|
+
# For S3, return the S3 URL as-is
|
|
173
|
+
return f'"{file_path}"'
|
|
174
|
+
return match.group(0) # Keep original if file not found
|
|
175
|
+
|
|
176
|
+
# Update only code cells
|
|
177
|
+
updated_cells = 0
|
|
178
|
+
for cell in notebook.cells:
|
|
179
|
+
if cell.type == "code":
|
|
180
|
+
original_content = cell.content
|
|
181
|
+
# Update file paths in the cell content
|
|
182
|
+
cell.content = re.sub(file_pattern, replace_path, cell.content)
|
|
183
|
+
if original_content != cell.content:
|
|
184
|
+
updated_cells += 1
|
|
185
|
+
|
|
186
|
+
# Log summary of updates (useful for monitoring)
|
|
187
|
+
if updated_cells > 0:
|
|
188
|
+
logger.info(f"Updated file paths in {updated_cells} notebook cells")
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def from_ipynb_string(
|
|
192
|
+
cls,
|
|
193
|
+
notebook_content: str,
|
|
194
|
+
owner_id: UUID_t,
|
|
195
|
+
name: Optional[str] = None,
|
|
196
|
+
deletable: bool = True,
|
|
197
|
+
) -> "Notebook":
|
|
198
|
+
"""
|
|
199
|
+
Create a Notebook instance from Jupyter notebook string content.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
notebook_content: Raw Jupyter notebook content as string
|
|
203
|
+
owner_id: UUID of the notebook owner
|
|
204
|
+
name: Optional custom name for the notebook
|
|
205
|
+
deletable: Whether the notebook can be deleted
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Notebook instance ready to be saved to database
|
|
209
|
+
"""
|
|
210
|
+
# Parse and validate the Jupyter notebook using nbformat
|
|
211
|
+
# Note: nbformat.reads() has loose typing, so we cast to NotebookNode
|
|
212
|
+
jupyter_nb = cast(
|
|
213
|
+
NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Convert Jupyter cells to NotebookCell objects
|
|
217
|
+
cells = []
|
|
218
|
+
for jupyter_cell in jupyter_nb.cells:
|
|
219
|
+
# Each cell is also a NotebookNode with dynamic attributes
|
|
220
|
+
cell = cast(NotebookNode, jupyter_cell)
|
|
221
|
+
# Skip raw cells as they're not supported in our model
|
|
222
|
+
if cell.cell_type == "raw":
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Get the source content
|
|
226
|
+
content = cell.source
|
|
227
|
+
|
|
228
|
+
# Generate a name based on content or cell index
|
|
229
|
+
cell_name = cls._generate_cell_name(cell)
|
|
230
|
+
|
|
231
|
+
# Map cell types (jupyter uses "code"/"markdown", we use same)
|
|
232
|
+
cell_type = "code" if cell.cell_type == "code" else "markdown"
|
|
233
|
+
|
|
234
|
+
cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
|
|
235
|
+
|
|
236
|
+
# Extract notebook name from metadata if not provided
|
|
237
|
+
if name is None:
|
|
238
|
+
kernelspec = jupyter_nb.metadata.get("kernelspec", {})
|
|
239
|
+
name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
|
|
240
|
+
|
|
241
|
+
return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
|
|
245
|
+
"""Generate a meaningful name for a notebook cell using nbformat cell."""
|
|
246
|
+
if jupyter_cell.cell_type == "markdown":
|
|
247
|
+
# Try to extract a title from markdown headers
|
|
248
|
+
content = jupyter_cell.source
|
|
249
|
+
|
|
250
|
+
lines = content.strip().split("\n")
|
|
251
|
+
if lines and lines[0].startswith("#"):
|
|
252
|
+
# Extract header text, clean it up
|
|
253
|
+
header = lines[0].lstrip("#").strip()
|
|
254
|
+
return header[:50] if len(header) > 50 else header
|
|
255
|
+
else:
|
|
256
|
+
return "Markdown Cell"
|
|
257
|
+
else:
|
|
258
|
+
return "Code Cell"
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
from uuid import uuid4
|
|
1
|
+
from uuid import UUID, uuid4
|
|
2
2
|
from fastapi_users.exceptions import UserAlreadyExists
|
|
3
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
3
4
|
|
|
4
5
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
5
|
-
from cognee.modules.notebooks.
|
|
6
|
-
from cognee.modules.notebooks.
|
|
6
|
+
from cognee.modules.notebooks.models.Notebook import Notebook
|
|
7
|
+
from cognee.modules.notebooks.methods.create_notebook import _create_tutorial_notebook
|
|
7
8
|
from cognee.modules.users.exceptions import TenantNotFoundError
|
|
8
9
|
from cognee.modules.users.get_user_manager import get_user_manager_context
|
|
9
10
|
from cognee.modules.users.get_user_db import get_user_db_context
|
|
@@ -60,26 +61,7 @@ async def create_user(
|
|
|
60
61
|
if auto_login:
|
|
61
62
|
await session.refresh(user)
|
|
62
63
|
|
|
63
|
-
await
|
|
64
|
-
user_id=user.id,
|
|
65
|
-
notebook_name="Welcome to cognee 🧠",
|
|
66
|
-
cells=[
|
|
67
|
-
NotebookCell(
|
|
68
|
-
id=uuid4(),
|
|
69
|
-
name="Welcome",
|
|
70
|
-
content="Cognee is your toolkit for turning text into a structured knowledge graph, optionally enhanced by ontologies, and then querying it with advanced retrieval techniques. This notebook will guide you through a simple example.",
|
|
71
|
-
type="markdown",
|
|
72
|
-
),
|
|
73
|
-
NotebookCell(
|
|
74
|
-
id=uuid4(),
|
|
75
|
-
name="Example",
|
|
76
|
-
content="",
|
|
77
|
-
type="code",
|
|
78
|
-
),
|
|
79
|
-
],
|
|
80
|
-
deletable=False,
|
|
81
|
-
session=session,
|
|
82
|
-
)
|
|
64
|
+
await _create_tutorial_notebook(user.id, session)
|
|
83
65
|
|
|
84
66
|
return user
|
|
85
67
|
except UserAlreadyExists as error:
|