remdb 0.3.163__py3-none-any.whl → 0.3.181__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of remdb might be problematic. Click here for more details.

@@ -116,7 +116,8 @@ def mcp_tool_error_handler(func: Callable) -> Callable:
116
116
  # Otherwise wrap in success response
117
117
  return {"status": "success", **result}
118
118
  except Exception as e:
119
- logger.error(f"{func.__name__} failed: {e}", exc_info=True)
119
+ # Use %s format to avoid issues with curly braces in error messages
120
+ logger.opt(exception=True).error("{} failed: {}", func.__name__, str(e))
120
121
  return {
121
122
  "status": "error",
122
123
  "error": str(e),
@@ -154,6 +155,10 @@ async def search_rem(
154
155
  - Fast exact match across all tables
155
156
  - Uses indexed label_vector for instant retrieval
156
157
  - Example: LOOKUP "Sarah Chen" returns all entities named "Sarah Chen"
158
+ - **Ontology Note**: Ontology content may contain markdown links like
159
+ `[sertraline](../../drugs/antidepressants/sertraline.md)`. The link name
160
+ (e.g., "sertraline") can be used as a LOOKUP subject, while the relative
161
+ path provides semantic context (e.g., it's a drug, specifically an antidepressant).
157
162
 
158
163
  **FUZZY** - Fuzzy text matching with similarity threshold:
159
164
  - Finds partial matches and typos
@@ -380,9 +385,10 @@ async def ask_rem_agent(
380
385
  from ...utils.schema_loader import load_agent_schema
381
386
 
382
387
  # Create agent context
388
+ # Note: tenant_id defaults to "default" if user_id is None
383
389
  context = AgentContext(
384
390
  user_id=user_id,
385
- tenant_id=user_id, # Set tenant_id to user_id for backward compat
391
+ tenant_id=user_id or "default", # Use default tenant for anonymous users
386
392
  default_model=settings.llm.default_model,
387
393
  )
388
394
 
@@ -1130,3 +1136,82 @@ async def save_agent(
1130
1136
  result["message"] = f"Agent '{name}' saved. Use `/custom-agent {name}` to chat with it."
1131
1137
 
1132
1138
  return result
1139
+
1140
+
1141
+ # =============================================================================
1142
+ # Test/Debug Tools (for development only)
1143
+ # =============================================================================
1144
+
1145
+ @mcp_tool_error_handler
1146
+ async def test_error_handling(
1147
+ error_type: Literal["exception", "error_response", "timeout", "success"] = "success",
1148
+ delay_seconds: float = 0,
1149
+ error_message: str = "Test error occurred",
1150
+ ) -> dict[str, Any]:
1151
+ """
1152
+ Test tool for simulating different error scenarios.
1153
+
1154
+ **FOR DEVELOPMENT/TESTING ONLY** - This tool helps verify that error
1155
+ handling works correctly through the streaming layer.
1156
+
1157
+ Args:
1158
+ error_type: Type of error to simulate:
1159
+ - "success": Returns successful response (default)
1160
+ - "exception": Raises an exception (tests @mcp_tool_error_handler)
1161
+ - "error_response": Returns {"status": "error", ...} dict
1162
+ - "timeout": Delays for 60 seconds (simulates timeout)
1163
+ delay_seconds: Optional delay before responding (0-10 seconds)
1164
+ error_message: Custom error message for error scenarios
1165
+
1166
+ Returns:
1167
+ Dict with test results or error information
1168
+
1169
+ Examples:
1170
+ # Test successful response
1171
+ test_error_handling(error_type="success")
1172
+
1173
+ # Test exception handling
1174
+ test_error_handling(error_type="exception", error_message="Database connection failed")
1175
+
1176
+ # Test error response format
1177
+ test_error_handling(error_type="error_response", error_message="Resource not found")
1178
+
1179
+ # Test with delay
1180
+ test_error_handling(error_type="success", delay_seconds=2)
1181
+ """
1182
+ import asyncio
1183
+
1184
+ logger.info(f"test_error_handling called: type={error_type}, delay={delay_seconds}")
1185
+
1186
+ # Apply delay (capped at 10 seconds for safety)
1187
+ if delay_seconds > 0:
1188
+ await asyncio.sleep(min(delay_seconds, 10))
1189
+
1190
+ if error_type == "exception":
1191
+ # This tests the @mcp_tool_error_handler decorator
1192
+ raise RuntimeError(f"TEST EXCEPTION: {error_message}")
1193
+
1194
+ elif error_type == "error_response":
1195
+ # This tests how the streaming layer handles error status responses
1196
+ return {
1197
+ "status": "error",
1198
+ "error": error_message,
1199
+ "error_code": "TEST_ERROR",
1200
+ "recoverable": True,
1201
+ }
1202
+
1203
+ elif error_type == "timeout":
1204
+ # Simulate a very long operation (for testing client-side timeouts)
1205
+ await asyncio.sleep(60)
1206
+ return {"status": "success", "message": "Timeout test completed (should not reach here)"}
1207
+
1208
+ else: # success
1209
+ return {
1210
+ "status": "success",
1211
+ "message": "Test completed successfully",
1212
+ "test_data": {
1213
+ "error_type": error_type,
1214
+ "delay_applied": delay_seconds,
1215
+ "timestamp": str(asyncio.get_event_loop().time()),
1216
+ },
1217
+ }
rem/api/routers/auth.py CHANGED
@@ -30,14 +30,17 @@ Access Control Flow (send-code):
30
30
  │ ├── Yes → Check user.tier
31
31
  │ │ ├── tier == BLOCKED → Reject "Account is blocked"
32
32
  │ │ └── tier != BLOCKED → Allow (send code, existing users grandfathered)
33
- │ └── No (new user) → Check EMAIL__TRUSTED_EMAIL_DOMAINS
34
- │ ├── Setting configureddomain in trusted list?
35
- │ ├── Yes Create user & send code
36
- │ └── NoReject "Email domain not allowed for signup"
37
- └── Not configured (empty) → Create user & send code (no restrictions)
33
+ │ └── No (new user) → Check subscriber list first
34
+ │ ├── Email in subscribers table? Allow (create user & send code)
35
+ └── Not a subscriber Check EMAIL__TRUSTED_EMAIL_DOMAINS
36
+ ├── Setting configured → domain in trusted list?
37
+ │ ├── Yes → Create user & send code
38
+ │ │ └── No → Reject "Email domain not allowed for signup"
39
+ │ └── Not configured (empty) → Create user & send code (no restrictions)
38
40
 
39
41
  Key Behaviors:
40
42
  - Existing users: Always allowed to login (unless tier=BLOCKED)
43
+ - Subscribers: Always allowed to login (regardless of email domain)
41
44
  - New users: Must have email from trusted domain (if EMAIL__TRUSTED_EMAIL_DOMAINS is set)
42
45
  - No restrictions: Leave EMAIL__TRUSTED_EMAIL_DOMAINS empty to allow all domains
43
46
 
@@ -103,6 +106,7 @@ from ...services.postgres.service import PostgresService
103
106
  from ...services.user_service import UserService
104
107
  from ...auth.providers.email import EmailAuthProvider
105
108
  from ...auth.jwt import JWTService, get_jwt_service
109
+ from ...utils.user_id import email_to_user_id
106
110
 
107
111
  router = APIRouter(prefix="/api/auth", tags=["auth"])
108
112
 
@@ -429,8 +433,9 @@ async def callback(provider: str, request: Request):
429
433
  await user_service.link_anonymous_session(user_entity, anon_id)
430
434
 
431
435
  # Enrich session user with DB info
436
+ # user_id = UUID5 hash of email (deterministic, bijection)
432
437
  db_info = {
433
- "id": str(user_entity.id),
438
+ "id": email_to_user_id(user_info.get("email")),
434
439
  "tenant_id": user_entity.tenant_id,
435
440
  "tier": user_entity.tier.value if user_entity.tier else "free",
436
441
  "roles": [user_entity.role] if user_entity.role else [],
@@ -97,7 +97,7 @@ Context Building Flow:
97
97
  - Long messages include REM LOOKUP hints: "... [REM LOOKUP session-{id}-msg-{index}] ..."
98
98
  - Agent can retrieve full content on-demand using REM LOOKUP
99
99
  3. User profile provided as REM LOOKUP hint (on-demand by default)
100
- - Agent receives: "User ID: {user_id}. To load user profile: Use REM LOOKUP users/{user_id}"
100
+ - Agent receives: "User: {email}. To load user profile: Use REM LOOKUP \"{email}\""
101
101
  - Agent decides whether to load profile based on query
102
102
  4. If CHAT__AUTO_INJECT_USER_CONTEXT=true: User profile auto-loaded and injected
103
103
  5. Combines: system context + compressed session history + new messages
@@ -835,3 +835,21 @@ async def stream_openai_response_with_save(
835
835
  )
836
836
  except Exception as e:
837
837
  logger.error(f"Failed to save session messages: {e}", exc_info=True)
838
+
839
+ # Update session description with session_name (non-blocking, after all yields)
840
+ for tool_call in tool_calls:
841
+ if tool_call.get("tool_name") == "register_metadata" and tool_call.get("is_metadata"):
842
+ session_name = tool_call.get("arguments", {}).get("session_name")
843
+ if session_name:
844
+ try:
845
+ from ....models.entities import Session
846
+ from ....services.postgres import Repository
847
+ repo = Repository(Session, table_name="sessions")
848
+ session = await repo.get_by_id(session_id)
849
+ if session and session.description != session_name:
850
+ session.description = session_name
851
+ await repo.update(session)
852
+ logger.debug(f"Updated session {session_id} description to '{session_name}'")
853
+ except Exception as e:
854
+ logger.warning(f"Failed to update session description: {e}")
855
+ break
rem/auth/middleware.py CHANGED
@@ -7,14 +7,22 @@ Anonymous access with rate limiting when allow_anonymous=True.
7
7
  MCP endpoints are always protected unless explicitly disabled.
8
8
 
9
9
  Design Pattern:
10
- - Check X-API-Key header first (if API key auth enabled)
11
- - Check JWT token in Authorization header (Bearer token)
12
- - Check dev token (non-production only, starts with "dev_")
13
- - Check session for user on protected paths
10
+ - API Key (X-API-Key): Access control guardrail, NOT user identity
11
+ - JWT (Authorization: Bearer): Primary method for user identity
12
+ - Dev token: Non-production testing (starts with "dev_")
13
+ - Session: Backward compatibility for browser-based auth
14
14
  - MCP paths always require authentication (protected service)
15
- - If allow_anonymous=True: Allow unauthenticated requests (marked as ANONYMOUS tier)
16
- - If allow_anonymous=False: Return 401 for API calls, redirect browsers to login
17
- - Exclude auth endpoints and public paths
15
+
16
+ Authentication Flow:
17
+ 1. If API key enabled: Validate X-API-Key header (access gate)
18
+ 2. Check JWT token for user identity (primary)
19
+ 3. Check dev token for testing (non-production only)
20
+ 4. Check session for user (backward compatibility)
21
+ 5. If allow_anonymous=True: Allow as anonymous (rate-limited)
22
+ 6. If allow_anonymous=False: Return 401 / redirect to login
23
+
24
+ IMPORTANT: API key validates ACCESS, JWT identifies USER.
25
+ Both can be required: API key for access + JWT for user identity.
18
26
 
19
27
  Access Modes (configured in settings.auth):
20
28
  - enabled=true, allow_anonymous=true: Auth available, anonymous gets rate-limited access
@@ -24,10 +32,9 @@ Access Modes (configured in settings.auth):
24
32
  - mcp_requires_auth=false: MCP follows normal allow_anonymous rules (dev only)
25
33
 
26
34
  API Key Authentication (configured in settings.api):
27
- - api_key_enabled=true: Require X-API-Key header for protected endpoints
35
+ - api_key_enabled=true: Require X-API-Key header for access
28
36
  - api_key: The secret key to validate against
29
- - Provides simple programmatic access without OAuth flow
30
- - X-API-Key header takes precedence over session auth
37
+ - API key is an ACCESS GATE, not user identity - JWT still needed for user
31
38
 
32
39
  Dev Token Support (non-production only):
33
40
  - GET /api/auth/dev/token returns a Bearer token for test-user
@@ -212,32 +219,28 @@ class AuthMiddleware(BaseHTTPMiddleware):
212
219
  if not is_protected or is_excluded:
213
220
  return await call_next(request)
214
221
 
215
- # Check for X-API-Key header first (if enabled)
216
- api_key_user = self._check_api_key(request)
217
- if api_key_user:
218
- request.state.user = api_key_user
219
- request.state.is_anonymous = False
220
- return await call_next(request)
221
-
222
- # If API key auth is enabled but no valid key provided, reject immediately
222
+ # API key validation (access control, not user identity)
223
+ # API key is a guardrail for access - JWT identifies the actual user
223
224
  if settings.api.api_key_enabled:
224
- # Check if X-API-Key header was provided but invalid
225
- if request.headers.get("x-api-key"):
225
+ api_key = request.headers.get("x-api-key")
226
+ if not api_key:
227
+ logger.debug(f"Missing X-API-Key for: {path}")
228
+ return JSONResponse(
229
+ status_code=401,
230
+ content={"detail": "API key required. Include X-API-Key header."},
231
+ headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
232
+ )
233
+ if api_key != settings.api.api_key:
226
234
  logger.warning(f"Invalid X-API-Key for: {path}")
227
235
  return JSONResponse(
228
236
  status_code=401,
229
237
  content={"detail": "Invalid API key"},
230
238
  headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
231
239
  )
232
- # No API key provided when required
233
- logger.debug(f"Missing X-API-Key for: {path}")
234
- return JSONResponse(
235
- status_code=401,
236
- content={"detail": "API key required. Include X-API-Key header."},
237
- headers={"WWW-Authenticate": 'ApiKey realm="REM API"'},
238
- )
240
+ logger.debug("X-API-Key validated for access")
241
+ # API key valid - continue to check JWT for user identity
239
242
 
240
- # Check for JWT token in Authorization header
243
+ # Check for JWT token in Authorization header (primary user identity)
241
244
  jwt_user = self._check_jwt_token(request)
242
245
  if jwt_user:
243
246
  request.state.user = jwt_user
rem/cli/commands/ask.py CHANGED
@@ -75,7 +75,7 @@ async def run_agent_streaming(
75
75
  """
76
76
  Run agent in streaming mode using agent.iter() with usage limits.
77
77
 
78
- Design Pattern (from carrier):
78
+ Design Pattern:
79
79
  - Use agent.iter() for complete execution with tool call visibility
80
80
  - run_stream() stops after first output, missing tool calls
81
81
  - Stream tool call markers: [Calling: tool_name]
rem/cli/commands/db.py CHANGED
@@ -333,64 +333,120 @@ def rebuild_cache(connection: str | None):
333
333
 
334
334
  @click.command()
335
335
  @click.argument("file_path", type=click.Path(exists=True, path_type=Path))
336
+ @click.option("--table", "-t", default=None, help="Target table name (required for non-YAML formats)")
336
337
  @click.option("--user-id", default=None, help="User ID to scope data privately (default: public/shared)")
337
338
  @click.option("--dry-run", is_flag=True, help="Show what would be loaded without loading")
338
- def load(file_path: Path, user_id: str | None, dry_run: bool):
339
+ def load(file_path: Path, table: str | None, user_id: str | None, dry_run: bool):
339
340
  """
340
- Load data from YAML file into database.
341
+ Load data from file into database.
341
342
 
342
- File format:
343
- - table: resources
344
- key_field: name
345
- rows:
346
- - name: Example
347
- content: Test data...
343
+ Supports YAML with embedded metadata, or any tabular format via Polars
344
+ (jsonl, parquet, csv, json, arrow, etc.). For non-YAML formats, use --table.
348
345
 
349
346
  Examples:
350
- rem db load rem/tests/data/graph_seed.yaml
351
- rem db load data.yaml --user-id my-user # Private to user
352
- rem db load data.yaml --dry-run
347
+ rem db load data.yaml # YAML with metadata
348
+ rem db load data.jsonl -t resources # Any Polars-supported format
353
349
  """
354
- asyncio.run(_load_async(file_path, user_id, dry_run))
350
+ asyncio.run(_load_async(file_path, table, user_id, dry_run))
355
351
 
356
352
 
357
- async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
353
+ def _load_dataframe_from_file(file_path: Path) -> "pl.DataFrame":
354
+ """Load any Polars-supported file format into a DataFrame."""
355
+ import polars as pl
356
+
357
+ suffix = file_path.suffix.lower()
358
+
359
+ if suffix in {".jsonl", ".ndjson"}:
360
+ return pl.read_ndjson(file_path)
361
+ elif suffix in {".parquet", ".pq"}:
362
+ return pl.read_parquet(file_path)
363
+ elif suffix == ".csv":
364
+ return pl.read_csv(file_path)
365
+ elif suffix == ".json":
366
+ return pl.read_json(file_path)
367
+ elif suffix in {".ipc", ".arrow"}:
368
+ return pl.read_ipc(file_path)
369
+ else:
370
+ raise ValueError(f"Unsupported file format: {suffix}. Use any Polars-supported format.")
371
+
372
+
373
+ async def _load_async(file_path: Path, table: str | None, user_id: str | None, dry_run: bool):
358
374
  """Async implementation of load command."""
375
+ import polars as pl
359
376
  import yaml
360
377
  from ...models.core.inline_edge import InlineEdge
361
- from ...models.entities import Resource, Moment, User, Message, SharedSession, Schema
378
+ from ...models.entities import SharedSession
362
379
  from ...services.postgres import get_postgres_service
380
+ from ...utils.model_helpers import get_table_name
381
+ from ... import get_model_registry
363
382
 
364
383
  logger.info(f"Loading data from: {file_path}")
365
384
  scope_msg = f"user: {user_id}" if user_id else "public"
366
385
  logger.info(f"Scope: {scope_msg}")
367
386
 
368
- # Load YAML file
369
- with open(file_path) as f:
370
- data = yaml.safe_load(f)
387
+ suffix = file_path.suffix.lower()
388
+ is_yaml = suffix in {".yaml", ".yml"}
371
389
 
372
- if not isinstance(data, list):
373
- logger.error("YAML must be a list of table definitions")
374
- raise click.Abort()
375
-
376
- if dry_run:
377
- logger.info("DRY RUN - Would load:")
378
- logger.info(yaml.dump(data, default_flow_style=False))
379
- return
380
-
381
- # Map table names to model classes
382
- # CoreModel subclasses use Repository.upsert()
390
+ # Build MODEL_MAP dynamically from registry
391
+ registry = get_model_registry()
392
+ registry.register_core_models()
383
393
  MODEL_MAP = {
384
- "users": User,
385
- "moments": Moment,
386
- "resources": Resource,
387
- "messages": Message,
388
- "schemas": Schema,
394
+ get_table_name(model): model
395
+ for model in registry.get_model_classes().values()
389
396
  }
390
397
 
391
398
  # Non-CoreModel tables that need direct SQL insertion
392
399
  DIRECT_INSERT_TABLES = {"shared_sessions"}
393
400
 
401
+ # Parse file based on format
402
+ if is_yaml:
403
+ # YAML with embedded metadata
404
+ with open(file_path) as f:
405
+ data = yaml.safe_load(f)
406
+
407
+ if not isinstance(data, list):
408
+ logger.error("YAML must be a list of table definitions")
409
+ raise click.Abort()
410
+
411
+ if dry_run:
412
+ logger.info("DRY RUN - Would load:")
413
+ logger.info(yaml.dump(data, default_flow_style=False))
414
+ return
415
+
416
+ table_defs = data
417
+ else:
418
+ # Polars-supported format - require --table
419
+ if not table:
420
+ logger.error(f"For {suffix} files, --table is required. Example: rem db load {file_path.name} -t resources")
421
+ raise click.Abort()
422
+
423
+ try:
424
+ df = _load_dataframe_from_file(file_path)
425
+ except Exception as e:
426
+ logger.error(f"Failed to load file: {e}")
427
+ raise click.Abort()
428
+
429
+ rows = df.to_dicts()
430
+
431
+ if dry_run:
432
+ logger.info(f"DRY RUN - Would load {len(rows)} rows to table '{table}':")
433
+ logger.info(f"Columns: {list(df.columns)}")
434
+
435
+ # Validate first row against model if table is known
436
+ if table in MODEL_MAP and rows:
437
+ from ...utils.model_helpers import validate_data_for_model
438
+ result = validate_data_for_model(MODEL_MAP[table], rows[0])
439
+ if result.extra_fields:
440
+ logger.warning(f"Unknown fields (ignored): {result.extra_fields}")
441
+ if result.valid:
442
+ logger.success(f"Sample row validates OK. Required: {result.required_fields or '(none)'}")
443
+ else:
444
+ result.log_errors("Sample row")
445
+ return
446
+
447
+ # Wrap as single table definition
448
+ table_defs = [{"table": table, "rows": rows}]
449
+
394
450
  # Connect to database
395
451
  pg = get_postgres_service()
396
452
  if not pg:
@@ -399,23 +455,24 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
399
455
 
400
456
  await pg.connect()
401
457
 
458
+ # Start embedding worker for generating embeddings
459
+ if pg.embedding_worker:
460
+ await pg.embedding_worker.start()
461
+
402
462
  try:
403
463
  total_loaded = 0
404
464
 
405
- for table_def in data:
465
+ for table_def in table_defs:
406
466
  table_name = table_def["table"]
407
- key_field = table_def.get("key_field", "id")
408
467
  rows = table_def.get("rows", [])
409
468
 
410
469
  # Handle direct insert tables (non-CoreModel)
411
470
  if table_name in DIRECT_INSERT_TABLES:
412
471
  for row_data in rows:
413
- # Add tenant_id if not present
414
472
  if "tenant_id" not in row_data:
415
473
  row_data["tenant_id"] = "default"
416
474
 
417
475
  if table_name == "shared_sessions":
418
- # Insert shared_session directly
419
476
  await pg.fetch(
420
477
  """INSERT INTO shared_sessions
421
478
  (session_id, owner_user_id, shared_with_user_id, tenant_id)
@@ -434,16 +491,13 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
434
491
  logger.warning(f"Unknown table: {table_name}, skipping")
435
492
  continue
436
493
 
437
- model_class = MODEL_MAP[table_name] # Type is inferred from MODEL_MAP
494
+ model_class = MODEL_MAP[table_name]
438
495
 
439
- for row_data in rows:
440
- # Add user_id and tenant_id only if explicitly provided
441
- # Default is public (None) - data is shared/visible to all
442
- # Pass --user-id to scope data privately to a specific user
443
- if "user_id" not in row_data and user_id is not None:
444
- row_data["user_id"] = user_id
496
+ for row_idx, row_data in enumerate(rows):
497
+ # user_id stays NULL for public data (accessible by any user)
498
+ # Only set tenant_id for scoping - the --user-id flag controls tenant scope
445
499
  if "tenant_id" not in row_data and user_id is not None:
446
- row_data["tenant_id"] = row_data.get("user_id", user_id)
500
+ row_data["tenant_id"] = user_id
447
501
 
448
502
  # Convert graph_edges to InlineEdge format if present
449
503
  if "graph_edges" in row_data:
@@ -452,30 +506,40 @@ async def _load_async(file_path: Path, user_id: str | None, dry_run: bool):
452
506
  for edge in row_data["graph_edges"]
453
507
  ]
454
508
 
455
- # Convert any ISO timestamp strings with Z suffix to naive datetime
456
- # This handles fields like starts_timestamp, ends_timestamp, etc.
509
+ # Convert ISO timestamp strings
457
510
  from ...utils.date_utils import parse_iso
458
511
  for key, value in list(row_data.items()):
459
512
  if isinstance(value, str) and (key.endswith("_timestamp") or key.endswith("_at")):
460
513
  try:
461
514
  row_data[key] = parse_iso(value)
462
515
  except (ValueError, TypeError):
463
- pass # Not a valid datetime string, leave as-is
516
+ pass
464
517
 
465
- # Create model instance and upsert via repository
466
518
  from ...services.postgres.repository import Repository
519
+ from ...utils.model_helpers import validate_data_for_model
520
+
521
+ result = validate_data_for_model(model_class, row_data)
522
+ if not result.valid:
523
+ result.log_errors(f"Row {row_idx + 1} ({table_name})")
524
+ raise click.Abort()
467
525
 
468
- instance = model_class(**row_data)
469
- repo = Repository(model_class, table_name, pg) # Type inferred from MODEL_MAP
470
- await repo.upsert(instance) # type: ignore[arg-type]
526
+ repo = Repository(model_class, table_name, pg)
527
+ await repo.upsert(result.instance) # type: ignore[arg-type]
471
528
  total_loaded += 1
472
529
 
473
- # Log based on model type
474
- name = getattr(instance, 'name', getattr(instance, 'id', '?'))
530
+ name = getattr(result.instance, 'name', getattr(result.instance, 'id', '?'))
475
531
  logger.success(f"Loaded {table_name[:-1]}: {name}")
476
532
 
477
533
  logger.success(f"Data loaded successfully! Total rows: {total_loaded}")
478
534
 
535
+ # Wait for embeddings to complete
536
+ if pg.embedding_worker and pg.embedding_worker.running:
537
+ queue_size = pg.embedding_worker.task_queue.qsize()
538
+ if queue_size > 0:
539
+ logger.info(f"Waiting for {queue_size} embeddings to complete...")
540
+ await pg.embedding_worker.stop()
541
+ logger.success("Embeddings generated successfully")
542
+
479
543
  finally:
480
544
  await pg.disconnect()
481
545