basic-memory 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- basic_memory/__init__.py +7 -0
- basic_memory/alembic/alembic.ini +119 -0
- basic_memory/alembic/env.py +185 -0
- basic_memory/alembic/migrations.py +24 -0
- basic_memory/alembic/script.py.mako +26 -0
- basic_memory/alembic/versions/314f1ea54dc4_add_postgres_full_text_search_support_.py +131 -0
- basic_memory/alembic/versions/3dae7c7b1564_initial_schema.py +93 -0
- basic_memory/alembic/versions/502b60eaa905_remove_required_from_entity_permalink.py +51 -0
- basic_memory/alembic/versions/5fe1ab1ccebe_add_projects_table.py +120 -0
- basic_memory/alembic/versions/647e7a75e2cd_project_constraint_fix.py +112 -0
- basic_memory/alembic/versions/9d9c1cb7d8f5_add_mtime_and_size_columns_to_entity_.py +49 -0
- basic_memory/alembic/versions/a1b2c3d4e5f6_fix_project_foreign_keys.py +49 -0
- basic_memory/alembic/versions/a2b3c4d5e6f7_add_search_index_entity_cascade.py +56 -0
- basic_memory/alembic/versions/b3c3938bacdb_relation_to_name_unique_index.py +44 -0
- basic_memory/alembic/versions/cc7172b46608_update_search_index_schema.py +113 -0
- basic_memory/alembic/versions/e7e1f4367280_add_scan_watermark_tracking_to_project.py +37 -0
- basic_memory/alembic/versions/f8a9b2c3d4e5_add_pg_trgm_for_fuzzy_link_resolution.py +239 -0
- basic_memory/api/__init__.py +5 -0
- basic_memory/api/app.py +131 -0
- basic_memory/api/routers/__init__.py +11 -0
- basic_memory/api/routers/directory_router.py +84 -0
- basic_memory/api/routers/importer_router.py +152 -0
- basic_memory/api/routers/knowledge_router.py +318 -0
- basic_memory/api/routers/management_router.py +80 -0
- basic_memory/api/routers/memory_router.py +90 -0
- basic_memory/api/routers/project_router.py +448 -0
- basic_memory/api/routers/prompt_router.py +260 -0
- basic_memory/api/routers/resource_router.py +249 -0
- basic_memory/api/routers/search_router.py +36 -0
- basic_memory/api/routers/utils.py +169 -0
- basic_memory/api/template_loader.py +292 -0
- basic_memory/api/v2/__init__.py +35 -0
- basic_memory/api/v2/routers/__init__.py +21 -0
- basic_memory/api/v2/routers/directory_router.py +93 -0
- basic_memory/api/v2/routers/importer_router.py +182 -0
- basic_memory/api/v2/routers/knowledge_router.py +413 -0
- basic_memory/api/v2/routers/memory_router.py +130 -0
- basic_memory/api/v2/routers/project_router.py +342 -0
- basic_memory/api/v2/routers/prompt_router.py +270 -0
- basic_memory/api/v2/routers/resource_router.py +286 -0
- basic_memory/api/v2/routers/search_router.py +73 -0
- basic_memory/cli/__init__.py +1 -0
- basic_memory/cli/app.py +84 -0
- basic_memory/cli/auth.py +277 -0
- basic_memory/cli/commands/__init__.py +18 -0
- basic_memory/cli/commands/cloud/__init__.py +6 -0
- basic_memory/cli/commands/cloud/api_client.py +112 -0
- basic_memory/cli/commands/cloud/bisync_commands.py +110 -0
- basic_memory/cli/commands/cloud/cloud_utils.py +101 -0
- basic_memory/cli/commands/cloud/core_commands.py +195 -0
- basic_memory/cli/commands/cloud/rclone_commands.py +371 -0
- basic_memory/cli/commands/cloud/rclone_config.py +110 -0
- basic_memory/cli/commands/cloud/rclone_installer.py +263 -0
- basic_memory/cli/commands/cloud/upload.py +233 -0
- basic_memory/cli/commands/cloud/upload_command.py +124 -0
- basic_memory/cli/commands/command_utils.py +77 -0
- basic_memory/cli/commands/db.py +44 -0
- basic_memory/cli/commands/format.py +198 -0
- basic_memory/cli/commands/import_chatgpt.py +84 -0
- basic_memory/cli/commands/import_claude_conversations.py +87 -0
- basic_memory/cli/commands/import_claude_projects.py +86 -0
- basic_memory/cli/commands/import_memory_json.py +87 -0
- basic_memory/cli/commands/mcp.py +76 -0
- basic_memory/cli/commands/project.py +889 -0
- basic_memory/cli/commands/status.py +174 -0
- basic_memory/cli/commands/telemetry.py +81 -0
- basic_memory/cli/commands/tool.py +341 -0
- basic_memory/cli/main.py +28 -0
- basic_memory/config.py +616 -0
- basic_memory/db.py +394 -0
- basic_memory/deps.py +705 -0
- basic_memory/file_utils.py +478 -0
- basic_memory/ignore_utils.py +297 -0
- basic_memory/importers/__init__.py +27 -0
- basic_memory/importers/base.py +79 -0
- basic_memory/importers/chatgpt_importer.py +232 -0
- basic_memory/importers/claude_conversations_importer.py +180 -0
- basic_memory/importers/claude_projects_importer.py +148 -0
- basic_memory/importers/memory_json_importer.py +108 -0
- basic_memory/importers/utils.py +61 -0
- basic_memory/markdown/__init__.py +21 -0
- basic_memory/markdown/entity_parser.py +279 -0
- basic_memory/markdown/markdown_processor.py +160 -0
- basic_memory/markdown/plugins.py +242 -0
- basic_memory/markdown/schemas.py +70 -0
- basic_memory/markdown/utils.py +117 -0
- basic_memory/mcp/__init__.py +1 -0
- basic_memory/mcp/async_client.py +139 -0
- basic_memory/mcp/project_context.py +141 -0
- basic_memory/mcp/prompts/__init__.py +19 -0
- basic_memory/mcp/prompts/ai_assistant_guide.py +70 -0
- basic_memory/mcp/prompts/continue_conversation.py +62 -0
- basic_memory/mcp/prompts/recent_activity.py +188 -0
- basic_memory/mcp/prompts/search.py +57 -0
- basic_memory/mcp/prompts/utils.py +162 -0
- basic_memory/mcp/resources/ai_assistant_guide.md +283 -0
- basic_memory/mcp/resources/project_info.py +71 -0
- basic_memory/mcp/server.py +81 -0
- basic_memory/mcp/tools/__init__.py +48 -0
- basic_memory/mcp/tools/build_context.py +120 -0
- basic_memory/mcp/tools/canvas.py +152 -0
- basic_memory/mcp/tools/chatgpt_tools.py +190 -0
- basic_memory/mcp/tools/delete_note.py +242 -0
- basic_memory/mcp/tools/edit_note.py +324 -0
- basic_memory/mcp/tools/list_directory.py +168 -0
- basic_memory/mcp/tools/move_note.py +551 -0
- basic_memory/mcp/tools/project_management.py +201 -0
- basic_memory/mcp/tools/read_content.py +281 -0
- basic_memory/mcp/tools/read_note.py +267 -0
- basic_memory/mcp/tools/recent_activity.py +534 -0
- basic_memory/mcp/tools/search.py +385 -0
- basic_memory/mcp/tools/utils.py +540 -0
- basic_memory/mcp/tools/view_note.py +78 -0
- basic_memory/mcp/tools/write_note.py +230 -0
- basic_memory/models/__init__.py +15 -0
- basic_memory/models/base.py +10 -0
- basic_memory/models/knowledge.py +226 -0
- basic_memory/models/project.py +87 -0
- basic_memory/models/search.py +85 -0
- basic_memory/repository/__init__.py +11 -0
- basic_memory/repository/entity_repository.py +503 -0
- basic_memory/repository/observation_repository.py +73 -0
- basic_memory/repository/postgres_search_repository.py +379 -0
- basic_memory/repository/project_info_repository.py +10 -0
- basic_memory/repository/project_repository.py +128 -0
- basic_memory/repository/relation_repository.py +146 -0
- basic_memory/repository/repository.py +385 -0
- basic_memory/repository/search_index_row.py +95 -0
- basic_memory/repository/search_repository.py +94 -0
- basic_memory/repository/search_repository_base.py +241 -0
- basic_memory/repository/sqlite_search_repository.py +439 -0
- basic_memory/schemas/__init__.py +86 -0
- basic_memory/schemas/base.py +297 -0
- basic_memory/schemas/cloud.py +50 -0
- basic_memory/schemas/delete.py +37 -0
- basic_memory/schemas/directory.py +30 -0
- basic_memory/schemas/importer.py +35 -0
- basic_memory/schemas/memory.py +285 -0
- basic_memory/schemas/project_info.py +212 -0
- basic_memory/schemas/prompt.py +90 -0
- basic_memory/schemas/request.py +112 -0
- basic_memory/schemas/response.py +229 -0
- basic_memory/schemas/search.py +117 -0
- basic_memory/schemas/sync_report.py +72 -0
- basic_memory/schemas/v2/__init__.py +27 -0
- basic_memory/schemas/v2/entity.py +129 -0
- basic_memory/schemas/v2/resource.py +46 -0
- basic_memory/services/__init__.py +8 -0
- basic_memory/services/context_service.py +601 -0
- basic_memory/services/directory_service.py +308 -0
- basic_memory/services/entity_service.py +864 -0
- basic_memory/services/exceptions.py +37 -0
- basic_memory/services/file_service.py +541 -0
- basic_memory/services/initialization.py +216 -0
- basic_memory/services/link_resolver.py +121 -0
- basic_memory/services/project_service.py +880 -0
- basic_memory/services/search_service.py +404 -0
- basic_memory/services/service.py +15 -0
- basic_memory/sync/__init__.py +6 -0
- basic_memory/sync/background_sync.py +26 -0
- basic_memory/sync/sync_service.py +1259 -0
- basic_memory/sync/watch_service.py +510 -0
- basic_memory/telemetry.py +249 -0
- basic_memory/templates/prompts/continue_conversation.hbs +110 -0
- basic_memory/templates/prompts/search.hbs +101 -0
- basic_memory/utils.py +468 -0
- basic_memory-0.17.1.dist-info/METADATA +617 -0
- basic_memory-0.17.1.dist-info/RECORD +171 -0
- basic_memory-0.17.1.dist-info/WHEEL +4 -0
- basic_memory-0.17.1.dist-info/entry_points.txt +3 -0
- basic_memory-0.17.1.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""Base repository implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import Type, Optional, Any, Sequence, TypeVar, List, Dict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from sqlalchemy import (
|
|
8
|
+
select,
|
|
9
|
+
func,
|
|
10
|
+
Select,
|
|
11
|
+
Executable,
|
|
12
|
+
inspect,
|
|
13
|
+
Result,
|
|
14
|
+
and_,
|
|
15
|
+
delete,
|
|
16
|
+
)
|
|
17
|
+
from sqlalchemy.exc import NoResultFound
|
|
18
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
|
19
|
+
from sqlalchemy.orm.interfaces import LoaderOption
|
|
20
|
+
from sqlalchemy.sql.elements import ColumnElement
|
|
21
|
+
|
|
22
|
+
from basic_memory import db
|
|
23
|
+
from basic_memory.models import Base
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T", bound=Base)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Repository[T: Base]:
|
|
29
|
+
"""Base repository implementation with generic CRUD operations."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
34
|
+
Model: Type[T],
|
|
35
|
+
project_id: Optional[int] = None,
|
|
36
|
+
):
|
|
37
|
+
self.session_maker = session_maker
|
|
38
|
+
self.project_id = project_id
|
|
39
|
+
if Model:
|
|
40
|
+
self.Model = Model
|
|
41
|
+
self.mapper = inspect(self.Model).mapper
|
|
42
|
+
self.primary_key: ColumnElement[Any] = self.mapper.primary_key[0]
|
|
43
|
+
self.valid_columns = [column.key for column in self.mapper.columns]
|
|
44
|
+
# Check if this model has a project_id column
|
|
45
|
+
self.has_project_id = "project_id" in self.valid_columns
|
|
46
|
+
|
|
47
|
+
def _set_project_id_if_needed(self, model: T) -> None:
|
|
48
|
+
"""Set project_id on model if needed and available."""
|
|
49
|
+
if (
|
|
50
|
+
self.has_project_id
|
|
51
|
+
and self.project_id is not None
|
|
52
|
+
and getattr(model, "project_id", None) is None
|
|
53
|
+
):
|
|
54
|
+
setattr(model, "project_id", self.project_id)
|
|
55
|
+
|
|
56
|
+
def get_model_data(self, entity_data):
|
|
57
|
+
model_data = {
|
|
58
|
+
k: v for k, v in entity_data.items() if k in self.valid_columns and v is not None
|
|
59
|
+
}
|
|
60
|
+
return model_data
|
|
61
|
+
|
|
62
|
+
def _add_project_filter(self, query: Select) -> Select:
|
|
63
|
+
"""Add project_id filter to query if applicable.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
query: The SQLAlchemy query to modify
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Updated query with project filter if applicable
|
|
70
|
+
"""
|
|
71
|
+
if self.has_project_id and self.project_id is not None:
|
|
72
|
+
query = query.filter(getattr(self.Model, "project_id") == self.project_id)
|
|
73
|
+
return query
|
|
74
|
+
|
|
75
|
+
async def select_by_id(self, session: AsyncSession, entity_id: int) -> Optional[T]:
|
|
76
|
+
"""Select an entity by ID using an existing session."""
|
|
77
|
+
query = (
|
|
78
|
+
select(self.Model)
|
|
79
|
+
.filter(self.primary_key == entity_id)
|
|
80
|
+
.options(*self.get_load_options())
|
|
81
|
+
)
|
|
82
|
+
# Add project filter if applicable
|
|
83
|
+
query = self._add_project_filter(query)
|
|
84
|
+
|
|
85
|
+
result = await session.execute(query)
|
|
86
|
+
return result.scalars().one_or_none()
|
|
87
|
+
|
|
88
|
+
async def select_by_ids(self, session: AsyncSession, ids: List[int]) -> Sequence[T]:
|
|
89
|
+
"""Select multiple entities by IDs using an existing session."""
|
|
90
|
+
query = (
|
|
91
|
+
select(self.Model).where(self.primary_key.in_(ids)).options(*self.get_load_options())
|
|
92
|
+
)
|
|
93
|
+
# Add project filter if applicable
|
|
94
|
+
query = self._add_project_filter(query)
|
|
95
|
+
|
|
96
|
+
result = await session.execute(query)
|
|
97
|
+
return result.scalars().all()
|
|
98
|
+
|
|
99
|
+
async def add(self, model: T) -> T:
|
|
100
|
+
"""
|
|
101
|
+
Add a model to the repository. This will also add related objects
|
|
102
|
+
:param model: the model to add
|
|
103
|
+
:return: the added model instance
|
|
104
|
+
"""
|
|
105
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
106
|
+
# Set project_id if applicable and not already set
|
|
107
|
+
self._set_project_id_if_needed(model)
|
|
108
|
+
|
|
109
|
+
session.add(model)
|
|
110
|
+
await session.flush()
|
|
111
|
+
|
|
112
|
+
# Query within same session
|
|
113
|
+
found = await self.select_by_id(session, model.id) # pyright: ignore [reportAttributeAccessIssue]
|
|
114
|
+
if found is None: # pragma: no cover
|
|
115
|
+
logger.error(
|
|
116
|
+
"Failed to retrieve model after add",
|
|
117
|
+
model_type=self.Model.__name__,
|
|
118
|
+
model_id=model.id, # pyright: ignore
|
|
119
|
+
)
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Can't find {self.Model.__name__} with ID {model.id} after session.add" # pyright: ignore
|
|
122
|
+
)
|
|
123
|
+
return found
|
|
124
|
+
|
|
125
|
+
async def add_all(self, models: List[T]) -> Sequence[T]:
|
|
126
|
+
"""
|
|
127
|
+
Add a list of models to the repository. This will also add related objects
|
|
128
|
+
:param models: the models to add
|
|
129
|
+
:return: the added models instances
|
|
130
|
+
"""
|
|
131
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
132
|
+
# set the project id if not present in models
|
|
133
|
+
for model in models:
|
|
134
|
+
self._set_project_id_if_needed(model)
|
|
135
|
+
|
|
136
|
+
session.add_all(models)
|
|
137
|
+
await session.flush()
|
|
138
|
+
|
|
139
|
+
# Query within same session
|
|
140
|
+
return await self.select_by_ids(session, [m.id for m in models]) # pyright: ignore [reportAttributeAccessIssue]
|
|
141
|
+
|
|
142
|
+
def select(self, *entities: Any) -> Select:
|
|
143
|
+
"""Create a new SELECT statement.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A SQLAlchemy Select object configured with the provided entities
|
|
147
|
+
or this repository's model if no entities provided.
|
|
148
|
+
"""
|
|
149
|
+
if not entities:
|
|
150
|
+
entities = (self.Model,)
|
|
151
|
+
query = select(*entities)
|
|
152
|
+
|
|
153
|
+
# Add project filter if applicable
|
|
154
|
+
return self._add_project_filter(query)
|
|
155
|
+
|
|
156
|
+
async def find_all(
|
|
157
|
+
self, skip: int = 0, limit: Optional[int] = None, use_load_options: bool = True
|
|
158
|
+
) -> Sequence[T]:
|
|
159
|
+
"""Fetch records from the database with pagination.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
skip: Number of records to skip
|
|
163
|
+
limit: Maximum number of records to return
|
|
164
|
+
use_load_options: Whether to apply eager loading options (default: True)
|
|
165
|
+
"""
|
|
166
|
+
logger.debug(f"Finding all {self.Model.__name__} (skip={skip}, limit={limit})")
|
|
167
|
+
|
|
168
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
169
|
+
query = select(self.Model).offset(skip)
|
|
170
|
+
|
|
171
|
+
# Only apply load options if requested
|
|
172
|
+
if use_load_options:
|
|
173
|
+
query = query.options(*self.get_load_options())
|
|
174
|
+
|
|
175
|
+
# Add project filter if applicable
|
|
176
|
+
query = self._add_project_filter(query)
|
|
177
|
+
|
|
178
|
+
if limit:
|
|
179
|
+
query = query.limit(limit)
|
|
180
|
+
|
|
181
|
+
result = await session.execute(query)
|
|
182
|
+
|
|
183
|
+
items = result.scalars().all()
|
|
184
|
+
logger.debug(f"Found {len(items)} {self.Model.__name__} records")
|
|
185
|
+
return items
|
|
186
|
+
|
|
187
|
+
async def find_by_id(self, entity_id: int) -> Optional[T]:
|
|
188
|
+
"""Fetch an entity by its unique identifier."""
|
|
189
|
+
logger.debug(f"Finding {self.Model.__name__} by ID: {entity_id}")
|
|
190
|
+
|
|
191
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
192
|
+
return await self.select_by_id(session, entity_id)
|
|
193
|
+
|
|
194
|
+
async def find_by_ids(self, ids: List[int]) -> Sequence[T]:
|
|
195
|
+
"""Fetch multiple entities by their identifiers in a single query."""
|
|
196
|
+
logger.debug(f"Finding {self.Model.__name__} by IDs: {ids}")
|
|
197
|
+
|
|
198
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
199
|
+
return await self.select_by_ids(session, ids)
|
|
200
|
+
|
|
201
|
+
async def find_one(self, query: Select[tuple[T]]) -> Optional[T]:
|
|
202
|
+
"""Execute a query and retrieve a single record."""
|
|
203
|
+
# add in load options
|
|
204
|
+
query = query.options(*self.get_load_options())
|
|
205
|
+
result = await self.execute_query(query)
|
|
206
|
+
entity = result.scalars().one_or_none()
|
|
207
|
+
|
|
208
|
+
if entity:
|
|
209
|
+
logger.trace(f"Found {self.Model.__name__}: {getattr(entity, 'id', None)}")
|
|
210
|
+
else:
|
|
211
|
+
logger.trace(f"No {self.Model.__name__} found")
|
|
212
|
+
return entity
|
|
213
|
+
|
|
214
|
+
async def create(self, data: dict) -> T:
|
|
215
|
+
"""Create a new record from a model instance."""
|
|
216
|
+
logger.debug(f"Creating {self.Model.__name__} from entity_data: {data}")
|
|
217
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
218
|
+
# Only include valid columns that are provided in entity_data
|
|
219
|
+
model_data = self.get_model_data(data)
|
|
220
|
+
|
|
221
|
+
# Add project_id if applicable and not already provided
|
|
222
|
+
if (
|
|
223
|
+
self.has_project_id
|
|
224
|
+
and self.project_id is not None
|
|
225
|
+
and "project_id" not in model_data
|
|
226
|
+
):
|
|
227
|
+
model_data["project_id"] = self.project_id
|
|
228
|
+
|
|
229
|
+
model = self.Model(**model_data)
|
|
230
|
+
session.add(model)
|
|
231
|
+
await session.flush()
|
|
232
|
+
|
|
233
|
+
return_instance = await self.select_by_id(session, model.id) # pyright: ignore [reportAttributeAccessIssue]
|
|
234
|
+
if return_instance is None: # pragma: no cover
|
|
235
|
+
logger.error(
|
|
236
|
+
"Failed to retrieve model after create",
|
|
237
|
+
model_type=self.Model.__name__,
|
|
238
|
+
model_id=model.id, # pyright: ignore
|
|
239
|
+
)
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Can't find {self.Model.__name__} with ID {model.id} after session.add" # pyright: ignore
|
|
242
|
+
)
|
|
243
|
+
return return_instance
|
|
244
|
+
|
|
245
|
+
async def create_all(self, data_list: List[dict]) -> Sequence[T]:
|
|
246
|
+
"""Create multiple records in a single transaction."""
|
|
247
|
+
logger.debug(f"Bulk creating {len(data_list)} {self.Model.__name__} instances")
|
|
248
|
+
|
|
249
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
250
|
+
# Only include valid columns that are provided in entity_data
|
|
251
|
+
model_list = []
|
|
252
|
+
for d in data_list:
|
|
253
|
+
model_data = self.get_model_data(d)
|
|
254
|
+
|
|
255
|
+
# Add project_id if applicable and not already provided
|
|
256
|
+
if (
|
|
257
|
+
self.has_project_id
|
|
258
|
+
and self.project_id is not None
|
|
259
|
+
and "project_id" not in model_data
|
|
260
|
+
):
|
|
261
|
+
model_data["project_id"] = self.project_id # pragma: no cover
|
|
262
|
+
|
|
263
|
+
model_list.append(self.Model(**model_data))
|
|
264
|
+
|
|
265
|
+
session.add_all(model_list)
|
|
266
|
+
await session.flush()
|
|
267
|
+
|
|
268
|
+
return await self.select_by_ids(session, [model.id for model in model_list]) # pyright: ignore [reportAttributeAccessIssue]
|
|
269
|
+
|
|
270
|
+
async def update(self, entity_id: int, entity_data: dict | T) -> Optional[T]:
|
|
271
|
+
"""Update an entity with the given data."""
|
|
272
|
+
logger.debug(f"Updating {self.Model.__name__} {entity_id} with data: {entity_data}")
|
|
273
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
274
|
+
try:
|
|
275
|
+
result = await session.execute(
|
|
276
|
+
select(self.Model).filter(self.primary_key == entity_id)
|
|
277
|
+
)
|
|
278
|
+
entity = result.scalars().one()
|
|
279
|
+
|
|
280
|
+
if isinstance(entity_data, dict):
|
|
281
|
+
for key, value in entity_data.items():
|
|
282
|
+
if key in self.valid_columns:
|
|
283
|
+
setattr(entity, key, value)
|
|
284
|
+
|
|
285
|
+
elif isinstance(entity_data, self.Model):
|
|
286
|
+
for column in self.Model.__table__.columns.keys():
|
|
287
|
+
setattr(entity, column, getattr(entity_data, column))
|
|
288
|
+
|
|
289
|
+
await session.flush() # Make sure changes are flushed
|
|
290
|
+
await session.refresh(entity) # Refresh
|
|
291
|
+
|
|
292
|
+
logger.debug(f"Updated {self.Model.__name__}: {entity_id}")
|
|
293
|
+
return await self.select_by_id(session, entity.id) # pyright: ignore [reportAttributeAccessIssue]
|
|
294
|
+
|
|
295
|
+
except NoResultFound:
|
|
296
|
+
logger.debug(f"No {self.Model.__name__} found to update: {entity_id}")
|
|
297
|
+
return None
|
|
298
|
+
|
|
299
|
+
async def delete(self, entity_id: int) -> bool:
|
|
300
|
+
"""Delete an entity from the database."""
|
|
301
|
+
logger.debug(f"Deleting {self.Model.__name__}: {entity_id}")
|
|
302
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
303
|
+
try:
|
|
304
|
+
result = await session.execute(
|
|
305
|
+
select(self.Model).filter(self.primary_key == entity_id)
|
|
306
|
+
)
|
|
307
|
+
entity = result.scalars().one()
|
|
308
|
+
await session.delete(entity)
|
|
309
|
+
|
|
310
|
+
logger.debug(f"Deleted {self.Model.__name__}: {entity_id}")
|
|
311
|
+
return True
|
|
312
|
+
except NoResultFound:
|
|
313
|
+
logger.debug(f"No {self.Model.__name__} found to delete: {entity_id}")
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
async def delete_by_ids(self, ids: List[int]) -> int:
|
|
317
|
+
"""Delete records matching given IDs."""
|
|
318
|
+
logger.debug(f"Deleting {self.Model.__name__} by ids: {ids}")
|
|
319
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
320
|
+
conditions = [self.primary_key.in_(ids)]
|
|
321
|
+
|
|
322
|
+
# Add project_id filter if applicable
|
|
323
|
+
if self.has_project_id and self.project_id is not None: # pragma: no cover
|
|
324
|
+
conditions.append(getattr(self.Model, "project_id") == self.project_id)
|
|
325
|
+
|
|
326
|
+
query = delete(self.Model).where(and_(*conditions))
|
|
327
|
+
result = await session.execute(query)
|
|
328
|
+
logger.debug(f"Deleted {result.rowcount} records")
|
|
329
|
+
return result.rowcount
|
|
330
|
+
|
|
331
|
+
async def delete_by_fields(self, **filters: Any) -> bool:
|
|
332
|
+
"""Delete records matching given field values."""
|
|
333
|
+
logger.debug(f"Deleting {self.Model.__name__} by fields: {filters}")
|
|
334
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
335
|
+
conditions = [getattr(self.Model, field) == value for field, value in filters.items()]
|
|
336
|
+
|
|
337
|
+
# Add project_id filter if applicable
|
|
338
|
+
if self.has_project_id and self.project_id is not None:
|
|
339
|
+
conditions.append(getattr(self.Model, "project_id") == self.project_id)
|
|
340
|
+
|
|
341
|
+
query = delete(self.Model).where(and_(*conditions))
|
|
342
|
+
result = await session.execute(query)
|
|
343
|
+
deleted = result.rowcount > 0
|
|
344
|
+
logger.debug(f"Deleted {result.rowcount} records")
|
|
345
|
+
return deleted
|
|
346
|
+
|
|
347
|
+
async def count(self, query: Executable | None = None) -> int:
|
|
348
|
+
"""Count entities in the database table."""
|
|
349
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
350
|
+
if query is None:
|
|
351
|
+
query = select(func.count()).select_from(self.Model)
|
|
352
|
+
# Add project filter if applicable
|
|
353
|
+
if (
|
|
354
|
+
isinstance(query, Select)
|
|
355
|
+
and self.has_project_id
|
|
356
|
+
and self.project_id is not None
|
|
357
|
+
):
|
|
358
|
+
query = query.where(
|
|
359
|
+
getattr(self.Model, "project_id") == self.project_id
|
|
360
|
+
) # pragma: no cover
|
|
361
|
+
|
|
362
|
+
result = await session.execute(query)
|
|
363
|
+
scalar = result.scalar()
|
|
364
|
+
count = scalar if scalar is not None else 0
|
|
365
|
+
logger.debug(f"Counted {count} {self.Model.__name__} records")
|
|
366
|
+
return count
|
|
367
|
+
|
|
368
|
+
async def execute_query(
|
|
369
|
+
self,
|
|
370
|
+
query: Executable,
|
|
371
|
+
params: Optional[Dict[str, Any]] = None,
|
|
372
|
+
use_query_options: bool = True,
|
|
373
|
+
) -> Result[Any]:
|
|
374
|
+
"""Execute a query asynchronously."""
|
|
375
|
+
|
|
376
|
+
query = query.options(*self.get_load_options()) if use_query_options else query
|
|
377
|
+
logger.trace(f"Executing query: {query}, params: {params}")
|
|
378
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
379
|
+
result = await session.execute(query, params)
|
|
380
|
+
return result
|
|
381
|
+
|
|
382
|
+
def get_load_options(self) -> List[LoaderOption]:
|
|
383
|
+
"""Get list of loader options for eager loading relationships.
|
|
384
|
+
Override in subclasses to specify what to load."""
|
|
385
|
+
return []
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Search index data structures."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from basic_memory.schemas.search import SearchItemType
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SearchIndexRow:
|
|
14
|
+
"""Search result with score and metadata."""
|
|
15
|
+
|
|
16
|
+
project_id: int
|
|
17
|
+
id: int
|
|
18
|
+
type: str
|
|
19
|
+
file_path: str
|
|
20
|
+
|
|
21
|
+
# date values
|
|
22
|
+
created_at: datetime
|
|
23
|
+
updated_at: datetime
|
|
24
|
+
|
|
25
|
+
permalink: Optional[str] = None
|
|
26
|
+
metadata: Optional[dict] = None
|
|
27
|
+
|
|
28
|
+
# assigned in result
|
|
29
|
+
score: Optional[float] = None
|
|
30
|
+
|
|
31
|
+
# Type-specific fields
|
|
32
|
+
title: Optional[str] = None # entity
|
|
33
|
+
content_stems: Optional[str] = None # entity, observation
|
|
34
|
+
content_snippet: Optional[str] = None # entity, observation
|
|
35
|
+
entity_id: Optional[int] = None # observations
|
|
36
|
+
category: Optional[str] = None # observations
|
|
37
|
+
from_id: Optional[int] = None # relations
|
|
38
|
+
to_id: Optional[int] = None # relations
|
|
39
|
+
relation_type: Optional[str] = None # relations
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def content(self):
|
|
43
|
+
return self.content_snippet
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def directory(self) -> str:
|
|
47
|
+
"""Extract directory part from file_path.
|
|
48
|
+
|
|
49
|
+
For a file at "projects/notes/ideas.md", returns "/projects/notes"
|
|
50
|
+
For a file at root level "README.md", returns "/"
|
|
51
|
+
"""
|
|
52
|
+
if not self.type == SearchItemType.ENTITY.value and not self.file_path:
|
|
53
|
+
return ""
|
|
54
|
+
|
|
55
|
+
# Normalize path separators to handle both Windows (\) and Unix (/) paths
|
|
56
|
+
normalized_path = Path(self.file_path).as_posix()
|
|
57
|
+
|
|
58
|
+
# Split the path by slashes
|
|
59
|
+
parts = normalized_path.split("/")
|
|
60
|
+
|
|
61
|
+
# If there's only one part (e.g., "README.md"), it's at the root
|
|
62
|
+
if len(parts) <= 1:
|
|
63
|
+
return "/"
|
|
64
|
+
|
|
65
|
+
# Join all parts except the last one (filename)
|
|
66
|
+
directory_path = "/".join(parts[:-1])
|
|
67
|
+
return f"/{directory_path}"
|
|
68
|
+
|
|
69
|
+
def to_insert(self, serialize_json: bool = True):
|
|
70
|
+
"""Convert to dict for database insertion.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
serialize_json: If True, converts metadata dict to JSON string (for SQLite).
|
|
74
|
+
If False, keeps metadata as dict (for Postgres JSONB).
|
|
75
|
+
"""
|
|
76
|
+
return {
|
|
77
|
+
"id": self.id,
|
|
78
|
+
"title": self.title,
|
|
79
|
+
"content_stems": self.content_stems,
|
|
80
|
+
"content_snippet": self.content_snippet,
|
|
81
|
+
"permalink": self.permalink,
|
|
82
|
+
"file_path": self.file_path,
|
|
83
|
+
"type": self.type,
|
|
84
|
+
"metadata": json.dumps(self.metadata)
|
|
85
|
+
if serialize_json and self.metadata
|
|
86
|
+
else self.metadata,
|
|
87
|
+
"from_id": self.from_id,
|
|
88
|
+
"to_id": self.to_id,
|
|
89
|
+
"relation_type": self.relation_type,
|
|
90
|
+
"entity_id": self.entity_id,
|
|
91
|
+
"category": self.category,
|
|
92
|
+
"created_at": self.created_at if self.created_at else None,
|
|
93
|
+
"updated_at": self.updated_at if self.updated_at else None,
|
|
94
|
+
"project_id": self.project_id,
|
|
95
|
+
}
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Repository for search operations.
|
|
2
|
+
|
|
3
|
+
This module provides the search repository interface.
|
|
4
|
+
The actual repository implementations are backend-specific:
|
|
5
|
+
- SQLiteSearchRepository: Uses FTS5 virtual tables
|
|
6
|
+
- PostgresSearchRepository: Uses tsvector/tsquery with GIN indexes
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import List, Optional, Protocol
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import Result
|
|
13
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
14
|
+
|
|
15
|
+
from basic_memory.config import ConfigManager, DatabaseBackend
|
|
16
|
+
from basic_memory.repository.postgres_search_repository import PostgresSearchRepository
|
|
17
|
+
from basic_memory.repository.search_index_row import SearchIndexRow
|
|
18
|
+
from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository
|
|
19
|
+
from basic_memory.schemas.search import SearchItemType
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SearchRepository(Protocol):
|
|
23
|
+
"""Protocol defining the search repository interface.
|
|
24
|
+
|
|
25
|
+
Both SQLite and Postgres implementations must satisfy this protocol.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
project_id: int
|
|
29
|
+
|
|
30
|
+
async def init_search_index(self) -> None:
|
|
31
|
+
"""Initialize the search index schema."""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
async def search(
|
|
35
|
+
self,
|
|
36
|
+
search_text: Optional[str] = None,
|
|
37
|
+
permalink: Optional[str] = None,
|
|
38
|
+
permalink_match: Optional[str] = None,
|
|
39
|
+
title: Optional[str] = None,
|
|
40
|
+
types: Optional[List[str]] = None,
|
|
41
|
+
after_date: Optional[datetime] = None,
|
|
42
|
+
search_item_types: Optional[List[SearchItemType]] = None,
|
|
43
|
+
limit: int = 10,
|
|
44
|
+
offset: int = 0,
|
|
45
|
+
) -> List[SearchIndexRow]:
|
|
46
|
+
"""Search across indexed content."""
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
async def index_item(self, search_index_row: SearchIndexRow) -> None:
|
|
50
|
+
"""Index a single item."""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
async def bulk_index_items(self, search_index_rows: List[SearchIndexRow]) -> None:
|
|
54
|
+
"""Index multiple items in a batch."""
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
async def delete_by_permalink(self, permalink: str) -> None:
|
|
58
|
+
"""Delete item by permalink."""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
async def delete_by_entity_id(self, entity_id: int) -> None:
|
|
62
|
+
"""Delete items by entity ID."""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
async def execute_query(self, query, params: dict) -> Result:
|
|
66
|
+
"""Execute a raw SQL query."""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def create_search_repository(
|
|
71
|
+
session_maker: async_sessionmaker[AsyncSession], project_id: int
|
|
72
|
+
) -> SearchRepository:
|
|
73
|
+
"""Factory function to create the appropriate search repository based on database backend.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
session_maker: SQLAlchemy async session maker
|
|
77
|
+
project_id: Project ID for the repository
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
SearchRepository: Backend-appropriate search repository instance
|
|
81
|
+
"""
|
|
82
|
+
config = ConfigManager().config
|
|
83
|
+
|
|
84
|
+
if config.database_backend == DatabaseBackend.POSTGRES:
|
|
85
|
+
return PostgresSearchRepository(session_maker, project_id=project_id)
|
|
86
|
+
else:
|
|
87
|
+
return SQLiteSearchRepository(session_maker, project_id=project_id)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
__all__ = [
|
|
91
|
+
"SearchRepository",
|
|
92
|
+
"SearchIndexRow",
|
|
93
|
+
"create_search_repository",
|
|
94
|
+
]
|