basic-memory 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of basic-memory might be problematic. Click here for more details.
- basic_memory/__init__.py +3 -0
- basic_memory/api/__init__.py +4 -0
- basic_memory/api/app.py +42 -0
- basic_memory/api/routers/__init__.py +8 -0
- basic_memory/api/routers/knowledge_router.py +168 -0
- basic_memory/api/routers/memory_router.py +123 -0
- basic_memory/api/routers/resource_router.py +34 -0
- basic_memory/api/routers/search_router.py +34 -0
- basic_memory/cli/__init__.py +1 -0
- basic_memory/cli/app.py +4 -0
- basic_memory/cli/commands/__init__.py +9 -0
- basic_memory/cli/commands/init.py +38 -0
- basic_memory/cli/commands/status.py +152 -0
- basic_memory/cli/commands/sync.py +254 -0
- basic_memory/cli/main.py +48 -0
- basic_memory/config.py +53 -0
- basic_memory/db.py +135 -0
- basic_memory/deps.py +182 -0
- basic_memory/file_utils.py +248 -0
- basic_memory/markdown/__init__.py +19 -0
- basic_memory/markdown/entity_parser.py +137 -0
- basic_memory/markdown/markdown_processor.py +153 -0
- basic_memory/markdown/plugins.py +236 -0
- basic_memory/markdown/schemas.py +73 -0
- basic_memory/markdown/utils.py +144 -0
- basic_memory/mcp/__init__.py +1 -0
- basic_memory/mcp/async_client.py +10 -0
- basic_memory/mcp/main.py +21 -0
- basic_memory/mcp/server.py +39 -0
- basic_memory/mcp/tools/__init__.py +34 -0
- basic_memory/mcp/tools/ai_edit.py +84 -0
- basic_memory/mcp/tools/knowledge.py +56 -0
- basic_memory/mcp/tools/memory.py +142 -0
- basic_memory/mcp/tools/notes.py +122 -0
- basic_memory/mcp/tools/search.py +28 -0
- basic_memory/mcp/tools/utils.py +154 -0
- basic_memory/models/__init__.py +12 -0
- basic_memory/models/base.py +9 -0
- basic_memory/models/knowledge.py +204 -0
- basic_memory/models/search.py +34 -0
- basic_memory/repository/__init__.py +7 -0
- basic_memory/repository/entity_repository.py +156 -0
- basic_memory/repository/observation_repository.py +40 -0
- basic_memory/repository/relation_repository.py +78 -0
- basic_memory/repository/repository.py +303 -0
- basic_memory/repository/search_repository.py +259 -0
- basic_memory/schemas/__init__.py +73 -0
- basic_memory/schemas/base.py +216 -0
- basic_memory/schemas/delete.py +38 -0
- basic_memory/schemas/discovery.py +25 -0
- basic_memory/schemas/memory.py +111 -0
- basic_memory/schemas/request.py +77 -0
- basic_memory/schemas/response.py +220 -0
- basic_memory/schemas/search.py +117 -0
- basic_memory/services/__init__.py +11 -0
- basic_memory/services/context_service.py +274 -0
- basic_memory/services/entity_service.py +281 -0
- basic_memory/services/exceptions.py +15 -0
- basic_memory/services/file_service.py +213 -0
- basic_memory/services/link_resolver.py +126 -0
- basic_memory/services/search_service.py +218 -0
- basic_memory/services/service.py +36 -0
- basic_memory/sync/__init__.py +5 -0
- basic_memory/sync/file_change_scanner.py +162 -0
- basic_memory/sync/sync_service.py +140 -0
- basic_memory/sync/utils.py +66 -0
- basic_memory/sync/watch_service.py +197 -0
- basic_memory/utils.py +78 -0
- basic_memory-0.0.0.dist-info/METADATA +71 -0
- basic_memory-0.0.0.dist-info/RECORD +73 -0
- basic_memory-0.0.0.dist-info/WHEEL +4 -0
- basic_memory-0.0.0.dist-info/entry_points.txt +2 -0
- basic_memory-0.0.0.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""Command module for basic-memory sync operations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Dict
|
|
8
|
+
|
|
9
|
+
import typer
|
|
10
|
+
from loguru import logger
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich.padding import Padding
|
|
13
|
+
from rich.panel import Panel
|
|
14
|
+
from rich.text import Text
|
|
15
|
+
from rich.tree import Tree
|
|
16
|
+
|
|
17
|
+
from basic_memory import db
|
|
18
|
+
from basic_memory.cli.app import app
|
|
19
|
+
from basic_memory.config import config
|
|
20
|
+
from basic_memory.db import DatabaseType
|
|
21
|
+
from basic_memory.markdown import EntityParser
|
|
22
|
+
from basic_memory.markdown.markdown_processor import MarkdownProcessor
|
|
23
|
+
from basic_memory.repository import (
|
|
24
|
+
EntityRepository,
|
|
25
|
+
ObservationRepository,
|
|
26
|
+
RelationRepository,
|
|
27
|
+
)
|
|
28
|
+
from basic_memory.repository.search_repository import SearchRepository
|
|
29
|
+
from basic_memory.services import EntityService, FileService
|
|
30
|
+
from basic_memory.services.link_resolver import LinkResolver
|
|
31
|
+
from basic_memory.services.search_service import SearchService
|
|
32
|
+
from basic_memory.sync import SyncService, FileChangeScanner
|
|
33
|
+
from basic_memory.sync.utils import SyncReport
|
|
34
|
+
from basic_memory.sync.watch_service import WatchService
|
|
35
|
+
|
|
36
|
+
console = Console()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ValidationIssue:
|
|
41
|
+
file_path: str
|
|
42
|
+
error: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
async def get_sync_service(db_type=DatabaseType.FILESYSTEM):
|
|
46
|
+
"""Get sync service instance with all dependencies."""
|
|
47
|
+
async with db.engine_session_factory(db_path=config.database_path, db_type=db_type) as (
|
|
48
|
+
engine,
|
|
49
|
+
session_maker,
|
|
50
|
+
):
|
|
51
|
+
entity_parser = EntityParser(config.home)
|
|
52
|
+
markdown_processor = MarkdownProcessor(entity_parser)
|
|
53
|
+
file_service = FileService(config.home, markdown_processor)
|
|
54
|
+
|
|
55
|
+
# Initialize repositories
|
|
56
|
+
entity_repository = EntityRepository(session_maker)
|
|
57
|
+
observation_repository = ObservationRepository(session_maker)
|
|
58
|
+
relation_repository = RelationRepository(session_maker)
|
|
59
|
+
search_repository = SearchRepository(session_maker)
|
|
60
|
+
|
|
61
|
+
# Initialize services
|
|
62
|
+
search_service = SearchService(search_repository, entity_repository, file_service)
|
|
63
|
+
link_resolver = LinkResolver(entity_repository, search_service)
|
|
64
|
+
|
|
65
|
+
# Initialize scanner
|
|
66
|
+
file_change_scanner = FileChangeScanner(entity_repository)
|
|
67
|
+
|
|
68
|
+
# Initialize services
|
|
69
|
+
entity_service = EntityService(
|
|
70
|
+
entity_parser,
|
|
71
|
+
entity_repository,
|
|
72
|
+
observation_repository,
|
|
73
|
+
relation_repository,
|
|
74
|
+
file_service,
|
|
75
|
+
link_resolver,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Create sync service
|
|
79
|
+
sync_service = SyncService(
|
|
80
|
+
scanner=file_change_scanner,
|
|
81
|
+
entity_service=entity_service,
|
|
82
|
+
entity_parser=entity_parser,
|
|
83
|
+
entity_repository=entity_repository,
|
|
84
|
+
relation_repository=relation_repository,
|
|
85
|
+
search_service=search_service,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return sync_service
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def group_issues_by_directory(issues: List[ValidationIssue]) -> Dict[str, List[ValidationIssue]]:
|
|
92
|
+
"""Group validation issues by directory."""
|
|
93
|
+
grouped = defaultdict(list)
|
|
94
|
+
for issue in issues:
|
|
95
|
+
dir_name = Path(issue.file_path).parent.name
|
|
96
|
+
grouped[dir_name].append(issue)
|
|
97
|
+
return dict(grouped)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def display_validation_errors(issues: List[ValidationIssue]):
|
|
101
|
+
"""Display validation errors in a rich, organized format."""
|
|
102
|
+
# Create header
|
|
103
|
+
console.print()
|
|
104
|
+
console.print(
|
|
105
|
+
Panel("[red bold]Error:[/red bold] Invalid frontmatter in knowledge files", expand=False)
|
|
106
|
+
)
|
|
107
|
+
console.print()
|
|
108
|
+
|
|
109
|
+
# Group issues by directory
|
|
110
|
+
grouped_issues = group_issues_by_directory(issues)
|
|
111
|
+
|
|
112
|
+
# Create tree structure
|
|
113
|
+
tree = Tree("Knowledge Files")
|
|
114
|
+
for dir_name, dir_issues in sorted(grouped_issues.items()):
|
|
115
|
+
# Create branch for directory
|
|
116
|
+
branch = tree.add(
|
|
117
|
+
f"[bold blue]{dir_name}/[/bold blue] ([yellow]{len(dir_issues)} files[/yellow])"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Add each file issue
|
|
121
|
+
for issue in sorted(dir_issues, key=lambda x: x.file_path):
|
|
122
|
+
file_name = Path(issue.file_path).name
|
|
123
|
+
branch.add(
|
|
124
|
+
Text.assemble(("└─ ", "dim"), (file_name, "yellow"), ": ", (issue.error, "red"))
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Display tree
|
|
128
|
+
console.print(Padding(tree, (1, 2)))
|
|
129
|
+
|
|
130
|
+
# Add help text
|
|
131
|
+
console.print()
|
|
132
|
+
console.print(
|
|
133
|
+
Panel(
|
|
134
|
+
Text.assemble(
|
|
135
|
+
("To fix:", "bold"),
|
|
136
|
+
"\n1. Add required frontmatter fields to each file",
|
|
137
|
+
"\n2. Run ",
|
|
138
|
+
("basic-memory sync", "bold cyan"),
|
|
139
|
+
" again",
|
|
140
|
+
),
|
|
141
|
+
expand=False,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
console.print()
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def display_sync_summary(knowledge: SyncReport):
|
|
148
|
+
"""Display a one-line summary of sync changes."""
|
|
149
|
+
total_changes = knowledge.total_changes
|
|
150
|
+
if total_changes == 0:
|
|
151
|
+
console.print("[green]Everything up to date[/green]")
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
# Format as: "Synced X files (A new, B modified, C moved, D deleted)"
|
|
155
|
+
changes = []
|
|
156
|
+
new_count = len(knowledge.new)
|
|
157
|
+
mod_count = len(knowledge.modified)
|
|
158
|
+
move_count = len(knowledge.moves)
|
|
159
|
+
del_count = len(knowledge.deleted)
|
|
160
|
+
|
|
161
|
+
if new_count:
|
|
162
|
+
changes.append(f"[green]{new_count} new[/green]")
|
|
163
|
+
if mod_count:
|
|
164
|
+
changes.append(f"[yellow]{mod_count} modified[/yellow]")
|
|
165
|
+
if move_count:
|
|
166
|
+
changes.append(f"[blue]{move_count} moved[/blue]")
|
|
167
|
+
if del_count:
|
|
168
|
+
changes.append(f"[red]{del_count} deleted[/red]")
|
|
169
|
+
|
|
170
|
+
console.print(f"Synced {total_changes} files ({', '.join(changes)})")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def display_detailed_sync_results(knowledge: SyncReport):
|
|
174
|
+
"""Display detailed sync results with trees."""
|
|
175
|
+
if knowledge.total_changes == 0:
|
|
176
|
+
console.print("\n[green]Everything up to date[/green]")
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
console.print("\n[bold]Sync Results[/bold]")
|
|
180
|
+
|
|
181
|
+
if knowledge.total_changes > 0:
|
|
182
|
+
knowledge_tree = Tree("[bold]Knowledge Files[/bold]")
|
|
183
|
+
if knowledge.new:
|
|
184
|
+
created = knowledge_tree.add("[green]Created[/green]")
|
|
185
|
+
for path in sorted(knowledge.new):
|
|
186
|
+
checksum = knowledge.checksums.get(path, "")
|
|
187
|
+
created.add(f"[green]{path}[/green] ({checksum[:8]})")
|
|
188
|
+
if knowledge.modified:
|
|
189
|
+
modified = knowledge_tree.add("[yellow]Modified[/yellow]")
|
|
190
|
+
for path in sorted(knowledge.modified):
|
|
191
|
+
checksum = knowledge.checksums.get(path, "")
|
|
192
|
+
modified.add(f"[yellow]{path}[/yellow] ({checksum[:8]})")
|
|
193
|
+
if knowledge.moves:
|
|
194
|
+
moved = knowledge_tree.add("[blue]Moved[/blue]")
|
|
195
|
+
for old_path, new_path in sorted(knowledge.moves.items()):
|
|
196
|
+
checksum = knowledge.checksums.get(new_path, "")
|
|
197
|
+
moved.add(f"[blue]{old_path}[/blue] → [blue]{new_path}[/blue] ({checksum[:8]})")
|
|
198
|
+
if knowledge.deleted:
|
|
199
|
+
deleted = knowledge_tree.add("[red]Deleted[/red]")
|
|
200
|
+
for path in sorted(knowledge.deleted):
|
|
201
|
+
deleted.add(f"[red]{path}[/red]")
|
|
202
|
+
console.print(knowledge_tree)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
async def run_sync(verbose: bool = False, watch: bool = False):
|
|
206
|
+
"""Run sync operation."""
|
|
207
|
+
|
|
208
|
+
sync_service = await get_sync_service()
|
|
209
|
+
|
|
210
|
+
# Start watching if requested
|
|
211
|
+
if watch:
|
|
212
|
+
watch_service = WatchService(
|
|
213
|
+
sync_service=sync_service,
|
|
214
|
+
file_service=sync_service.entity_service.file_service,
|
|
215
|
+
config=config
|
|
216
|
+
)
|
|
217
|
+
await watch_service.handle_changes(config.home)
|
|
218
|
+
await watch_service.run()
|
|
219
|
+
else:
|
|
220
|
+
# one time sync
|
|
221
|
+
knowledge_changes = await sync_service.sync(config.home)
|
|
222
|
+
# Display results
|
|
223
|
+
if verbose:
|
|
224
|
+
display_detailed_sync_results(knowledge_changes)
|
|
225
|
+
else:
|
|
226
|
+
display_sync_summary(knowledge_changes)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
@app.command()
|
|
230
|
+
def sync(
|
|
231
|
+
verbose: bool = typer.Option(
|
|
232
|
+
False,
|
|
233
|
+
"--verbose",
|
|
234
|
+
"-v",
|
|
235
|
+
help="Show detailed sync information.",
|
|
236
|
+
),
|
|
237
|
+
watch: bool = typer.Option(
|
|
238
|
+
False,
|
|
239
|
+
"--watch",
|
|
240
|
+
"-w",
|
|
241
|
+
help="Start watching for changes after sync.",
|
|
242
|
+
),
|
|
243
|
+
) -> None:
|
|
244
|
+
"""Sync knowledge files with the database."""
|
|
245
|
+
try:
|
|
246
|
+
# Run sync
|
|
247
|
+
asyncio.run(run_sync(verbose=verbose, watch=watch))
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
if not isinstance(e, typer.Exit):
|
|
251
|
+
logger.exception("Sync failed")
|
|
252
|
+
typer.echo(f"Error during sync: {e}", err=True)
|
|
253
|
+
raise typer.Exit(1)
|
|
254
|
+
raise
|
basic_memory/cli/main.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Main CLI entry point for basic-memory."""
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import typer
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from basic_memory.cli.app import app
|
|
8
|
+
from basic_memory.cli.commands.init import init
|
|
9
|
+
|
|
10
|
+
# Register commands
|
|
11
|
+
from basic_memory.cli.commands import init, status, sync
|
|
12
|
+
__all__ = ["init", "status", "sync"]
|
|
13
|
+
|
|
14
|
+
from basic_memory.config import config
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_logging(home_dir: str = config.home, log_file: str = "./basic-memory/basic-memory-tools.log"):
|
|
18
|
+
"""Configure logging for the application."""
|
|
19
|
+
|
|
20
|
+
# Remove default handler and any existing handlers
|
|
21
|
+
logger.remove()
|
|
22
|
+
|
|
23
|
+
# Add file handler for debug level logs
|
|
24
|
+
log = f"{home_dir}/{log_file}"
|
|
25
|
+
logger.add(
|
|
26
|
+
log,
|
|
27
|
+
level="DEBUG",
|
|
28
|
+
rotation="100 MB",
|
|
29
|
+
retention="10 days",
|
|
30
|
+
backtrace=True,
|
|
31
|
+
diagnose=True,
|
|
32
|
+
enqueue=True,
|
|
33
|
+
colorize=False,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Add stderr handler for warnings and errors only
|
|
37
|
+
logger.add(
|
|
38
|
+
sys.stderr,
|
|
39
|
+
level="WARNING",
|
|
40
|
+
backtrace=True,
|
|
41
|
+
diagnose=True
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Set up logging when module is imported
|
|
45
|
+
setup_logging()
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__": # pragma: no cover
|
|
48
|
+
app()
|
basic_memory/config.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Configuration management for basic-memory."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from pydantic import Field, field_validator
|
|
7
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
8
|
+
|
|
9
|
+
DATABASE_NAME = "memory.db"
|
|
10
|
+
DATA_DIR_NAME = ".basic-memory"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProjectConfig(BaseSettings):
|
|
14
|
+
"""Configuration for a specific basic-memory project."""
|
|
15
|
+
|
|
16
|
+
# Default to ~/basic-memory but allow override with env var: BASIC_MEMORY_HOME
|
|
17
|
+
home: Path = Field(
|
|
18
|
+
default_factory=lambda: Path.home() / "basic-memory",
|
|
19
|
+
description="Base path for basic-memory files",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# Name of the project
|
|
23
|
+
project: str = Field(default="default", description="Project name")
|
|
24
|
+
|
|
25
|
+
# Watch service configuration
|
|
26
|
+
sync_delay: int = Field(
|
|
27
|
+
default=500, description="Milliseconds to wait after changes before syncing", gt=0
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
model_config = SettingsConfigDict(
|
|
31
|
+
env_prefix="BASIC_MEMORY_",
|
|
32
|
+
extra="ignore",
|
|
33
|
+
env_file=".env",
|
|
34
|
+
env_file_encoding="utf-8",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def database_path(self) -> Path:
|
|
39
|
+
"""Get SQLite database path."""
|
|
40
|
+
return self.home / DATA_DIR_NAME / DATABASE_NAME
|
|
41
|
+
|
|
42
|
+
@field_validator("home")
|
|
43
|
+
@classmethod
|
|
44
|
+
def ensure_path_exists(cls, v: Path) -> Path:
|
|
45
|
+
"""Ensure project path exists."""
|
|
46
|
+
if not v.exists():
|
|
47
|
+
v.mkdir(parents=True)
|
|
48
|
+
return v
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Load project config
|
|
52
|
+
config = ProjectConfig()
|
|
53
|
+
logger.info(f"project config home: {config.home}")
|
basic_memory/db.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import AsyncGenerator, Optional
|
|
6
|
+
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from sqlalchemy import text
|
|
9
|
+
from sqlalchemy.ext.asyncio import (
|
|
10
|
+
create_async_engine,
|
|
11
|
+
async_sessionmaker,
|
|
12
|
+
AsyncSession,
|
|
13
|
+
AsyncEngine,
|
|
14
|
+
async_scoped_session,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from basic_memory.models import Base
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Module level state
|
|
21
|
+
_engine: Optional[AsyncEngine] = None
|
|
22
|
+
_session_maker: Optional[async_sessionmaker[AsyncSession]] = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DatabaseType(Enum):
|
|
26
|
+
"""Types of supported databases."""
|
|
27
|
+
|
|
28
|
+
MEMORY = auto()
|
|
29
|
+
FILESYSTEM = auto()
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def get_db_url(cls, db_path: Path, db_type: "DatabaseType") -> str:
|
|
33
|
+
"""Get SQLAlchemy URL for database path."""
|
|
34
|
+
if db_type == cls.MEMORY:
|
|
35
|
+
logger.info("Using in-memory SQLite database")
|
|
36
|
+
return "sqlite+aiosqlite://"
|
|
37
|
+
|
|
38
|
+
return f"sqlite+aiosqlite:///{db_path}"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_scoped_session_factory(
|
|
42
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
43
|
+
) -> async_scoped_session:
|
|
44
|
+
"""Create a scoped session factory scoped to current task."""
|
|
45
|
+
return async_scoped_session(session_maker, scopefunc=asyncio.current_task)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@asynccontextmanager
|
|
49
|
+
async def scoped_session(
|
|
50
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
51
|
+
) -> AsyncGenerator[AsyncSession, None]:
|
|
52
|
+
"""
|
|
53
|
+
Get a scoped session with proper lifecycle management.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
session_maker: Session maker to create scoped sessions from
|
|
57
|
+
"""
|
|
58
|
+
factory = get_scoped_session_factory(session_maker)
|
|
59
|
+
session = factory()
|
|
60
|
+
try:
|
|
61
|
+
await session.execute(text("PRAGMA foreign_keys=ON"))
|
|
62
|
+
yield session
|
|
63
|
+
await session.commit()
|
|
64
|
+
except Exception:
|
|
65
|
+
await session.rollback()
|
|
66
|
+
raise
|
|
67
|
+
finally:
|
|
68
|
+
await session.close()
|
|
69
|
+
await factory.remove()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
async def init_db(session: AsyncSession):
|
|
73
|
+
"""Initialize database with required tables."""
|
|
74
|
+
await session.execute(text("PRAGMA foreign_keys=ON"))
|
|
75
|
+
conn = await session.connection()
|
|
76
|
+
await conn.run_sync(Base.metadata.create_all)
|
|
77
|
+
await session.commit()
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
async def get_or_create_db(
|
|
81
|
+
db_path: Path,
|
|
82
|
+
db_type: DatabaseType = DatabaseType.FILESYSTEM,
|
|
83
|
+
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
|
|
84
|
+
"""Get or create database engine and session maker."""
|
|
85
|
+
global _engine, _session_maker
|
|
86
|
+
|
|
87
|
+
if _engine is None:
|
|
88
|
+
db_url = DatabaseType.get_db_url(db_path, db_type)
|
|
89
|
+
logger.debug(f"Creating engine for db_url: {db_url}")
|
|
90
|
+
_engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
|
|
91
|
+
_session_maker = async_sessionmaker(_engine, expire_on_commit=False)
|
|
92
|
+
|
|
93
|
+
# Initialize database
|
|
94
|
+
logger.debug("Initializing database...")
|
|
95
|
+
async with scoped_session(_session_maker) as db_session:
|
|
96
|
+
await init_db(db_session)
|
|
97
|
+
|
|
98
|
+
return _engine, _session_maker
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def shutdown_db():
|
|
102
|
+
"""Clean up database connections."""
|
|
103
|
+
global _engine, _session_maker
|
|
104
|
+
|
|
105
|
+
if _engine:
|
|
106
|
+
await _engine.dispose()
|
|
107
|
+
_engine = None
|
|
108
|
+
_session_maker = None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@asynccontextmanager
|
|
112
|
+
async def engine_session_factory(
|
|
113
|
+
db_path: Path,
|
|
114
|
+
db_type: DatabaseType = DatabaseType.FILESYSTEM,
|
|
115
|
+
init: bool = True,
|
|
116
|
+
) -> AsyncGenerator[tuple[AsyncEngine, async_sessionmaker[AsyncSession]], None]:
|
|
117
|
+
"""Create engine and session factory.
|
|
118
|
+
|
|
119
|
+
Note: This is primarily used for testing where we want a fresh database
|
|
120
|
+
for each test. For production use, use get_or_create_db() instead.
|
|
121
|
+
"""
|
|
122
|
+
db_url = DatabaseType.get_db_url(db_path, db_type)
|
|
123
|
+
logger.debug(f"Creating engine for db_url: {db_url}")
|
|
124
|
+
engine = create_async_engine(db_url, connect_args={"check_same_thread": False})
|
|
125
|
+
try:
|
|
126
|
+
factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
127
|
+
|
|
128
|
+
if init:
|
|
129
|
+
logger.debug("Initializing database...")
|
|
130
|
+
async with scoped_session(factory) as db_session:
|
|
131
|
+
await init_db(db_session)
|
|
132
|
+
|
|
133
|
+
yield engine, factory
|
|
134
|
+
finally:
|
|
135
|
+
await engine.dispose()
|
basic_memory/deps.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Dependency injection functions for basic-memory services."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import Depends
|
|
6
|
+
from sqlalchemy.ext.asyncio import (
|
|
7
|
+
AsyncSession,
|
|
8
|
+
AsyncEngine,
|
|
9
|
+
async_sessionmaker,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from basic_memory import db
|
|
13
|
+
from basic_memory.config import ProjectConfig, config
|
|
14
|
+
from basic_memory.markdown import EntityParser
|
|
15
|
+
from basic_memory.markdown.markdown_processor import MarkdownProcessor
|
|
16
|
+
from basic_memory.repository.entity_repository import EntityRepository
|
|
17
|
+
from basic_memory.repository.observation_repository import ObservationRepository
|
|
18
|
+
from basic_memory.repository.relation_repository import RelationRepository
|
|
19
|
+
from basic_memory.repository.search_repository import SearchRepository
|
|
20
|
+
from basic_memory.services import (
|
|
21
|
+
EntityService,
|
|
22
|
+
)
|
|
23
|
+
from basic_memory.services.context_service import ContextService
|
|
24
|
+
from basic_memory.services.file_service import FileService
|
|
25
|
+
from basic_memory.services.link_resolver import LinkResolver
|
|
26
|
+
from basic_memory.services.search_service import SearchService
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
## project
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_project_config() -> ProjectConfig:
|
|
33
|
+
return config
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
ProjectConfigDep = Annotated[ProjectConfig, Depends(get_project_config)]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
## sqlalchemy
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
async def get_engine_factory(
|
|
43
|
+
project_config: ProjectConfigDep,
|
|
44
|
+
) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]:
|
|
45
|
+
"""Get engine and session maker."""
|
|
46
|
+
return await db.get_or_create_db(project_config.database_path)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
EngineFactoryDep = Annotated[
|
|
50
|
+
tuple[AsyncEngine, async_sessionmaker[AsyncSession]], Depends(get_engine_factory)
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def get_session_maker(engine_factory: EngineFactoryDep) -> async_sessionmaker[AsyncSession]:
|
|
55
|
+
"""Get session maker."""
|
|
56
|
+
_, session_maker = engine_factory
|
|
57
|
+
return session_maker
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
SessionMakerDep = Annotated[async_sessionmaker, Depends(get_session_maker)]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
## repositories
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def get_entity_repository(
|
|
67
|
+
session_maker: SessionMakerDep,
|
|
68
|
+
) -> EntityRepository:
|
|
69
|
+
"""Create an EntityRepository instance."""
|
|
70
|
+
return EntityRepository(session_maker)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
EntityRepositoryDep = Annotated[EntityRepository, Depends(get_entity_repository)]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def get_observation_repository(
|
|
77
|
+
session_maker: SessionMakerDep,
|
|
78
|
+
) -> ObservationRepository:
|
|
79
|
+
"""Create an ObservationRepository instance."""
|
|
80
|
+
return ObservationRepository(session_maker)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
ObservationRepositoryDep = Annotated[ObservationRepository, Depends(get_observation_repository)]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def get_relation_repository(
|
|
87
|
+
session_maker: SessionMakerDep,
|
|
88
|
+
) -> RelationRepository:
|
|
89
|
+
"""Create a RelationRepository instance."""
|
|
90
|
+
return RelationRepository(session_maker)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
RelationRepositoryDep = Annotated[RelationRepository, Depends(get_relation_repository)]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
async def get_search_repository(
|
|
97
|
+
session_maker: SessionMakerDep,
|
|
98
|
+
) -> SearchRepository:
|
|
99
|
+
"""Create a SearchRepository instance."""
|
|
100
|
+
return SearchRepository(session_maker)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
SearchRepositoryDep = Annotated[SearchRepository, Depends(get_search_repository)]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
## services
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def get_entity_parser(project_config: ProjectConfigDep) -> EntityParser:
|
|
110
|
+
return EntityParser(project_config.home)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
EntityParserDep = Annotated["EntityParser", Depends(get_entity_parser)]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
async def get_markdown_processor(entity_parser: EntityParserDep) -> MarkdownProcessor:
|
|
117
|
+
return MarkdownProcessor(entity_parser)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
MarkdownProcessorDep = Annotated[MarkdownProcessor, Depends(get_markdown_processor)]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
async def get_file_service(
|
|
124
|
+
project_config: ProjectConfigDep, markdown_processor: MarkdownProcessorDep
|
|
125
|
+
) -> FileService:
|
|
126
|
+
return FileService(project_config.home, markdown_processor)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
FileServiceDep = Annotated[FileService, Depends(get_file_service)]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
async def get_entity_service(
|
|
134
|
+
entity_repository: EntityRepositoryDep,
|
|
135
|
+
observation_repository: ObservationRepositoryDep,
|
|
136
|
+
relation_repository: RelationRepositoryDep,
|
|
137
|
+
entity_parser: EntityParserDep,
|
|
138
|
+
file_service: FileServiceDep,
|
|
139
|
+
link_resolver: "LinkResolverDep",
|
|
140
|
+
) -> EntityService:
|
|
141
|
+
"""Create EntityService with repository."""
|
|
142
|
+
return EntityService(
|
|
143
|
+
entity_repository=entity_repository,
|
|
144
|
+
observation_repository=observation_repository,
|
|
145
|
+
relation_repository=relation_repository,
|
|
146
|
+
entity_parser=entity_parser,
|
|
147
|
+
file_service=file_service,
|
|
148
|
+
link_resolver=link_resolver,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
EntityServiceDep = Annotated[EntityService, Depends(get_entity_service)]
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def get_search_service(
|
|
156
|
+
search_repository: SearchRepositoryDep,
|
|
157
|
+
entity_repository: EntityRepositoryDep,
|
|
158
|
+
file_service: FileServiceDep,
|
|
159
|
+
) -> SearchService:
|
|
160
|
+
"""Create SearchService with dependencies."""
|
|
161
|
+
return SearchService(search_repository, entity_repository, file_service)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
SearchServiceDep = Annotated[SearchService, Depends(get_search_service)]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
async def get_link_resolver(
|
|
168
|
+
entity_repository: EntityRepositoryDep, search_service: SearchServiceDep
|
|
169
|
+
) -> LinkResolver:
|
|
170
|
+
return LinkResolver(entity_repository=entity_repository, search_service=search_service)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
LinkResolverDep = Annotated[LinkResolver, Depends(get_link_resolver)]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def get_context_service(
|
|
177
|
+
search_repository: SearchRepositoryDep, entity_repository: EntityRepositoryDep
|
|
178
|
+
) -> ContextService:
|
|
179
|
+
return ContextService(search_repository, entity_repository)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
ContextServiceDep = Annotated[ContextService, Depends(get_context_service)]
|