basic-memory 0.12.3__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of basic-memory might be problematic. Click here for more details.
- basic_memory/__init__.py +2 -1
- basic_memory/alembic/env.py +1 -1
- basic_memory/alembic/versions/5fe1ab1ccebe_add_projects_table.py +108 -0
- basic_memory/alembic/versions/647e7a75e2cd_project_constraint_fix.py +104 -0
- basic_memory/alembic/versions/cc7172b46608_update_search_index_schema.py +0 -6
- basic_memory/api/app.py +43 -13
- basic_memory/api/routers/__init__.py +4 -2
- basic_memory/api/routers/directory_router.py +63 -0
- basic_memory/api/routers/importer_router.py +152 -0
- basic_memory/api/routers/knowledge_router.py +139 -37
- basic_memory/api/routers/management_router.py +78 -0
- basic_memory/api/routers/memory_router.py +6 -62
- basic_memory/api/routers/project_router.py +234 -0
- basic_memory/api/routers/prompt_router.py +260 -0
- basic_memory/api/routers/search_router.py +3 -21
- basic_memory/api/routers/utils.py +130 -0
- basic_memory/api/template_loader.py +292 -0
- basic_memory/cli/app.py +20 -21
- basic_memory/cli/commands/__init__.py +2 -1
- basic_memory/cli/commands/auth.py +136 -0
- basic_memory/cli/commands/db.py +3 -3
- basic_memory/cli/commands/import_chatgpt.py +31 -207
- basic_memory/cli/commands/import_claude_conversations.py +16 -142
- basic_memory/cli/commands/import_claude_projects.py +33 -143
- basic_memory/cli/commands/import_memory_json.py +26 -83
- basic_memory/cli/commands/mcp.py +71 -18
- basic_memory/cli/commands/project.py +102 -70
- basic_memory/cli/commands/status.py +19 -9
- basic_memory/cli/commands/sync.py +44 -58
- basic_memory/cli/commands/tool.py +6 -6
- basic_memory/cli/main.py +1 -5
- basic_memory/config.py +143 -87
- basic_memory/db.py +6 -4
- basic_memory/deps.py +227 -30
- basic_memory/importers/__init__.py +27 -0
- basic_memory/importers/base.py +79 -0
- basic_memory/importers/chatgpt_importer.py +222 -0
- basic_memory/importers/claude_conversations_importer.py +172 -0
- basic_memory/importers/claude_projects_importer.py +148 -0
- basic_memory/importers/memory_json_importer.py +93 -0
- basic_memory/importers/utils.py +58 -0
- basic_memory/markdown/entity_parser.py +5 -2
- basic_memory/mcp/auth_provider.py +270 -0
- basic_memory/mcp/external_auth_provider.py +321 -0
- basic_memory/mcp/project_session.py +103 -0
- basic_memory/mcp/prompts/__init__.py +2 -0
- basic_memory/mcp/prompts/continue_conversation.py +18 -68
- basic_memory/mcp/prompts/recent_activity.py +20 -4
- basic_memory/mcp/prompts/search.py +14 -140
- basic_memory/mcp/prompts/sync_status.py +116 -0
- basic_memory/mcp/prompts/utils.py +3 -3
- basic_memory/mcp/{tools → resources}/project_info.py +6 -2
- basic_memory/mcp/server.py +86 -13
- basic_memory/mcp/supabase_auth_provider.py +463 -0
- basic_memory/mcp/tools/__init__.py +24 -0
- basic_memory/mcp/tools/build_context.py +43 -8
- basic_memory/mcp/tools/canvas.py +17 -3
- basic_memory/mcp/tools/delete_note.py +168 -5
- basic_memory/mcp/tools/edit_note.py +303 -0
- basic_memory/mcp/tools/list_directory.py +154 -0
- basic_memory/mcp/tools/move_note.py +299 -0
- basic_memory/mcp/tools/project_management.py +332 -0
- basic_memory/mcp/tools/read_content.py +15 -6
- basic_memory/mcp/tools/read_note.py +26 -7
- basic_memory/mcp/tools/recent_activity.py +11 -2
- basic_memory/mcp/tools/search.py +189 -8
- basic_memory/mcp/tools/sync_status.py +254 -0
- basic_memory/mcp/tools/utils.py +184 -12
- basic_memory/mcp/tools/view_note.py +66 -0
- basic_memory/mcp/tools/write_note.py +24 -17
- basic_memory/models/__init__.py +3 -2
- basic_memory/models/knowledge.py +16 -4
- basic_memory/models/project.py +78 -0
- basic_memory/models/search.py +8 -5
- basic_memory/repository/__init__.py +2 -0
- basic_memory/repository/entity_repository.py +8 -3
- basic_memory/repository/observation_repository.py +35 -3
- basic_memory/repository/project_info_repository.py +3 -2
- basic_memory/repository/project_repository.py +85 -0
- basic_memory/repository/relation_repository.py +8 -2
- basic_memory/repository/repository.py +107 -15
- basic_memory/repository/search_repository.py +192 -54
- basic_memory/schemas/__init__.py +6 -0
- basic_memory/schemas/base.py +33 -5
- basic_memory/schemas/directory.py +30 -0
- basic_memory/schemas/importer.py +34 -0
- basic_memory/schemas/memory.py +84 -13
- basic_memory/schemas/project_info.py +112 -2
- basic_memory/schemas/prompt.py +90 -0
- basic_memory/schemas/request.py +56 -2
- basic_memory/schemas/search.py +1 -1
- basic_memory/services/__init__.py +2 -1
- basic_memory/services/context_service.py +208 -95
- basic_memory/services/directory_service.py +167 -0
- basic_memory/services/entity_service.py +399 -6
- basic_memory/services/exceptions.py +6 -0
- basic_memory/services/file_service.py +14 -15
- basic_memory/services/initialization.py +170 -66
- basic_memory/services/link_resolver.py +35 -12
- basic_memory/services/migration_service.py +168 -0
- basic_memory/services/project_service.py +671 -0
- basic_memory/services/search_service.py +77 -2
- basic_memory/services/sync_status_service.py +181 -0
- basic_memory/sync/background_sync.py +25 -0
- basic_memory/sync/sync_service.py +102 -21
- basic_memory/sync/watch_service.py +63 -39
- basic_memory/templates/prompts/continue_conversation.hbs +110 -0
- basic_memory/templates/prompts/search.hbs +101 -0
- {basic_memory-0.12.3.dist-info → basic_memory-0.13.0.dist-info}/METADATA +24 -2
- basic_memory-0.13.0.dist-info/RECORD +138 -0
- basic_memory/api/routers/project_info_router.py +0 -274
- basic_memory/mcp/main.py +0 -24
- basic_memory-0.12.3.dist-info/RECORD +0 -100
- {basic_memory-0.12.3.dist-info → basic_memory-0.13.0.dist-info}/WHEEL +0 -0
- {basic_memory-0.12.3.dist-info → basic_memory-0.13.0.dist-info}/entry_points.txt +0 -0
- {basic_memory-0.12.3.dist-info → basic_memory-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Repository for managing Observation objects."""
|
|
2
2
|
|
|
3
|
-
from typing import Sequence
|
|
3
|
+
from typing import Dict, List, Sequence
|
|
4
4
|
|
|
5
5
|
from sqlalchemy import select
|
|
6
6
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
@@ -12,8 +12,14 @@ from basic_memory.repository.repository import Repository
|
|
|
12
12
|
class ObservationRepository(Repository[Observation]):
|
|
13
13
|
"""Repository for Observation model with memory-specific operations."""
|
|
14
14
|
|
|
15
|
-
def __init__(self, session_maker: async_sessionmaker):
|
|
16
|
-
|
|
15
|
+
def __init__(self, session_maker: async_sessionmaker, project_id: int):
|
|
16
|
+
"""Initialize with session maker and project_id filter.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
session_maker: SQLAlchemy session maker
|
|
20
|
+
project_id: Project ID to filter all operations by
|
|
21
|
+
"""
|
|
22
|
+
super().__init__(session_maker, Observation, project_id=project_id)
|
|
17
23
|
|
|
18
24
|
async def find_by_entity(self, entity_id: int) -> Sequence[Observation]:
|
|
19
25
|
"""Find all observations for a specific entity."""
|
|
@@ -38,3 +44,29 @@ class ObservationRepository(Repository[Observation]):
|
|
|
38
44
|
query = select(Observation.category).distinct()
|
|
39
45
|
result = await self.execute_query(query, use_query_options=False)
|
|
40
46
|
return result.scalars().all()
|
|
47
|
+
|
|
48
|
+
async def find_by_entities(self, entity_ids: List[int]) -> Dict[int, List[Observation]]:
|
|
49
|
+
"""Find all observations for multiple entities in a single query.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
entity_ids: List of entity IDs to fetch observations for
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Dictionary mapping entity_id to list of observations
|
|
56
|
+
"""
|
|
57
|
+
if not entity_ids: # pragma: no cover
|
|
58
|
+
return {}
|
|
59
|
+
|
|
60
|
+
# Query observations for all entities in the list
|
|
61
|
+
query = select(Observation).filter(Observation.entity_id.in_(entity_ids))
|
|
62
|
+
result = await self.execute_query(query)
|
|
63
|
+
observations = result.scalars().all()
|
|
64
|
+
|
|
65
|
+
# Group observations by entity_id
|
|
66
|
+
observations_by_entity = {}
|
|
67
|
+
for obs in observations:
|
|
68
|
+
if obs.entity_id not in observations_by_entity:
|
|
69
|
+
observations_by_entity[obs.entity_id] = []
|
|
70
|
+
observations_by_entity[obs.entity_id].append(obs)
|
|
71
|
+
|
|
72
|
+
return observations_by_entity
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from basic_memory.repository.repository import Repository
|
|
2
|
+
from basic_memory.models.project import Project
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class ProjectInfoRepository(Repository):
|
|
5
6
|
"""Repository for statistics queries."""
|
|
6
7
|
|
|
7
8
|
def __init__(self, session_maker):
|
|
8
|
-
# Initialize with
|
|
9
|
-
super().__init__(session_maker,
|
|
9
|
+
# Initialize with Project model as a reference
|
|
10
|
+
super().__init__(session_maker, Project)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Repository for managing projects in Basic Memory."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import text
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
8
|
+
|
|
9
|
+
from basic_memory import db
|
|
10
|
+
from basic_memory.models.project import Project
|
|
11
|
+
from basic_memory.repository.repository import Repository
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ProjectRepository(Repository[Project]):
|
|
15
|
+
"""Repository for Project model.
|
|
16
|
+
|
|
17
|
+
Projects represent collections of knowledge entities grouped together.
|
|
18
|
+
Each entity, observation, and relation belongs to a specific project.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, session_maker: async_sessionmaker[AsyncSession]):
|
|
22
|
+
"""Initialize with session maker."""
|
|
23
|
+
super().__init__(session_maker, Project)
|
|
24
|
+
|
|
25
|
+
async def get_by_name(self, name: str) -> Optional[Project]:
|
|
26
|
+
"""Get project by name.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
name: Unique name of the project
|
|
30
|
+
"""
|
|
31
|
+
query = self.select().where(Project.name == name)
|
|
32
|
+
return await self.find_one(query)
|
|
33
|
+
|
|
34
|
+
async def get_by_permalink(self, permalink: str) -> Optional[Project]:
|
|
35
|
+
"""Get project by permalink.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
permalink: URL-friendly identifier for the project
|
|
39
|
+
"""
|
|
40
|
+
query = self.select().where(Project.permalink == permalink)
|
|
41
|
+
return await self.find_one(query)
|
|
42
|
+
|
|
43
|
+
async def get_by_path(self, path: Union[Path, str]) -> Optional[Project]:
|
|
44
|
+
"""Get project by filesystem path.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
path: Path to the project directory (will be converted to string internally)
|
|
48
|
+
"""
|
|
49
|
+
query = self.select().where(Project.path == str(path))
|
|
50
|
+
return await self.find_one(query)
|
|
51
|
+
|
|
52
|
+
async def get_default_project(self) -> Optional[Project]:
|
|
53
|
+
"""Get the default project (the one marked as is_default=True)."""
|
|
54
|
+
query = self.select().where(Project.is_default.is_not(None))
|
|
55
|
+
return await self.find_one(query)
|
|
56
|
+
|
|
57
|
+
async def get_active_projects(self) -> Sequence[Project]:
|
|
58
|
+
"""Get all active projects."""
|
|
59
|
+
query = self.select().where(Project.is_active == True) # noqa: E712
|
|
60
|
+
result = await self.execute_query(query)
|
|
61
|
+
return list(result.scalars().all())
|
|
62
|
+
|
|
63
|
+
async def set_as_default(self, project_id: int) -> Optional[Project]:
|
|
64
|
+
"""Set a project as the default and unset previous default.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
project_id: ID of the project to set as default
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The updated project if found, None otherwise
|
|
71
|
+
"""
|
|
72
|
+
async with db.scoped_session(self.session_maker) as session:
|
|
73
|
+
# First, clear the default flag for all projects using direct SQL
|
|
74
|
+
await session.execute(
|
|
75
|
+
text("UPDATE project SET is_default = NULL WHERE is_default IS NOT NULL")
|
|
76
|
+
)
|
|
77
|
+
await session.flush()
|
|
78
|
+
|
|
79
|
+
# Set the new default project
|
|
80
|
+
target_project = await self.select_by_id(session, project_id)
|
|
81
|
+
if target_project:
|
|
82
|
+
target_project.is_default = True
|
|
83
|
+
await session.flush()
|
|
84
|
+
return target_project
|
|
85
|
+
return None # pragma: no cover
|
|
@@ -16,8 +16,14 @@ from basic_memory.repository.repository import Repository
|
|
|
16
16
|
class RelationRepository(Repository[Relation]):
|
|
17
17
|
"""Repository for Relation model with memory-specific operations."""
|
|
18
18
|
|
|
19
|
-
def __init__(self, session_maker: async_sessionmaker):
|
|
20
|
-
|
|
19
|
+
def __init__(self, session_maker: async_sessionmaker, project_id: int):
|
|
20
|
+
"""Initialize with session maker and project_id filter.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
session_maker: SQLAlchemy session maker
|
|
24
|
+
project_id: Project ID to filter all operations by
|
|
25
|
+
"""
|
|
26
|
+
super().__init__(session_maker, Relation, project_id=project_id)
|
|
21
27
|
|
|
22
28
|
async def find_relation(
|
|
23
29
|
self, from_permalink: str, to_permalink: str, relation_type: str
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Base repository implementation."""
|
|
2
2
|
|
|
3
|
-
from typing import Type, Optional, Any, Sequence, TypeVar, List
|
|
3
|
+
from typing import Type, Optional, Any, Sequence, TypeVar, List, Dict
|
|
4
4
|
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from sqlalchemy import (
|
|
@@ -27,13 +27,30 @@ T = TypeVar("T", bound=Base)
|
|
|
27
27
|
class Repository[T: Base]:
|
|
28
28
|
"""Base repository implementation with generic CRUD operations."""
|
|
29
29
|
|
|
30
|
-
def __init__(
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
33
|
+
Model: Type[T],
|
|
34
|
+
project_id: Optional[int] = None,
|
|
35
|
+
):
|
|
31
36
|
self.session_maker = session_maker
|
|
37
|
+
self.project_id = project_id
|
|
32
38
|
if Model:
|
|
33
39
|
self.Model = Model
|
|
34
40
|
self.mapper = inspect(self.Model).mapper
|
|
35
41
|
self.primary_key: Column[Any] = self.mapper.primary_key[0]
|
|
36
42
|
self.valid_columns = [column.key for column in self.mapper.columns]
|
|
43
|
+
# Check if this model has a project_id column
|
|
44
|
+
self.has_project_id = "project_id" in self.valid_columns
|
|
45
|
+
|
|
46
|
+
def _set_project_id_if_needed(self, model: T) -> None:
|
|
47
|
+
"""Set project_id on model if needed and available."""
|
|
48
|
+
if (
|
|
49
|
+
self.has_project_id
|
|
50
|
+
and self.project_id is not None
|
|
51
|
+
and getattr(model, "project_id", None) is None
|
|
52
|
+
):
|
|
53
|
+
setattr(model, "project_id", self.project_id)
|
|
37
54
|
|
|
38
55
|
def get_model_data(self, entity_data):
|
|
39
56
|
model_data = {
|
|
@@ -41,6 +58,19 @@ class Repository[T: Base]:
|
|
|
41
58
|
}
|
|
42
59
|
return model_data
|
|
43
60
|
|
|
61
|
+
def _add_project_filter(self, query: Select) -> Select:
|
|
62
|
+
"""Add project_id filter to query if applicable.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
query: The SQLAlchemy query to modify
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Updated query with project filter if applicable
|
|
69
|
+
"""
|
|
70
|
+
if self.has_project_id and self.project_id is not None:
|
|
71
|
+
query = query.filter(getattr(self.Model, "project_id") == self.project_id)
|
|
72
|
+
return query
|
|
73
|
+
|
|
44
74
|
async def select_by_id(self, session: AsyncSession, entity_id: int) -> Optional[T]:
|
|
45
75
|
"""Select an entity by ID using an existing session."""
|
|
46
76
|
query = (
|
|
@@ -48,6 +78,9 @@ class Repository[T: Base]:
|
|
|
48
78
|
.filter(self.primary_key == entity_id)
|
|
49
79
|
.options(*self.get_load_options())
|
|
50
80
|
)
|
|
81
|
+
# Add project filter if applicable
|
|
82
|
+
query = self._add_project_filter(query)
|
|
83
|
+
|
|
51
84
|
result = await session.execute(query)
|
|
52
85
|
return result.scalars().one_or_none()
|
|
53
86
|
|
|
@@ -56,6 +89,9 @@ class Repository[T: Base]:
|
|
|
56
89
|
query = (
|
|
57
90
|
select(self.Model).where(self.primary_key.in_(ids)).options(*self.get_load_options())
|
|
58
91
|
)
|
|
92
|
+
# Add project filter if applicable
|
|
93
|
+
query = self._add_project_filter(query)
|
|
94
|
+
|
|
59
95
|
result = await session.execute(query)
|
|
60
96
|
return result.scalars().all()
|
|
61
97
|
|
|
@@ -66,6 +102,9 @@ class Repository[T: Base]:
|
|
|
66
102
|
:return: the added model instance
|
|
67
103
|
"""
|
|
68
104
|
async with db.scoped_session(self.session_maker) as session:
|
|
105
|
+
# Set project_id if applicable and not already set
|
|
106
|
+
self._set_project_id_if_needed(model)
|
|
107
|
+
|
|
69
108
|
session.add(model)
|
|
70
109
|
await session.flush()
|
|
71
110
|
|
|
@@ -89,6 +128,10 @@ class Repository[T: Base]:
|
|
|
89
128
|
:return: the added models instances
|
|
90
129
|
"""
|
|
91
130
|
async with db.scoped_session(self.session_maker) as session:
|
|
131
|
+
# set the project id if not present in models
|
|
132
|
+
for model in models:
|
|
133
|
+
self._set_project_id_if_needed(model)
|
|
134
|
+
|
|
92
135
|
session.add_all(models)
|
|
93
136
|
await session.flush()
|
|
94
137
|
|
|
@@ -104,7 +147,10 @@ class Repository[T: Base]:
|
|
|
104
147
|
"""
|
|
105
148
|
if not entities:
|
|
106
149
|
entities = (self.Model,)
|
|
107
|
-
|
|
150
|
+
query = select(*entities)
|
|
151
|
+
|
|
152
|
+
# Add project filter if applicable
|
|
153
|
+
return self._add_project_filter(query)
|
|
108
154
|
|
|
109
155
|
async def find_all(self, skip: int = 0, limit: Optional[int] = None) -> Sequence[T]:
|
|
110
156
|
"""Fetch records from the database with pagination."""
|
|
@@ -112,6 +158,9 @@ class Repository[T: Base]:
|
|
|
112
158
|
|
|
113
159
|
async with db.scoped_session(self.session_maker) as session:
|
|
114
160
|
query = select(self.Model).offset(skip).options(*self.get_load_options())
|
|
161
|
+
# Add project filter if applicable
|
|
162
|
+
query = self._add_project_filter(query)
|
|
163
|
+
|
|
115
164
|
if limit:
|
|
116
165
|
query = query.limit(limit)
|
|
117
166
|
|
|
@@ -143,9 +192,9 @@ class Repository[T: Base]:
|
|
|
143
192
|
entity = result.scalars().one_or_none()
|
|
144
193
|
|
|
145
194
|
if entity:
|
|
146
|
-
logger.
|
|
195
|
+
logger.trace(f"Found {self.Model.__name__}: {getattr(entity, 'id', None)}")
|
|
147
196
|
else:
|
|
148
|
-
logger.
|
|
197
|
+
logger.trace(f"No {self.Model.__name__} found")
|
|
149
198
|
return entity
|
|
150
199
|
|
|
151
200
|
async def create(self, data: dict) -> T:
|
|
@@ -154,6 +203,15 @@ class Repository[T: Base]:
|
|
|
154
203
|
async with db.scoped_session(self.session_maker) as session:
|
|
155
204
|
# Only include valid columns that are provided in entity_data
|
|
156
205
|
model_data = self.get_model_data(data)
|
|
206
|
+
|
|
207
|
+
# Add project_id if applicable and not already provided
|
|
208
|
+
if (
|
|
209
|
+
self.has_project_id
|
|
210
|
+
and self.project_id is not None
|
|
211
|
+
and "project_id" not in model_data
|
|
212
|
+
):
|
|
213
|
+
model_data["project_id"] = self.project_id
|
|
214
|
+
|
|
157
215
|
model = self.Model(**model_data)
|
|
158
216
|
session.add(model)
|
|
159
217
|
await session.flush()
|
|
@@ -176,12 +234,20 @@ class Repository[T: Base]:
|
|
|
176
234
|
|
|
177
235
|
async with db.scoped_session(self.session_maker) as session:
|
|
178
236
|
# Only include valid columns that are provided in entity_data
|
|
179
|
-
model_list = [
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
237
|
+
model_list = []
|
|
238
|
+
for d in data_list:
|
|
239
|
+
model_data = self.get_model_data(d)
|
|
240
|
+
|
|
241
|
+
# Add project_id if applicable and not already provided
|
|
242
|
+
if (
|
|
243
|
+
self.has_project_id
|
|
244
|
+
and self.project_id is not None
|
|
245
|
+
and "project_id" not in model_data
|
|
246
|
+
):
|
|
247
|
+
model_data["project_id"] = self.project_id # pragma: no cover
|
|
248
|
+
|
|
249
|
+
model_list.append(self.Model(**model_data))
|
|
250
|
+
|
|
185
251
|
session.add_all(model_list)
|
|
186
252
|
await session.flush()
|
|
187
253
|
|
|
@@ -237,7 +303,13 @@ class Repository[T: Base]:
|
|
|
237
303
|
"""Delete records matching given IDs."""
|
|
238
304
|
logger.debug(f"Deleting {self.Model.__name__} by ids: {ids}")
|
|
239
305
|
async with db.scoped_session(self.session_maker) as session:
|
|
240
|
-
|
|
306
|
+
conditions = [self.primary_key.in_(ids)]
|
|
307
|
+
|
|
308
|
+
# Add project_id filter if applicable
|
|
309
|
+
if self.has_project_id and self.project_id is not None: # pragma: no cover
|
|
310
|
+
conditions.append(getattr(self.Model, "project_id") == self.project_id)
|
|
311
|
+
|
|
312
|
+
query = delete(self.Model).where(and_(*conditions))
|
|
241
313
|
result = await session.execute(query)
|
|
242
314
|
logger.debug(f"Deleted {result.rowcount} records")
|
|
243
315
|
return result.rowcount
|
|
@@ -247,6 +319,11 @@ class Repository[T: Base]:
|
|
|
247
319
|
logger.debug(f"Deleting {self.Model.__name__} by fields: {filters}")
|
|
248
320
|
async with db.scoped_session(self.session_maker) as session:
|
|
249
321
|
conditions = [getattr(self.Model, field) == value for field, value in filters.items()]
|
|
322
|
+
|
|
323
|
+
# Add project_id filter if applicable
|
|
324
|
+
if self.has_project_id and self.project_id is not None:
|
|
325
|
+
conditions.append(getattr(self.Model, "project_id") == self.project_id)
|
|
326
|
+
|
|
250
327
|
query = delete(self.Model).where(and_(*conditions))
|
|
251
328
|
result = await session.execute(query)
|
|
252
329
|
deleted = result.rowcount > 0
|
|
@@ -258,19 +335,34 @@ class Repository[T: Base]:
|
|
|
258
335
|
async with db.scoped_session(self.session_maker) as session:
|
|
259
336
|
if query is None:
|
|
260
337
|
query = select(func.count()).select_from(self.Model)
|
|
338
|
+
# Add project filter if applicable
|
|
339
|
+
if (
|
|
340
|
+
isinstance(query, Select)
|
|
341
|
+
and self.has_project_id
|
|
342
|
+
and self.project_id is not None
|
|
343
|
+
):
|
|
344
|
+
query = query.where(
|
|
345
|
+
getattr(self.Model, "project_id") == self.project_id
|
|
346
|
+
) # pragma: no cover
|
|
347
|
+
|
|
261
348
|
result = await session.execute(query)
|
|
262
349
|
scalar = result.scalar()
|
|
263
350
|
count = scalar if scalar is not None else 0
|
|
264
351
|
logger.debug(f"Counted {count} {self.Model.__name__} records")
|
|
265
352
|
return count
|
|
266
353
|
|
|
267
|
-
async def execute_query(
|
|
354
|
+
async def execute_query(
|
|
355
|
+
self,
|
|
356
|
+
query: Executable,
|
|
357
|
+
params: Optional[Dict[str, Any]] = None,
|
|
358
|
+
use_query_options: bool = True,
|
|
359
|
+
) -> Result[Any]:
|
|
268
360
|
"""Execute a query asynchronously."""
|
|
269
361
|
|
|
270
362
|
query = query.options(*self.get_load_options()) if use_query_options else query
|
|
271
|
-
logger.
|
|
363
|
+
logger.trace(f"Executing query: {query}, params: {params}")
|
|
272
364
|
async with db.scoped_session(self.session_maker) as session:
|
|
273
|
-
result = await session.execute(query)
|
|
365
|
+
result = await session.execute(query, params)
|
|
274
366
|
return result
|
|
275
367
|
|
|
276
368
|
def get_load_options(self) -> List[LoaderOption]:
|