cognee 0.3.4.dev1__py3-none-any.whl → 0.3.4.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. cognee/api/v1/cloud/routers/get_checks_router.py +1 -1
  2. cognee/api/v1/prune/prune.py +2 -2
  3. cognee/api/v1/sync/sync.py +16 -5
  4. cognee/base_config.py +15 -0
  5. cognee/infrastructure/databases/graph/kuzu/remote_kuzu_adapter.py +4 -1
  6. cognee/infrastructure/databases/vector/embeddings/OllamaEmbeddingEngine.py +4 -1
  7. cognee/infrastructure/files/storage/LocalFileStorage.py +50 -0
  8. cognee/infrastructure/files/storage/S3FileStorage.py +56 -9
  9. cognee/infrastructure/files/storage/StorageManager.py +18 -0
  10. cognee/modules/cloud/operations/check_api_key.py +4 -1
  11. cognee/modules/data/deletion/prune_system.py +5 -1
  12. cognee/modules/notebooks/methods/create_notebook.py +32 -0
  13. cognee/modules/notebooks/models/Notebook.py +206 -1
  14. cognee/modules/retrieval/temporal_retriever.py +2 -2
  15. cognee/modules/users/methods/create_user.py +5 -23
  16. cognee/root_dir.py +5 -0
  17. cognee/shared/cache.py +346 -0
  18. cognee/shared/utils.py +12 -0
  19. cognee/tasks/ingestion/save_data_item_to_storage.py +1 -0
  20. cognee/tests/unit/modules/users/test_tutorial_notebook_creation.py +399 -0
  21. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/METADATA +2 -1
  22. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/RECORD +26 -24
  23. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/WHEEL +0 -0
  24. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/entry_points.txt +0 -0
  25. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/licenses/LICENSE +0 -0
  26. {cognee-0.3.4.dev1.dist-info → cognee-0.3.4.dev3.dist-info}/licenses/NOTICE.md +0 -0
@@ -16,7 +16,7 @@ def get_checks_router():
16
16
  api_token = request.headers.get("X-Api-Key")
17
17
 
18
18
  if api_token is None:
19
- return CloudApiKeyMissingError()
19
+ raise CloudApiKeyMissingError()
20
20
 
21
21
  return await check_api_key(api_token)
22
22
 
@@ -7,8 +7,8 @@ class prune:
7
7
  await _prune_data()
8
8
 
9
9
  @staticmethod
10
- async def prune_system(graph=True, vector=True, metadata=False):
11
- await _prune_system(graph, vector, metadata)
10
+ async def prune_system(graph=True, vector=True, metadata=False, cache=True):
11
+ await _prune_system(graph, vector, metadata, cache)
12
12
 
13
13
 
14
14
  if __name__ == "__main__":
@@ -23,6 +23,7 @@ from cognee.modules.sync.methods import (
23
23
  mark_sync_completed,
24
24
  mark_sync_failed,
25
25
  )
26
+ from cognee.shared.utils import create_secure_ssl_context
26
27
 
27
28
  logger = get_logger("sync")
28
29
 
@@ -583,7 +584,9 @@ async def _check_hashes_diff(
583
584
  logger.info(f"Checking missing hashes on cloud for dataset {dataset.id}")
584
585
 
585
586
  try:
586
- async with aiohttp.ClientSession() as session:
587
+ ssl_context = create_secure_ssl_context()
588
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
589
+ async with aiohttp.ClientSession(connector=connector) as session:
587
590
  async with session.post(url, json=payload.dict(), headers=headers) as response:
588
591
  if response.status == 200:
589
592
  data = await response.json()
@@ -630,7 +633,9 @@ async def _download_missing_files(
630
633
 
631
634
  headers = {"X-Api-Key": auth_token}
632
635
 
633
- async with aiohttp.ClientSession() as session:
636
+ ssl_context = create_secure_ssl_context()
637
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
638
+ async with aiohttp.ClientSession(connector=connector) as session:
634
639
  for file_hash in hashes_missing_on_local:
635
640
  try:
636
641
  # Download file from cloud by hash
@@ -749,7 +754,9 @@ async def _upload_missing_files(
749
754
 
750
755
  headers = {"X-Api-Key": auth_token}
751
756
 
752
- async with aiohttp.ClientSession() as session:
757
+ ssl_context = create_secure_ssl_context()
758
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
759
+ async with aiohttp.ClientSession(connector=connector) as session:
753
760
  for file_info in files_to_upload:
754
761
  try:
755
762
  file_dir = os.path.dirname(file_info.raw_data_location)
@@ -809,7 +816,9 @@ async def _prune_cloud_dataset(
809
816
  logger.info("Pruning cloud dataset to match local state")
810
817
 
811
818
  try:
812
- async with aiohttp.ClientSession() as session:
819
+ ssl_context = create_secure_ssl_context()
820
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
821
+ async with aiohttp.ClientSession(connector=connector) as session:
813
822
  async with session.put(url, json=payload.dict(), headers=headers) as response:
814
823
  if response.status == 200:
815
824
  data = await response.json()
@@ -852,7 +861,9 @@ async def _trigger_remote_cognify(
852
861
  logger.info(f"Triggering cognify processing for dataset {dataset_id}")
853
862
 
854
863
  try:
855
- async with aiohttp.ClientSession() as session:
864
+ ssl_context = create_secure_ssl_context()
865
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
866
+ async with aiohttp.ClientSession(connector=connector) as session:
856
867
  async with session.post(url, json=payload, headers=headers) as response:
857
868
  if response.status == 200:
858
869
  data = await response.json()
cognee/base_config.py CHANGED
@@ -10,13 +10,27 @@ import pydantic
10
10
  class BaseConfig(BaseSettings):
11
11
  data_root_directory: str = get_absolute_path(".data_storage")
12
12
  system_root_directory: str = get_absolute_path(".cognee_system")
13
+ cache_root_directory: str = get_absolute_path(".cognee_cache")
13
14
  monitoring_tool: object = Observer.LANGFUSE
14
15
 
15
16
  @pydantic.model_validator(mode="after")
16
17
  def validate_paths(self):
18
+ # Adding this here temporarily to ensure that the cache root directory is set correctly for S3 storage automatically
19
+ # I'll remove this after we update documentation for S3 storage
20
+ # Auto-configure cache root directory for S3 storage if not explicitly set
21
+ storage_backend = os.getenv("STORAGE_BACKEND", "").lower()
22
+ cache_root_env = os.getenv("CACHE_ROOT_DIRECTORY")
23
+
24
+ if storage_backend == "s3" and not cache_root_env:
25
+ # Auto-generate S3 cache path when using S3 storage
26
+ bucket_name = os.getenv("STORAGE_BUCKET_NAME")
27
+ if bucket_name:
28
+ self.cache_root_directory = f"s3://{bucket_name}/cognee/cache"
29
+
17
30
  # Require absolute paths for root directories
18
31
  self.data_root_directory = ensure_absolute_path(self.data_root_directory)
19
32
  self.system_root_directory = ensure_absolute_path(self.system_root_directory)
33
+ self.cache_root_directory = ensure_absolute_path(self.cache_root_directory)
20
34
  return self
21
35
 
22
36
  langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
@@ -31,6 +45,7 @@ class BaseConfig(BaseSettings):
31
45
  "data_root_directory": self.data_root_directory,
32
46
  "system_root_directory": self.system_root_directory,
33
47
  "monitoring_tool": self.monitoring_tool,
48
+ "cache_root_directory": self.cache_root_directory,
34
49
  }
35
50
 
36
51
 
@@ -7,6 +7,7 @@ import aiohttp
7
7
  from uuid import UUID
8
8
 
9
9
  from cognee.infrastructure.databases.graph.kuzu.adapter import KuzuAdapter
10
+ from cognee.shared.utils import create_secure_ssl_context
10
11
 
11
12
  logger = get_logger()
12
13
 
@@ -42,7 +43,9 @@ class RemoteKuzuAdapter(KuzuAdapter):
42
43
  async def _get_session(self) -> aiohttp.ClientSession:
43
44
  """Get or create an aiohttp session."""
44
45
  if self._session is None or self._session.closed:
45
- self._session = aiohttp.ClientSession()
46
+ ssl_context = create_secure_ssl_context()
47
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
48
+ self._session = aiohttp.ClientSession(connector=connector)
46
49
  return self._session
47
50
 
48
51
  async def close(self):
@@ -14,6 +14,7 @@ from cognee.infrastructure.databases.vector.embeddings.embedding_rate_limiter im
14
14
  embedding_rate_limit_async,
15
15
  embedding_sleep_and_retry_async,
16
16
  )
17
+ from cognee.shared.utils import create_secure_ssl_context
17
18
 
18
19
  logger = get_logger("OllamaEmbeddingEngine")
19
20
 
@@ -101,7 +102,9 @@ class OllamaEmbeddingEngine(EmbeddingEngine):
101
102
  if api_key:
102
103
  headers["Authorization"] = f"Bearer {api_key}"
103
104
 
104
- async with aiohttp.ClientSession() as session:
105
+ ssl_context = create_secure_ssl_context()
106
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
107
+ async with aiohttp.ClientSession(connector=connector) as session:
105
108
  async with session.post(
106
109
  self.endpoint, json=payload, headers=headers, timeout=60.0
107
110
  ) as response:
@@ -253,6 +253,56 @@ class LocalFileStorage(Storage):
253
253
  if os.path.exists(full_file_path):
254
254
  os.remove(full_file_path)
255
255
 
256
+ def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
257
+ """
258
+ List all files in the specified directory.
259
+
260
+ Parameters:
261
+ -----------
262
+ - directory_path (str): The directory path to list files from
263
+ - recursive (bool): If True, list files recursively in subdirectories
264
+
265
+ Returns:
266
+ --------
267
+ - list[str]: List of file paths relative to the storage root
268
+ """
269
+ from pathlib import Path
270
+
271
+ parsed_storage_path = get_parsed_path(self.storage_path)
272
+
273
+ if directory_path:
274
+ full_directory_path = os.path.join(parsed_storage_path, directory_path)
275
+ else:
276
+ full_directory_path = parsed_storage_path
277
+
278
+ directory_pathlib = Path(full_directory_path)
279
+
280
+ if not directory_pathlib.exists() or not directory_pathlib.is_dir():
281
+ return []
282
+
283
+ files = []
284
+
285
+ if recursive:
286
+ # Use rglob for recursive search
287
+ for file_path in directory_pathlib.rglob("*"):
288
+ if file_path.is_file():
289
+ # Get relative path from storage root
290
+ relative_path = os.path.relpath(str(file_path), parsed_storage_path)
291
+ # Normalize path separators for consistency
292
+ relative_path = relative_path.replace(os.sep, "/")
293
+ files.append(relative_path)
294
+ else:
295
+ # Use iterdir for just immediate directory
296
+ for file_path in directory_pathlib.iterdir():
297
+ if file_path.is_file():
298
+ # Get relative path from storage root
299
+ relative_path = os.path.relpath(str(file_path), parsed_storage_path)
300
+ # Normalize path separators for consistency
301
+ relative_path = relative_path.replace(os.sep, "/")
302
+ files.append(relative_path)
303
+
304
+ return files
305
+
256
306
  def remove_all(self, tree_path: str = None):
257
307
  """
258
308
  Remove an entire directory tree at the specified path, including all files and
@@ -155,21 +155,19 @@ class S3FileStorage(Storage):
155
155
  """
156
156
  Ensure that the specified directory exists, creating it if necessary.
157
157
 
158
- If the directory already exists, no action is taken.
158
+ For S3 storage, this is a no-op since directories are created implicitly
159
+ when files are written to paths. S3 doesn't have actual directories,
160
+ just object keys with prefixes that appear as directories.
159
161
 
160
162
  Parameters:
161
163
  -----------
162
164
 
163
165
  - directory_path (str): The path of the directory to check or create.
164
166
  """
165
- if not directory_path.strip():
166
- directory_path = self.storage_path.replace("s3://", "")
167
-
168
- def ensure_directory():
169
- if not self.s3.exists(directory_path):
170
- self.s3.makedirs(directory_path, exist_ok=True)
171
-
172
- await run_async(ensure_directory)
167
+ # In S3, directories don't exist as separate entities - they're just prefixes
168
+ # When you write a file to s3://bucket/path/to/file.txt, the "directories"
169
+ # path/ and path/to/ are implicitly created. No explicit action needed.
170
+ pass
173
171
 
174
172
  async def copy_file(self, source_file_path: str, destination_file_path: str):
175
173
  """
@@ -213,6 +211,55 @@ class S3FileStorage(Storage):
213
211
 
214
212
  await run_async(remove_file)
215
213
 
214
+ async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
215
+ """
216
+ List all files in the specified directory.
217
+
218
+ Parameters:
219
+ -----------
220
+ - directory_path (str): The directory path to list files from
221
+ - recursive (bool): If True, list files recursively in subdirectories
222
+
223
+ Returns:
224
+ --------
225
+ - list[str]: List of file paths relative to the storage root
226
+ """
227
+
228
+ def list_files_sync():
229
+ if directory_path:
230
+ # Combine storage path with directory path
231
+ full_path = os.path.join(self.storage_path.replace("s3://", ""), directory_path)
232
+ else:
233
+ full_path = self.storage_path.replace("s3://", "")
234
+
235
+ if recursive:
236
+ # Use ** for recursive search
237
+ pattern = f"{full_path}/**"
238
+ else:
239
+ # Just files in the immediate directory
240
+ pattern = f"{full_path}/*"
241
+
242
+ # Use s3fs glob to find files
243
+ try:
244
+ all_paths = self.s3.glob(pattern)
245
+ # Filter to only files (not directories)
246
+ files = [path for path in all_paths if self.s3.isfile(path)]
247
+
248
+ # Convert back to relative paths from storage root
249
+ storage_prefix = self.storage_path.replace("s3://", "")
250
+ relative_files = []
251
+ for file_path in files:
252
+ if file_path.startswith(storage_prefix):
253
+ relative_path = file_path[len(storage_prefix) :].lstrip("/")
254
+ relative_files.append(relative_path)
255
+
256
+ return relative_files
257
+ except Exception:
258
+ # If directory doesn't exist or other error, return empty list
259
+ return []
260
+
261
+ return await run_async(list_files_sync)
262
+
216
263
  async def remove_all(self, tree_path: str):
217
264
  """
218
265
  Remove an entire directory tree at the specified path, including all files and
@@ -135,6 +135,24 @@ class StorageManager:
135
135
  else:
136
136
  return self.storage.remove(file_path)
137
137
 
138
+ async def list_files(self, directory_path: str, recursive: bool = False) -> list[str]:
139
+ """
140
+ List all files in the specified directory.
141
+
142
+ Parameters:
143
+ -----------
144
+ - directory_path (str): The directory path to list files from
145
+ - recursive (bool): If True, list files recursively in subdirectories
146
+
147
+ Returns:
148
+ --------
149
+ - list[str]: List of file paths relative to the storage root
150
+ """
151
+ if inspect.iscoroutinefunction(self.storage.list_files):
152
+ return await self.storage.list_files(directory_path, recursive)
153
+ else:
154
+ return self.storage.list_files(directory_path, recursive)
155
+
138
156
  async def remove_all(self, tree_path: str = None):
139
157
  """
140
158
  Remove an entire directory tree at the specified path, including all files and
@@ -1,6 +1,7 @@
1
1
  import aiohttp
2
2
 
3
3
  from cognee.modules.cloud.exceptions import CloudConnectionError
4
+ from cognee.shared.utils import create_secure_ssl_context
4
5
 
5
6
 
6
7
  async def check_api_key(auth_token: str):
@@ -10,7 +11,9 @@ async def check_api_key(auth_token: str):
10
11
  headers = {"X-Api-Key": auth_token}
11
12
 
12
13
  try:
13
- async with aiohttp.ClientSession() as session:
14
+ ssl_context = create_secure_ssl_context()
15
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
16
+ async with aiohttp.ClientSession(connector=connector) as session:
14
17
  async with session.post(url, headers=headers) as response:
15
18
  if response.status == 200:
16
19
  return
@@ -1,9 +1,10 @@
1
1
  from cognee.infrastructure.databases.vector import get_vector_engine
2
2
  from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
3
3
  from cognee.infrastructure.databases.relational import get_relational_engine
4
+ from cognee.shared.cache import delete_cache
4
5
 
5
6
 
6
- async def prune_system(graph=True, vector=True, metadata=True):
7
+ async def prune_system(graph=True, vector=True, metadata=True, cache=True):
7
8
  if graph:
8
9
  graph_engine = await get_graph_engine()
9
10
  await graph_engine.delete_graph()
@@ -15,3 +16,6 @@ async def prune_system(graph=True, vector=True, metadata=True):
15
16
  if metadata:
16
17
  db_engine = get_relational_engine()
17
18
  await db_engine.delete_database()
19
+
20
+ if cache:
21
+ await delete_cache()
@@ -7,6 +7,38 @@ from cognee.infrastructure.databases.relational import with_async_session
7
7
  from ..models.Notebook import Notebook, NotebookCell
8
8
 
9
9
 
10
+ async def _create_tutorial_notebook(
11
+ user_id: UUID, session: AsyncSession, force_refresh: bool = False
12
+ ) -> None:
13
+ """
14
+ Create the default tutorial notebook for new users.
15
+ Dynamically fetches from: https://github.com/topoteretes/cognee/blob/notebook_tutorial/notebooks/starter_tutorial.zip
16
+ """
17
+ TUTORIAL_ZIP_URL = (
18
+ "https://github.com/topoteretes/cognee/raw/notebook_tutorial/notebooks/starter_tutorial.zip"
19
+ )
20
+
21
+ try:
22
+ # Create notebook from remote zip file (includes notebook + data files)
23
+ notebook = await Notebook.from_ipynb_zip_url(
24
+ zip_url=TUTORIAL_ZIP_URL,
25
+ owner_id=user_id,
26
+ notebook_filename="tutorial.ipynb",
27
+ name="Python Development with Cognee Tutorial 🧠",
28
+ deletable=False,
29
+ force=force_refresh,
30
+ )
31
+
32
+ # Add to session and commit
33
+ session.add(notebook)
34
+ await session.commit()
35
+
36
+ except Exception as e:
37
+ print(f"Failed to fetch tutorial notebook from {TUTORIAL_ZIP_URL}: {e}")
38
+
39
+ raise e
40
+
41
+
10
42
  @with_async_session
11
43
  async def create_notebook(
12
44
  user_id: UUID,
@@ -1,13 +1,24 @@
1
1
  import json
2
- from typing import List, Literal
2
+ import nbformat
3
+ import asyncio
4
+ from nbformat.notebooknode import NotebookNode
5
+ from typing import List, Literal, Optional, cast, Tuple
3
6
  from uuid import uuid4, UUID as UUID_t
4
7
  from pydantic import BaseModel, ConfigDict
5
8
  from datetime import datetime, timezone
6
9
  from fastapi.encoders import jsonable_encoder
7
10
  from sqlalchemy import Boolean, Column, DateTime, JSON, UUID, String, TypeDecorator
8
11
  from sqlalchemy.orm import mapped_column, Mapped
12
+ from pathlib import Path
9
13
 
10
14
  from cognee.infrastructure.databases.relational import Base
15
+ from cognee.shared.cache import (
16
+ download_and_extract_zip,
17
+ get_tutorial_data_dir,
18
+ generate_content_hash,
19
+ )
20
+ from cognee.infrastructure.files.storage.get_file_storage import get_file_storage
21
+ from cognee.base_config import get_base_config
11
22
 
12
23
 
13
24
  class NotebookCell(BaseModel):
@@ -51,3 +62,197 @@ class Notebook(Base):
51
62
  deletable: Mapped[bool] = mapped_column(Boolean, default=True)
52
63
 
53
64
  created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
65
+
66
+ @classmethod
67
+ async def from_ipynb_zip_url(
68
+ cls,
69
+ zip_url: str,
70
+ owner_id: UUID_t,
71
+ notebook_filename: str = "tutorial.ipynb",
72
+ name: Optional[str] = None,
73
+ deletable: bool = True,
74
+ force: bool = False,
75
+ ) -> "Notebook":
76
+ """
77
+ Create a Notebook instance from a remote zip file containing notebook + data files.
78
+
79
+ Args:
80
+ zip_url: Remote URL to fetch the .zip file from
81
+ owner_id: UUID of the notebook owner
82
+ notebook_filename: Name of the .ipynb file within the zip
83
+ name: Optional custom name for the notebook
84
+ deletable: Whether the notebook can be deleted
85
+ force: If True, re-download even if already cached
86
+
87
+ Returns:
88
+ Notebook instance
89
+ """
90
+ # Generate a cache key based on the zip URL
91
+ content_hash = generate_content_hash(zip_url, notebook_filename)
92
+
93
+ # Download and extract the zip file to tutorial_data/{content_hash}
94
+ try:
95
+ extracted_cache_dir = await download_and_extract_zip(
96
+ url=zip_url,
97
+ cache_dir_name=f"tutorial_data/{content_hash}",
98
+ version_or_hash=content_hash,
99
+ force=force,
100
+ )
101
+ except Exception as e:
102
+ raise RuntimeError(f"Failed to download tutorial zip from {zip_url}") from e
103
+
104
+ # Use cache system to access the notebook file
105
+ from cognee.shared.cache import cache_file_exists, read_cache_file
106
+
107
+ notebook_file_path = f"{extracted_cache_dir}/{notebook_filename}"
108
+
109
+ # Check if the notebook file exists in cache
110
+ if not await cache_file_exists(notebook_file_path):
111
+ raise FileNotFoundError(f"Notebook file '{notebook_filename}' not found in zip")
112
+
113
+ # Read and parse the notebook using cache system
114
+ async with await read_cache_file(notebook_file_path, encoding="utf-8") as f:
115
+ notebook_content = await asyncio.to_thread(f.read)
116
+ notebook = cls.from_ipynb_string(notebook_content, owner_id, name, deletable)
117
+
118
+ # Update file paths in notebook cells to point to actual cached data files
119
+ await cls._update_file_paths_in_cells(notebook, extracted_cache_dir)
120
+
121
+ return notebook
122
+
123
+ @staticmethod
124
+ async def _update_file_paths_in_cells(notebook: "Notebook", cache_dir: str) -> None:
125
+ """
126
+ Update file paths in code cells to use actual cached data files.
127
+ Works with both local filesystem and S3 storage.
128
+
129
+ Args:
130
+ notebook: Parsed Notebook instance with cells to update
131
+ cache_dir: Path to the cached tutorial directory containing data files
132
+ """
133
+ import re
134
+ from cognee.shared.cache import list_cache_files, cache_file_exists
135
+ from cognee.shared.logging_utils import get_logger
136
+
137
+ logger = get_logger()
138
+
139
+ # Look for data files in the data subdirectory
140
+ data_dir = f"{cache_dir}/data"
141
+
142
+ try:
143
+ # Get all data files in the cache directory using cache system
144
+ data_files = {}
145
+ if await cache_file_exists(data_dir):
146
+ file_list = await list_cache_files(data_dir)
147
+ else:
148
+ file_list = []
149
+
150
+ for file_path in file_list:
151
+ # Extract just the filename
152
+ filename = file_path.split("/")[-1]
153
+ # Use the file path as provided by cache system
154
+ data_files[filename] = file_path
155
+
156
+ except Exception as e:
157
+ # If we can't list files, skip updating paths
158
+ logger.error(f"Error listing data files in {data_dir}: {e}")
159
+ return
160
+
161
+ # Pattern to match file://data/filename patterns in code cells
162
+ file_pattern = r'"file://data/([^"]+)"'
163
+
164
+ def replace_path(match):
165
+ filename = match.group(1)
166
+ if filename in data_files:
167
+ file_path = data_files[filename]
168
+ # For local filesystem, preserve file:// prefix
169
+ if not file_path.startswith("s3://"):
170
+ return f'"file://{file_path}"'
171
+ else:
172
+ # For S3, return the S3 URL as-is
173
+ return f'"{file_path}"'
174
+ return match.group(0) # Keep original if file not found
175
+
176
+ # Update only code cells
177
+ updated_cells = 0
178
+ for cell in notebook.cells:
179
+ if cell.type == "code":
180
+ original_content = cell.content
181
+ # Update file paths in the cell content
182
+ cell.content = re.sub(file_pattern, replace_path, cell.content)
183
+ if original_content != cell.content:
184
+ updated_cells += 1
185
+
186
+ # Log summary of updates (useful for monitoring)
187
+ if updated_cells > 0:
188
+ logger.info(f"Updated file paths in {updated_cells} notebook cells")
189
+
190
+ @classmethod
191
+ def from_ipynb_string(
192
+ cls,
193
+ notebook_content: str,
194
+ owner_id: UUID_t,
195
+ name: Optional[str] = None,
196
+ deletable: bool = True,
197
+ ) -> "Notebook":
198
+ """
199
+ Create a Notebook instance from Jupyter notebook string content.
200
+
201
+ Args:
202
+ notebook_content: Raw Jupyter notebook content as string
203
+ owner_id: UUID of the notebook owner
204
+ name: Optional custom name for the notebook
205
+ deletable: Whether the notebook can be deleted
206
+
207
+ Returns:
208
+ Notebook instance ready to be saved to database
209
+ """
210
+ # Parse and validate the Jupyter notebook using nbformat
211
+ # Note: nbformat.reads() has loose typing, so we cast to NotebookNode
212
+ jupyter_nb = cast(
213
+ NotebookNode, nbformat.reads(notebook_content, as_version=nbformat.NO_CONVERT)
214
+ )
215
+
216
+ # Convert Jupyter cells to NotebookCell objects
217
+ cells = []
218
+ for jupyter_cell in jupyter_nb.cells:
219
+ # Each cell is also a NotebookNode with dynamic attributes
220
+ cell = cast(NotebookNode, jupyter_cell)
221
+ # Skip raw cells as they're not supported in our model
222
+ if cell.cell_type == "raw":
223
+ continue
224
+
225
+ # Get the source content
226
+ content = cell.source
227
+
228
+ # Generate a name based on content or cell index
229
+ cell_name = cls._generate_cell_name(cell)
230
+
231
+ # Map cell types (jupyter uses "code"/"markdown", we use same)
232
+ cell_type = "code" if cell.cell_type == "code" else "markdown"
233
+
234
+ cells.append(NotebookCell(id=uuid4(), type=cell_type, name=cell_name, content=content))
235
+
236
+ # Extract notebook name from metadata if not provided
237
+ if name is None:
238
+ kernelspec = jupyter_nb.metadata.get("kernelspec", {})
239
+ name = kernelspec.get("display_name") or kernelspec.get("name", "Imported Notebook")
240
+
241
+ return cls(id=uuid4(), owner_id=owner_id, name=name, cells=cells, deletable=deletable)
242
+
243
+ @staticmethod
244
+ def _generate_cell_name(jupyter_cell: NotebookNode) -> str:
245
+ """Generate a meaningful name for a notebook cell using nbformat cell."""
246
+ if jupyter_cell.cell_type == "markdown":
247
+ # Try to extract a title from markdown headers
248
+ content = jupyter_cell.source
249
+
250
+ lines = content.strip().split("\n")
251
+ if lines and lines[0].startswith("#"):
252
+ # Extract header text, clean it up
253
+ header = lines[0].lstrip("#").strip()
254
+ return header[:50] if len(header) > 50 else header
255
+ else:
256
+ return "Markdown Cell"
257
+ else:
258
+ return "Code Cell"
@@ -113,7 +113,7 @@ class TemporalRetriever(GraphCompletionRetriever):
113
113
  logger.info(
114
114
  "No timestamps identified based on the query, performing retrieval using triplet search on events and entities."
115
115
  )
116
- triplets = await self.get_context(query)
116
+ triplets = await self.get_triplets(query)
117
117
  return await self.resolve_edges_to_text(triplets)
118
118
 
119
119
  if ids:
@@ -122,7 +122,7 @@ class TemporalRetriever(GraphCompletionRetriever):
122
122
  logger.info(
123
123
  "No events identified based on timestamp filtering, performing retrieval using triplet search on events and entities."
124
124
  )
125
- triplets = await self.get_context(query)
125
+ triplets = await self.get_triplets(query)
126
126
  return await self.resolve_edges_to_text(triplets)
127
127
 
128
128
  vector_engine = get_vector_engine()