remdb 0.3.163__py3-none-any.whl → 0.3.181__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of remdb might be problematic. Click here for more details.
- rem/agentic/agents/agent_manager.py +2 -1
- rem/agentic/context_builder.py +18 -6
- rem/agentic/mcp/tool_wrapper.py +43 -14
- rem/agentic/providers/pydantic_ai.py +76 -34
- rem/agentic/schema.py +4 -3
- rem/agentic/tools/rem_tools.py +11 -0
- rem/api/mcp_router/resources.py +75 -14
- rem/api/mcp_router/server.py +27 -24
- rem/api/mcp_router/tools.py +87 -2
- rem/api/routers/auth.py +11 -6
- rem/api/routers/chat/completions.py +1 -1
- rem/api/routers/chat/streaming.py +18 -0
- rem/auth/middleware.py +31 -28
- rem/cli/commands/ask.py +1 -1
- rem/cli/commands/db.py +118 -54
- rem/models/entities/ontology.py +93 -101
- rem/schemas/agents/core/agent-builder.yaml +143 -42
- rem/services/email/service.py +17 -6
- rem/services/embeddings/worker.py +26 -12
- rem/services/postgres/register_type.py +1 -1
- rem/services/postgres/repository.py +32 -21
- rem/services/postgres/schema_generator.py +5 -5
- rem/services/postgres/sql_builder.py +6 -5
- rem/services/user_service.py +12 -9
- rem/settings.py +7 -1
- rem/sql/background_indexes.sql +5 -0
- rem/sql/migrations/001_install.sql +33 -4
- rem/sql/migrations/002_install_models.sql +204 -186
- rem/utils/model_helpers.py +101 -0
- rem/utils/schema_loader.py +45 -7
- {remdb-0.3.163.dist-info → remdb-0.3.181.dist-info}/METADATA +1 -1
- {remdb-0.3.163.dist-info → remdb-0.3.181.dist-info}/RECORD +34 -34
- {remdb-0.3.163.dist-info → remdb-0.3.181.dist-info}/WHEEL +0 -0
- {remdb-0.3.163.dist-info → remdb-0.3.181.dist-info}/entry_points.txt +0 -0
rem/api/mcp_router/tools.py
CHANGED
|
@@ -116,7 +116,8 @@ def mcp_tool_error_handler(func: Callable) -> Callable:
|
|
|
116
116
|
# Otherwise wrap in success response
|
|
117
117
|
return {"status": "success", **result}
|
|
118
118
|
except Exception as e:
|
|
119
|
-
|
|
119
|
+
# Use %s format to avoid issues with curly braces in error messages
|
|
120
|
+
logger.opt(exception=True).error("{} failed: {}", func.__name__, str(e))
|
|
120
121
|
return {
|
|
121
122
|
"status": "error",
|
|
122
123
|
"error": str(e),
|
|
@@ -154,6 +155,10 @@ async def search_rem(
|
|
|
154
155
|
- Fast exact match across all tables
|
|
155
156
|
- Uses indexed label_vector for instant retrieval
|
|
156
157
|
- Example: LOOKUP "Sarah Chen" returns all entities named "Sarah Chen"
|
|
158
|
+
- **Ontology Note**: Ontology content may contain markdown links like
|
|
159
|
+
`[sertraline](../../drugs/antidepressants/sertraline.md)`. The link name
|
|
160
|
+
(e.g., "sertraline") can be used as a LOOKUP subject, while the relative
|
|
161
|
+
path provides semantic context (e.g., it's a drug, specifically an antidepressant).
|
|
157
162
|
|
|
158
163
|
**FUZZY** - Fuzzy text matching with similarity threshold:
|
|
159
164
|
- Finds partial matches and typos
|
|
@@ -380,9 +385,10 @@ async def ask_rem_agent(
|
|
|
380
385
|
from ...utils.schema_loader import load_agent_schema
|
|
381
386
|
|
|
382
387
|
# Create agent context
|
|
388
|
+
# Note: tenant_id defaults to "default" if user_id is None
|
|
383
389
|
context = AgentContext(
|
|
384
390
|
user_id=user_id,
|
|
385
|
-
tenant_id=user_id, #
|
|
391
|
+
tenant_id=user_id or "default", # Use default tenant for anonymous users
|
|
386
392
|
default_model=settings.llm.default_model,
|
|
387
393
|
)
|
|
388
394
|
|
|
@@ -1130,3 +1136,82 @@ async def save_agent(
|
|
|
1130
1136
|
result["message"] = f"Agent '{name}' saved. Use `/custom-agent {name}` to chat with it."
|
|
1131
1137
|
|
|
1132
1138
|
return result
|
|
1139
|
+
|
|
1140
|
+
|
|
1141
|
+
# =============================================================================
|
|
1142
|
+
# Test/Debug Tools (for development only)
|
|
1143
|
+
# =============================================================================
|
|
1144
|
+
|
|
1145
|
+
@mcp_tool_error_handler
|
|
1146
|
+
async def test_error_handling(
|
|
1147
|
+
error_type: Literal["exception", "error_response", "timeout", "success"] = "success",
|
|
1148
|
+
delay_seconds: float = 0,
|
|
1149
|
+
error_message: str = "Test error occurred",
|
|
1150
|
+
) -> dict[str, Any]:
|
|
1151
|
+
"""
|
|
1152
|
+
Test tool for simulating different error scenarios.
|
|
1153
|
+
|
|
1154
|
+
**FOR DEVELOPMENT/TESTING ONLY** - This tool helps verify that error
|
|
1155
|
+
handling works correctly through the streaming layer.
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
error_type: Type of error to simulate:
|
|
1159
|
+
- "success": Returns successful response (default)
|
|
1160
|
+
- "exception": Raises an exception (tests @mcp_tool_error_handler)
|
|
1161
|
+
- "error_response": Returns {"status": "error", ...} dict
|
|
1162
|
+
- "timeout": Delays for 60 seconds (simulates timeout)
|
|
1163
|
+
delay_seconds: Optional delay before responding (0-10 seconds)
|
|
1164
|
+
error_message: Custom error message for error scenarios
|
|
1165
|
+
|
|
1166
|
+
Returns:
|
|
1167
|
+
Dict with test results or error information
|
|
1168
|
+
|
|
1169
|
+
Examples:
|
|
1170
|
+
# Test successful response
|
|
1171
|
+
test_error_handling(error_type="success")
|
|
1172
|
+
|
|
1173
|
+
# Test exception handling
|
|
1174
|
+
test_error_handling(error_type="exception", error_message="Database connection failed")
|
|
1175
|
+
|
|
1176
|
+
# Test error response format
|
|
1177
|
+
test_error_handling(error_type="error_response", error_message="Resource not found")
|
|
1178
|
+
|
|
1179
|
+
# Test with delay
|
|
1180
|
+
test_error_handling(error_type="success", delay_seconds=2)
|
|
1181
|
+
"""
|
|
1182
|
+
import asyncio
|
|
1183
|
+
|
|
1184
|
+
logger.info(f"test_error_handling called: type={error_type}, delay={delay_seconds}")
|
|
1185
|
+
|
|
1186
|
+
# Apply delay (capped at 10 seconds for safety)
|
|
1187
|
+
if delay_seconds > 0:
|
|
1188
|
+
await asyncio.sleep(min(delay_seconds, 10))
|
|
1189
|
+
|
|
1190
|
+
if error_type == "exception":
|
|
1191
|
+
# This tests the @mcp_tool_error_handler decorator
|
|
1192
|
+
raise RuntimeError(f"TEST EXCEPTION: {error_message}")
|
|
1193
|
+
|
|
1194
|
+
elif error_type == "error_response":
|
|
1195
|
+
# This tests how the streaming layer handles error status responses
|
|
1196
|
+
return {
|
|
1197
|
+
"status": "error",
|
|
1198
|
+
"error": error_message,
|
|
1199
|
+
"error_code": "TEST_ERROR",
|
|
1200
|
+
"recoverable": True,
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
elif error_type == "timeout":
|
|
1204
|
+
# Simulate a very long operation (for testing client-side timeouts)
|
|
1205
|
+
await asyncio.sleep(60)
|
|
1206
|
+
return {"status": "success", "message": "Timeout test completed (should not reach here)"}
|
|
1207
|
+
|
|
1208
|
+
else: # success
|
|
1209
|
+
return {
|
|
1210
|
+
"status": "success",
|
|
1211
|
+
"message": "Test completed successfully",
|
|
1212
|
+
"test_data": {
|
|
1213
|
+
"error_type": error_type,
|
|
1214
|
+
"delay_applied": delay_seconds,
|
|
1215
|
+
"timestamp": str(asyncio.get_event_loop().time()),
|
|
1216
|
+
},
|
|
1217
|
+
}
|
rem/api/routers/auth.py
CHANGED
|
@@ -30,14 +30,17 @@ Access Control Flow (send-code):
|
|
|
30
30
|
│ ├── Yes → Check user.tier
|
|
31
31
|
│ │ ├── tier == BLOCKED → Reject "Account is blocked"
|
|
32
32
|
│ │ └── tier != BLOCKED → Allow (send code, existing users grandfathered)
|
|
33
|
-
│ └── No (new user) → Check
|
|
34
|
-
│ ├──
|
|
35
|
-
│
|
|
36
|
-
│
|
|
37
|
-
│
|
|
33
|
+
│ └── No (new user) → Check subscriber list first
|
|
34
|
+
│ ├── Email in subscribers table? → Allow (create user & send code)
|
|
35
|
+
│ └── Not a subscriber → Check EMAIL__TRUSTED_EMAIL_DOMAINS
|
|
36
|
+
│ ├── Setting configured → domain in trusted list?
|
|
37
|
+
│ │ ├── Yes → Create user & send code
|
|
38
|
+
│ │ └── No → Reject "Email domain not allowed for signup"
|
|
39
|
+
│ └── Not configured (empty) → Create user & send code (no restrictions)
|
|
38
40
|
|
|
39
41
|
Key Behaviors:
|
|
40
42
|
- Existing users: Always allowed to login (unless tier=BLOCKED)
|
|
43
|
+
- Subscribers: Always allowed to login (regardless of email domain)
|
|
41
44
|
- New users: Must have email from trusted domain (if EMAIL__TRUSTED_EMAIL_DOMAINS is set)
|
|
42
45
|
- No restrictions: Leave EMAIL__TRUSTED_EMAIL_DOMAINS empty to allow all domains
|
|
43
46
|
|
|
@@ -103,6 +106,7 @@ from ...services.postgres.service import PostgresService
|
|
|
103
106
|
from ...services.user_service import UserService
|
|
104
107
|
from ...auth.providers.email import EmailAuthProvider
|
|
105
108
|
from ...auth.jwt import JWTService, get_jwt_service
|
|
109
|
+
from ...utils.user_id import email_to_user_id
|
|
106
110
|
|
|
107
111
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
108
112
|
|
|
@@ -429,8 +433,9 @@ async def callback(provider: str, request: Request):
|
|
|
429
433
|
await user_service.link_anonymous_session(user_entity, anon_id)
|
|
430
434
|
|
|
431
435
|
# Enrich session user with DB info
|
|
436
|
+
# user_id = UUID5 hash of email (deterministic, bijection)
|
|
432
437
|
db_info = {
|
|
433
|
-
"id":
|
|
438
|
+
"id": email_to_user_id(user_info.get("email")),
|
|
434
439
|
"tenant_id": user_entity.tenant_id,
|
|
435
440
|
"tier": user_entity.tier.value if user_entity.tier else "free",
|
|
436
441
|
"roles": [user_entity.role] if user_entity.role else [],
|
|
@@ -97,7 +97,7 @@ Context Building Flow:
|
|
|
97
97
|
- Long messages include REM LOOKUP hints: "... [REM LOOKUP session-{id}-msg-{index}] ..."
|
|
98
98
|
- Agent can retrieve full content on-demand using REM LOOKUP
|
|
99
99
|
3. User profile provided as REM LOOKUP hint (on-demand by default)
|
|
100
|
-
- Agent receives: "User
|
|
100
|
+
- Agent receives: "User: {email}. To load user profile: Use REM LOOKUP \"{email}\""
|
|
101
101
|
- Agent decides whether to load profile based on query
|
|
102
102
|
4. If CHAT__AUTO_INJECT_USER_CONTEXT=true: User profile auto-loaded and injected
|
|
103
103
|
5. Combines: system context + compressed session history + new messages
|
|
@@ -835,3 +835,21 @@ async def stream_openai_response_with_save(
|
|
|
835
835
|
)
|
|
836
836
|
except Exception as e:
|
|
837
837
|
logger.error(f"Failed to save session messages: {e}", exc_info=True)
|
|
838
|
+
|
|
839
|
+
# Update session description with session_name (non-blocking, after all yields)
|
|
840
|
+
for tool_call in tool_calls:
|
|
841
|
+
if tool_call.get("tool_name") == "register_metadata" and tool_call.get("is_metadata"):
|
|
842
|
+
session_name = tool_call.get("arguments", {}).get("session_name")
|
|
843
|
+
if session_name:
|
|
844
|
+
try:
|
|
845
|
+
from ....models.entities import Session
|
|
846
|
+
from ....services.postgres import Repository
|
|
847
|
+
repo = Repository(Session, table_name="sessions")
|
|
848
|
+
session = await repo.get_by_id(session_id)
|
|
849
|
+
if session and session.description != session_name:
|
|
850
|
+
session.description = session_name
|
|
851
|
+
await repo.update(session)
|
|
852
|
+
logger.debug(f"Updated session {session_id} description to '{session_name}'")
|
|
853
|
+
except Exception as e:
|
|
854
|
+
logger.warning(f"Failed to update session description: {e}")
|
|
855
|
+
break
|
rem/auth/middleware.py
CHANGED
|
@@ -7,14 +7,22 @@ Anonymous access with rate limiting when allow_anonymous=True.
|
|
|
7
7
|
MCP endpoints are always protected unless explicitly disabled.
|
|
8
8
|
|
|
9
9
|
Design Pattern:
|
|
10
|
-
-
|
|
11
|
-
-
|
|
12
|
-
-
|
|
13
|
-
-
|
|
10
|
+
- API Key (X-API-Key): Access control guardrail, NOT user identity
|
|
11
|
+
- JWT (Authorization: Bearer): Primary method for user identity
|
|
12
|
+
- Dev token: Non-production testing (starts with "dev_")
|
|
13
|
+
- Session: Backward compatibility for browser-based auth
|
|
14
14
|
- MCP paths always require authentication (protected service)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
|
|
16
|
+
Authentication Flow:
|
|
17
|
+
1. If API key enabled: Validate X-API-Key header (access gate)
|
|
18
|
+
2. Check JWT token for user identity (primary)
|
|
19
|
+
3. Check dev token for testing (non-production only)
|
|
20
|
+
4. Check session for user (backward compatibility)
|
|
21
|
+
5. If allow_anonymous=True: Allow as anonymous (rate-limited)
|
|
22
|
+
6. If allow_anonymous=False: Return 401 / redirect to login
|
|
23
|
+
|
|
24
|
+
IMPORTANT: API key validates ACCESS, JWT identifies USER.
|
|
25
|
+
Both can be required: API key for access + JWT for user identity.
|
|
18
26
|
|
|
19
27
|
Access Modes (configured in settings.auth):
|
|
20
28
|
- enabled=true, allow_anonymous=true: Auth available, anonymous gets rate-limited access
|
|
@@ -24,10 +32,9 @@ Access Modes (configured in settings.auth):
|
|
|
24
32
|
- mcp_requires_auth=false: MCP follows normal allow_anonymous rules (dev only)
|
|
25
33
|
|
|
26
34
|
API Key Authentication (configured in settings.api):
|
|
27
|
-
- api_key_enabled=true: Require X-API-Key header for
|
|
35
|
+
- api_key_enabled=true: Require X-API-Key header for access
|
|
28
36
|
- api_key: The secret key to validate against
|
|
29
|
-
-
|
|
30
|
-
- X-API-Key header takes precedence over session auth
|
|
37
|
+
- API key is an ACCESS GATE, not user identity - JWT still needed for user
|
|
31
38
|
|
|
32
39
|
Dev Token Support (non-production only):
|
|
33
40
|
- GET /api/auth/dev/token returns a Bearer token for test-user
|
|
@@ -212,32 +219,28 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|
|
212
219
|
if not is_protected or is_excluded:
|
|
213
220
|
return await call_next(request)
|
|
214
221
|
|
|
215
|
-
#
|
|
216
|
-
|
|
217
|
-
if api_key_user:
|
|
218
|
-
request.state.user = api_key_user
|
|
219
|
-
request.state.is_anonymous = False
|
|
220
|
-
return await call_next(request)
|
|
221
|
-
|
|
222
|
-
# If API key auth is enabled but no valid key provided, reject immediately
|
|
222
|
+
# API key validation (access control, not user identity)
|
|
223
|
+
# API key is a guardrail for access - JWT identifies the actual user
|
|
223
224
|
if settings.api.api_key_enabled:
|
|
224
|
-
|
|
225
|
-
if
|
|
225
|
+
api_key = request.headers.get("x-api-key")
|
|
226
|
+
if not api_key:
|
|
227
|
+
logger.debug(f"Missing X-API-Key for: {path}")
|
|
228
|
+
return JSONResponse(
|
|
229
|
+
status_code=401,
|
|
230
|
+
content={"detail": "API key required. Include X-API-Key header."},
|
|
231
|
+
headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
|
|
232
|
+
)
|
|
233
|
+
if api_key != settings.api.api_key:
|
|
226
234
|
logger.warning(f"Invalid X-API-Key for: {path}")
|
|
227
235
|
return JSONResponse(
|
|
228
236
|
status_code=401,
|
|
229
237
|
content={"detail": "Invalid API key"},
|
|
230
238
|
headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
|
|
231
239
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
return JSONResponse(
|
|
235
|
-
status_code=401,
|
|
236
|
-
content={"detail": "API key required. Include X-API-Key header."},
|
|
237
|
-
headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
|
|
238
|
-
)
|
|
240
|
+
logger.debug("X-API-Key validated for access")
|
|
241
|
+
# API key valid - continue to check JWT for user identity
|
|
239
242
|
|
|
240
|
-
# Check for JWT token in Authorization header
|
|
243
|
+
# Check for JWT token in Authorization header (primary user identity)
|
|
241
244
|
jwt_user = self._check_jwt_token(request)
|
|
242
245
|
if jwt_user:
|
|
243
246
|
request.state.user = jwt_user
|
rem/cli/commands/ask.py
CHANGED
|
@@ -75,7 +75,7 @@ async def run_agent_streaming(
|
|
|
75
75
|
"""
|
|
76
76
|
Run agent in streaming mode using agent.iter() with usage limits.
|
|
77
77
|
|
|
78
|
-
Design Pattern
|
|
78
|
+
Design Pattern:
|
|
79
79
|
- Use agent.iter() for complete execution with tool call visibility
|
|
80
80
|
- run_stream() stops after first output, missing tool calls
|
|
81
81
|
- Stream tool call markers: [Calling: tool_name]
|
rem/cli/commands/db.py
CHANGED
|
@@ -333,64 +333,120 @@ def rebuild_cache(connection: str | None):
|
|
|
333
333
|
|
|
334
334
|
@click.command()
|
|
335
335
|
@click.argument("file_path", type=click.Path(exists=True, path_type=Path))
|
|
336
|
+
@click.option("--table", "-t", default=None, help="Target table name (required for non-YAML formats)")
|
|
336
337
|
@click.option("--user-id", default=None, help="User ID to scope data privately (default: public/shared)")
|
|
337
338
|
@click.option("--dry-run", is_flag=True, help="Show what would be loaded without loading")
|
|
338
|
-
def load(file_path: Path, user_id: str | None, dry_run: bool):
|
|
339
|
+
def load(file_path: Path, table: str | None, user_id: str | None, dry_run: bool):
|
|
339
340
|
"""
|
|
340
|
-
Load data from
|
|
341
|
+
Load data from file into database.
|
|
341
342
|
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
key_field: name
|
|
345
|
-
rows:
|
|
346
|
-
- name: Example
|
|
347
|
-
content: Test data...
|
|
343
|
+
Supports YAML with embedded metadata, or any tabular format via Polars
|
|
344
|
+
(jsonl, parquet, csv, json, arrow, etc.). For non-YAML formats, use --table.
|
|
348
345
|
|
|
349
346
|
Examples:
|
|
350
|
-
rem db load
|
|
351
|
-
rem db load data.
|
|
352
|
-
rem db load data.yaml --dry-run
|
|
347
|
+
rem db load data.yaml # YAML with metadata
|
|
348
|
+
rem db load data.jsonl -t resources # Any Polars-supported format
|
|
353
349
|
"""
|
|
354
|
-
asyncio.run(_load_async(file_path, user_id, dry_run))
|
|
350
|
+
asyncio.run(_load_async(file_path, table, user_id, dry_run))
|
|
355
351
|
|
|
356
352
|
|
|
357
|
-
|
|
353
|
+
def _load_dataframe_from_file(file_path: Path) -> "pl.DataFrame":
|
|
354
|
+
"""Load any Polars-supported file format into a DataFrame."""
|
|
355
|
+
import polars as pl
|
|
356
|
+
|
|
357
|
+
suffix = file_path.suffix.lower()
|
|
358
|
+
|
|
359
|
+
if suffix in {".jsonl", ".ndjson"}:
|
|
360
|
+
return pl.read_ndjson(file_path)
|
|
361
|
+
elif suffix in {".parquet", ".pq"}:
|
|
362
|
+
return pl.read_parquet(file_path)
|
|
363
|
+
elif suffix == ".csv":
|
|
364
|
+
return pl.read_csv(file_path)
|
|
365
|
+
elif suffix == ".json":
|
|
366
|
+
return pl.read_json(file_path)
|
|
367
|
+
elif suffix in {".ipc", ".arrow"}:
|
|
368
|
+
return pl.read_ipc(file_path)
|
|
369
|
+
else:
|
|
370
|
+
raise ValueError(f"Unsupported file format: {suffix}. Use any Polars-supported format.")
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
async def _load_async(file_path: Path, table: str | None, user_id: str | None, dry_run: bool):
|
|
358
374
|
"""Async implementation of load command."""
|
|
375
|
+
import polars as pl
|
|
359
376
|
import yaml
|
|
360
377
|
from ...models.core.inline_edge import InlineEdge
|
|
361
|
-
from ...models.entities import
|
|
378
|
+
from ...models.entities import SharedSession
|
|
362
379
|
from ...services.postgres import get_postgres_service
|
|
380
|
+
from ...utils.model_helpers import get_table_name
|
|
381
|
+
from ... import get_model_registry
|
|
363
382
|
|
|
364
383
|
logger.info(f"Loading data from: {file_path}")
|
|
365
384
|
scope_msg = f"user: {user_id}" if user_id else "public"
|
|
366
385
|
logger.info(f"Scope: {scope_msg}")
|
|
367
386
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
data = yaml.safe_load(f)
|
|
387
|
+
suffix = file_path.suffix.lower()
|
|
388
|
+
is_yaml = suffix in {".yaml", ".yml"}
|
|
371
389
|
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
if dry_run:
|
|
377
|
-
logger.info("DRY RUN - Would load:")
|
|
378
|
-
logger.info(yaml.dump(data, default_flow_style=False))
|
|
379
|
-
return
|
|
380
|
-
|
|
381
|
-
# Map table names to model classes
|
|
382
|
-
# CoreModel subclasses use Repository.upsert()
|
|
390
|
+
# Build MODEL_MAP dynamically from registry
|
|
391
|
+
registry = get_model_registry()
|
|
392
|
+
registry.register_core_models()
|
|
383
393
|
MODEL_MAP = {
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
"resources": Resource,
|
|
387
|
-
"messages": Message,
|
|
388
|
-
"schemas": Schema,
|
|
394
|
+
get_table_name(model): model
|
|
395
|
+
for model in registry.get_model_classes().values()
|
|
389
396
|
}
|
|
390
397
|
|
|
391
398
|
# Non-CoreModel tables that need direct SQL insertion
|
|
392
399
|
DIRECT_INSERT_TABLES = {"shared_sessions"}
|
|
393
400
|
|
|
401
|
+
# Parse file based on format
|
|
402
|
+
if is_yaml:
|
|
403
|
+
# YAML with embedded metadata
|
|
404
|
+
with open(file_path) as f:
|
|
405
|
+
data = yaml.safe_load(f)
|
|
406
|
+
|
|
407
|
+
if not isinstance(data, list):
|
|
408
|
+
logger.error("YAML must be a list of table definitions")
|
|
409
|
+
raise click.Abort()
|
|
410
|
+
|
|
411
|
+
if dry_run:
|
|
412
|
+
logger.info("DRY RUN - Would load:")
|
|
413
|
+
logger.info(yaml.dump(data, default_flow_style=False))
|
|
414
|
+
return
|
|
415
|
+
|
|
416
|
+
table_defs = data
|
|
417
|
+
else:
|
|
418
|
+
# Polars-supported format - require --table
|
|
419
|
+
if not table:
|
|
420
|
+
logger.error(f"For {suffix} files, --table is required. Example: rem db load {file_path.name} -t resources")
|
|
421
|
+
raise click.Abort()
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
df = _load_dataframe_from_file(file_path)
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.error(f"Failed to load file: {e}")
|
|
427
|
+
raise click.Abort()
|
|
428
|
+
|
|
429
|
+
rows = df.to_dicts()
|
|
430
|
+
|
|
431
|
+
if dry_run:
|
|
432
|
+
logger.info(f"DRY RUN - Would load {len(rows)} rows to table '{table}':")
|
|
433
|
+
logger.info(f"Columns: {list(df.columns)}")
|
|
434
|
+
|
|
435
|
+
# Validate first row against model if table is known
|
|
436
|
+
if table in MODEL_MAP and rows:
|
|
437
|
+
from ...utils.model_helpers import validate_data_for_model
|
|
438
|
+
result = validate_data_for_model(MODEL_MAP[table], rows[0])
|
|
439
|
+
if result.extra_fields:
|
|
440
|
+
logger.warning(f"Unknown fields (ignored): {result.extra_fields}")
|
|
441
|
+
if result.valid:
|
|
442
|
+
logger.success(f"Sample row validates OK. Required: {result.required_fields or '(none)'}")
|
|
443
|
+
else:
|
|
444
|
+
result.log_errors("Sample row")
|
|
445
|
+
return
|
|
446
|
+
|
|
447
|
+
# Wrap as single table definition
|
|
448
|
+
table_defs = [{"table": table, "rows": rows}]
|
|
449
|
+
|
|
394
450
|
# Connect to database
|
|
395
451
|
pg = get_postgres_service()
|
|
396
452
|
if not pg:
|
|
@@ -399,23 +455,24 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
|
|
|
399
455
|
|
|
400
456
|
await pg.connect()
|
|
401
457
|
|
|
458
|
+
# Start embedding worker for generating embeddings
|
|
459
|
+
if pg.embedding_worker:
|
|
460
|
+
await pg.embedding_worker.start()
|
|
461
|
+
|
|
402
462
|
try:
|
|
403
463
|
total_loaded = 0
|
|
404
464
|
|
|
405
|
-
for table_def in
|
|
465
|
+
for table_def in table_defs:
|
|
406
466
|
table_name = table_def["table"]
|
|
407
|
-
key_field = table_def.get("key_field", "id")
|
|
408
467
|
rows = table_def.get("rows", [])
|
|
409
468
|
|
|
410
469
|
# Handle direct insert tables (non-CoreModel)
|
|
411
470
|
if table_name in DIRECT_INSERT_TABLES:
|
|
412
471
|
for row_data in rows:
|
|
413
|
-
# Add tenant_id if not present
|
|
414
472
|
if "tenant_id" not in row_data:
|
|
415
473
|
row_data["tenant_id"] = "default"
|
|
416
474
|
|
|
417
475
|
if table_name == "shared_sessions":
|
|
418
|
-
# Insert shared_session directly
|
|
419
476
|
await pg.fetch(
|
|
420
477
|
"""INSERT INTO shared_sessions
|
|
421
478
|
(session_id, owner_user_id, shared_with_user_id, tenant_id)
|
|
@@ -434,16 +491,13 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
|
|
|
434
491
|
logger.warning(f"Unknown table: {table_name}, skipping")
|
|
435
492
|
continue
|
|
436
493
|
|
|
437
|
-
model_class = MODEL_MAP[table_name]
|
|
494
|
+
model_class = MODEL_MAP[table_name]
|
|
438
495
|
|
|
439
|
-
for row_data in rows:
|
|
440
|
-
#
|
|
441
|
-
#
|
|
442
|
-
# Pass --user-id to scope data privately to a specific user
|
|
443
|
-
if "user_id" not in row_data and user_id is not None:
|
|
444
|
-
row_data["user_id"] = user_id
|
|
496
|
+
for row_idx, row_data in enumerate(rows):
|
|
497
|
+
# user_id stays NULL for public data (accessible by any user)
|
|
498
|
+
# Only set tenant_id for scoping - the --user-id flag controls tenant scope
|
|
445
499
|
if "tenant_id" not in row_data and user_id is not None:
|
|
446
|
-
row_data["tenant_id"] =
|
|
500
|
+
row_data["tenant_id"] = user_id
|
|
447
501
|
|
|
448
502
|
# Convert graph_edges to InlineEdge format if present
|
|
449
503
|
if "graph_edges" in row_data:
|
|
@@ -452,30 +506,40 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
|
|
|
452
506
|
for edge in row_data["graph_edges"]
|
|
453
507
|
]
|
|
454
508
|
|
|
455
|
-
# Convert
|
|
456
|
-
# This handles fields like starts_timestamp, ends_timestamp, etc.
|
|
509
|
+
# Convert ISO timestamp strings
|
|
457
510
|
from ...utils.date_utils import parse_iso
|
|
458
511
|
for key, value in list(row_data.items()):
|
|
459
512
|
if isinstance(value, str) and (key.endswith("_timestamp") or key.endswith("_at")):
|
|
460
513
|
try:
|
|
461
514
|
row_data[key] = parse_iso(value)
|
|
462
515
|
except (ValueError, TypeError):
|
|
463
|
-
pass
|
|
516
|
+
pass
|
|
464
517
|
|
|
465
|
-
# Create model instance and upsert via repository
|
|
466
518
|
from ...services.postgres.repository import Repository
|
|
519
|
+
from ...utils.model_helpers import validate_data_for_model
|
|
520
|
+
|
|
521
|
+
result = validate_data_for_model(model_class, row_data)
|
|
522
|
+
if not result.valid:
|
|
523
|
+
result.log_errors(f"Row {row_idx + 1} ({table_name})")
|
|
524
|
+
raise click.Abort()
|
|
467
525
|
|
|
468
|
-
|
|
469
|
-
repo
|
|
470
|
-
await repo.upsert(instance) # type: ignore[arg-type]
|
|
526
|
+
repo = Repository(model_class, table_name, pg)
|
|
527
|
+
await repo.upsert(result.instance) # type: ignore[arg-type]
|
|
471
528
|
total_loaded += 1
|
|
472
529
|
|
|
473
|
-
|
|
474
|
-
name = getattr(instance, 'name', getattr(instance, 'id', '?'))
|
|
530
|
+
name = getattr(result.instance, 'name', getattr(result.instance, 'id', '?'))
|
|
475
531
|
logger.success(f"Loaded {table_name[:-1]}: {name}")
|
|
476
532
|
|
|
477
533
|
logger.success(f"Data loaded successfully! Total rows: {total_loaded}")
|
|
478
534
|
|
|
535
|
+
# Wait for embeddings to complete
|
|
536
|
+
if pg.embedding_worker and pg.embedding_worker.running:
|
|
537
|
+
queue_size = pg.embedding_worker.task_queue.qsize()
|
|
538
|
+
if queue_size > 0:
|
|
539
|
+
logger.info(f"Waiting for {queue_size} embeddings to complete...")
|
|
540
|
+
await pg.embedding_worker.stop()
|
|
541
|
+
logger.success("Embeddings generated successfully")
|
|
542
|
+
|
|
479
543
|
finally:
|
|
480
544
|
await pg.disconnect()
|
|
481
545
|
|