remdb 0.3.180__py3-none-any.whl → 0.3.258__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. rem/agentic/README.md +36 -2
  2. rem/agentic/__init__.py +10 -1
  3. rem/agentic/context.py +185 -1
  4. rem/agentic/context_builder.py +56 -35
  5. rem/agentic/mcp/tool_wrapper.py +2 -2
  6. rem/agentic/providers/pydantic_ai.py +303 -111
  7. rem/agentic/schema.py +2 -2
  8. rem/api/main.py +1 -1
  9. rem/api/mcp_router/resources.py +223 -0
  10. rem/api/mcp_router/server.py +4 -0
  11. rem/api/mcp_router/tools.py +608 -166
  12. rem/api/routers/admin.py +30 -4
  13. rem/api/routers/auth.py +219 -20
  14. rem/api/routers/chat/child_streaming.py +393 -0
  15. rem/api/routers/chat/completions.py +77 -40
  16. rem/api/routers/chat/sse_events.py +7 -3
  17. rem/api/routers/chat/streaming.py +381 -291
  18. rem/api/routers/chat/streaming_utils.py +325 -0
  19. rem/api/routers/common.py +18 -0
  20. rem/api/routers/dev.py +7 -1
  21. rem/api/routers/feedback.py +11 -3
  22. rem/api/routers/messages.py +176 -38
  23. rem/api/routers/models.py +9 -1
  24. rem/api/routers/query.py +17 -15
  25. rem/api/routers/shared_sessions.py +16 -0
  26. rem/auth/jwt.py +19 -4
  27. rem/auth/middleware.py +42 -28
  28. rem/cli/README.md +62 -0
  29. rem/cli/commands/ask.py +205 -114
  30. rem/cli/commands/db.py +55 -31
  31. rem/cli/commands/experiments.py +1 -1
  32. rem/cli/commands/process.py +179 -43
  33. rem/cli/commands/query.py +109 -0
  34. rem/cli/commands/session.py +117 -0
  35. rem/cli/main.py +2 -0
  36. rem/models/core/experiment.py +1 -1
  37. rem/models/entities/ontology.py +18 -20
  38. rem/models/entities/session.py +1 -0
  39. rem/schemas/agents/core/agent-builder.yaml +1 -1
  40. rem/schemas/agents/rem.yaml +1 -1
  41. rem/schemas/agents/test_orchestrator.yaml +42 -0
  42. rem/schemas/agents/test_structured_output.yaml +52 -0
  43. rem/services/content/providers.py +151 -49
  44. rem/services/content/service.py +18 -5
  45. rem/services/embeddings/worker.py +26 -12
  46. rem/services/postgres/__init__.py +28 -3
  47. rem/services/postgres/diff_service.py +57 -5
  48. rem/services/postgres/programmable_diff_service.py +635 -0
  49. rem/services/postgres/pydantic_to_sqlalchemy.py +2 -2
  50. rem/services/postgres/register_type.py +11 -10
  51. rem/services/postgres/repository.py +39 -28
  52. rem/services/postgres/schema_generator.py +5 -5
  53. rem/services/postgres/sql_builder.py +6 -5
  54. rem/services/rem/README.md +4 -3
  55. rem/services/rem/parser.py +7 -10
  56. rem/services/rem/service.py +47 -0
  57. rem/services/session/__init__.py +8 -1
  58. rem/services/session/compression.py +47 -5
  59. rem/services/session/pydantic_messages.py +310 -0
  60. rem/services/session/reload.py +2 -1
  61. rem/settings.py +92 -7
  62. rem/sql/migrations/001_install.sql +125 -7
  63. rem/sql/migrations/002_install_models.sql +159 -149
  64. rem/sql/migrations/004_cache_system.sql +10 -276
  65. rem/sql/migrations/migrate_session_id_to_uuid.sql +45 -0
  66. rem/utils/schema_loader.py +180 -120
  67. {remdb-0.3.180.dist-info → remdb-0.3.258.dist-info}/METADATA +7 -6
  68. {remdb-0.3.180.dist-info → remdb-0.3.258.dist-info}/RECORD +70 -61
  69. {remdb-0.3.180.dist-info → remdb-0.3.258.dist-info}/WHEEL +0 -0
  70. {remdb-0.3.180.dist-info → remdb-0.3.258.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,52 @@
1
+ # =============================================================================
2
+ # TEST STRUCTURED OUTPUT AGENT
3
+ # =============================================================================
4
+ # Simple agent for testing structured_output: true functionality
5
+ # =============================================================================
6
+
7
+ name: test_structured_output
8
+ version: "1.0"
9
+ description: |
10
+ You are a test agent that produces structured output.
11
+
12
+ Your ONLY job is to return a structured response matching the schema below.
13
+
14
+ Based on the user's input, extract:
15
+ - summary: A brief summary of what they said
16
+ - sentiment: positive, negative, or neutral
17
+ - keywords: List of key words from their message
18
+
19
+ DO NOT ask questions. Just produce the structured output.
20
+
21
+ type: object
22
+ properties:
23
+ result:
24
+ type: object
25
+ description: Structured analysis result
26
+ properties:
27
+ summary:
28
+ type: string
29
+ description: Brief summary of the input
30
+ sentiment:
31
+ type: string
32
+ enum: [positive, negative, neutral]
33
+ description: Overall sentiment
34
+ keywords:
35
+ type: array
36
+ items:
37
+ type: string
38
+ description: Key words extracted from input
39
+ required: [summary, sentiment, keywords]
40
+ additionalProperties: false
41
+
42
+ required:
43
+ - result
44
+
45
+ json_schema_extra:
46
+ kind: agent
47
+ name: test_structured_output
48
+ version: "1.0.0"
49
+ tags: [test, structured-output]
50
+ structured_output: true
51
+ tools: []
52
+ resources: []
@@ -118,15 +118,40 @@ class DocProvider(ContentProvider):
118
118
  - Images (.png, .jpg) - OCR text extraction
119
119
 
120
120
  Handles:
121
- - Text extraction with OCR fallback
121
+ - Text extraction with automatic OCR fallback for scanned documents
122
122
  - Table detection and extraction
123
123
  - Daemon process workaround for multiprocessing restrictions
124
+
125
+ Environment Variables:
126
+ EXTRACTION_OCR_FALLBACK: Enable OCR fallback (default: true)
127
+ EXTRACTION_OCR_THRESHOLD: Min chars before triggering OCR fallback (default: 100)
128
+ EXTRACTION_FORCE_OCR: Always use OCR, skip native extraction (default: false)
129
+ EXTRACTION_OCR_LANGUAGE: Tesseract language codes (default: eng)
124
130
  """
125
131
 
126
132
  @property
127
133
  def name(self) -> str:
128
134
  return "doc"
129
135
 
136
+ def _get_env_bool(self, key: str, default: bool) -> bool:
137
+ """Get boolean from environment variable."""
138
+ import os
139
+ val = os.environ.get(key, "").lower()
140
+ if val in ("true", "1", "yes"):
141
+ return True
142
+ elif val in ("false", "0", "no"):
143
+ return False
144
+ return default
145
+
146
+ def _get_env_int(self, key: str, default: int) -> int:
147
+ """Get integer from environment variable."""
148
+ import os
149
+ val = os.environ.get(key, "")
150
+ try:
151
+ return int(val) if val else default
152
+ except ValueError:
153
+ return default
154
+
130
155
  def _is_daemon_process(self) -> bool:
131
156
  """Check if running in a daemon process."""
132
157
  try:
@@ -134,29 +159,34 @@ class DocProvider(ContentProvider):
134
159
  except Exception:
135
160
  return False
136
161
 
137
- def _parse_in_subprocess(self, file_path: Path) -> dict:
162
+ def _parse_in_subprocess(self, file_path: Path, force_ocr: bool = False) -> dict:
138
163
  """Run kreuzberg in a separate subprocess to bypass daemon restrictions."""
139
- script = """
164
+ import os
165
+ ocr_language = os.environ.get("EXTRACTION_OCR_LANGUAGE", "eng")
166
+
167
+ script = f"""
140
168
  import json
141
169
  import sys
142
170
  from pathlib import Path
143
- from kreuzberg import ExtractionConfig, extract_file_sync
171
+ from kreuzberg import ExtractionConfig, OcrConfig, extract_file_sync
144
172
 
145
- # Parse document with kreuzberg 3.x
146
- config = ExtractionConfig(
147
- extract_tables=True,
148
- chunk_content=False,
149
- extract_keywords=False,
150
- )
173
+ force_ocr = {force_ocr}
174
+
175
+ if force_ocr:
176
+ config = ExtractionConfig(
177
+ force_ocr=True,
178
+ ocr=OcrConfig(backend="tesseract", language="{ocr_language}")
179
+ )
180
+ else:
181
+ config = ExtractionConfig()
151
182
 
152
183
  result = extract_file_sync(Path(sys.argv[1]), config=config)
153
184
 
154
- # Serialize result to JSON
155
- output = {
185
+ output = {{
156
186
  'content': result.content,
157
- 'tables': [t.model_dump() for t in result.tables] if result.tables else [],
158
- 'metadata': result.metadata
159
- }
187
+ 'tables': [],
188
+ 'metadata': {{}}
189
+ }}
160
190
  print(json.dumps(output))
161
191
  """
162
192
 
@@ -173,9 +203,41 @@ print(json.dumps(output))
173
203
 
174
204
  return json.loads(result.stdout)
175
205
 
206
+ def _extract_with_config(self, tmp_path: Path, force_ocr: bool = False) -> tuple[str, dict]:
207
+ """Extract content with optional OCR config."""
208
+ import os
209
+ from kreuzberg import ExtractionConfig, OcrConfig, extract_file_sync
210
+
211
+ ocr_language = os.environ.get("EXTRACTION_OCR_LANGUAGE", "eng")
212
+
213
+ if force_ocr:
214
+ config = ExtractionConfig(
215
+ force_ocr=True,
216
+ ocr=OcrConfig(backend="tesseract", language=ocr_language)
217
+ )
218
+ parser_name = "kreuzberg_ocr"
219
+ else:
220
+ config = ExtractionConfig()
221
+ parser_name = "kreuzberg"
222
+
223
+ result = extract_file_sync(tmp_path, config=config)
224
+ text = result.content
225
+
226
+ extraction_metadata = {
227
+ "parser": parser_name,
228
+ "file_extension": tmp_path.suffix,
229
+ }
230
+
231
+ return text, extraction_metadata
232
+
176
233
  def extract(self, content: bytes, metadata: dict[str, Any]) -> dict[str, Any]:
177
234
  """
178
- Extract document content using Kreuzberg.
235
+ Extract document content using Kreuzberg with intelligent OCR fallback.
236
+
237
+ Process:
238
+ 1. Try native text extraction first (fast, preserves structure)
239
+ 2. If content is minimal (< threshold chars), retry with OCR
240
+ 3. Use OCR result if it's better than native result
179
241
 
180
242
  Args:
181
243
  content: Document file bytes
@@ -184,49 +246,89 @@ print(json.dumps(output))
184
246
  Returns:
185
247
  dict with text and extraction metadata
186
248
  """
249
+ # Get OCR settings from environment
250
+ force_ocr = self._get_env_bool("EXTRACTION_FORCE_OCR", False)
251
+ ocr_fallback = self._get_env_bool("EXTRACTION_OCR_FALLBACK", True)
252
+ ocr_threshold = self._get_env_int("EXTRACTION_OCR_THRESHOLD", 100)
253
+
187
254
  # Write bytes to temp file for kreuzberg
188
- # Detect extension from metadata
189
255
  content_type = metadata.get("content_type", "")
190
256
  suffix = get_extension(content_type, default=".pdf")
191
257
 
192
258
  with temp_file_from_bytes(content, suffix=suffix) as tmp_path:
259
+ ocr_used = False
260
+ ocr_fallback_triggered = False
261
+ native_char_count = 0
262
+
193
263
  # Check if running in daemon process
194
264
  if self._is_daemon_process():
195
- logger.info("Daemon process detected - using subprocess workaround for document parsing")
265
+ logger.info("Daemon process detected - using subprocess workaround")
196
266
  try:
197
- result_dict = self._parse_in_subprocess(tmp_path)
198
- text = result_dict["content"]
199
- extraction_metadata = {
200
- "table_count": len(result_dict["tables"]),
201
- "parser": "kreuzberg_subprocess",
202
- "file_extension": tmp_path.suffix,
203
- }
267
+ if force_ocr:
268
+ result_dict = self._parse_in_subprocess(tmp_path, force_ocr=True)
269
+ text = result_dict["content"]
270
+ ocr_used = True
271
+ extraction_metadata = {
272
+ "parser": "kreuzberg_subprocess_ocr",
273
+ "file_extension": tmp_path.suffix,
274
+ }
275
+ else:
276
+ # Try native first
277
+ result_dict = self._parse_in_subprocess(tmp_path, force_ocr=False)
278
+ text = result_dict["content"]
279
+ native_char_count = len(text)
280
+
281
+ # OCR fallback if content is minimal
282
+ if ocr_fallback and len(text.strip()) < ocr_threshold:
283
+ logger.warning(f"Content below threshold ({len(text.strip())} < {ocr_threshold}) - trying OCR fallback")
284
+ try:
285
+ ocr_result = self._parse_in_subprocess(tmp_path, force_ocr=True)
286
+ ocr_text = ocr_result["content"]
287
+ if len(ocr_text.strip()) > len(text.strip()):
288
+ logger.info(f"OCR fallback improved result: {len(ocr_text)} chars (was {native_char_count})")
289
+ text = ocr_text
290
+ ocr_used = True
291
+ ocr_fallback_triggered = True
292
+ except Exception as e:
293
+ logger.warning(f"OCR fallback failed in subprocess: {e}")
294
+
295
+ extraction_metadata = {
296
+ "parser": "kreuzberg_subprocess" if not ocr_used else "kreuzberg_subprocess_ocr_fallback",
297
+ "file_extension": tmp_path.suffix,
298
+ }
204
299
  except Exception as e:
205
- logger.error(f"Subprocess parsing failed: {e}. Falling back to text-only.")
206
- # Fallback to simple text extraction (kreuzberg 3.x API)
207
- from kreuzberg import ExtractionConfig, extract_file_sync
208
- config = ExtractionConfig(extract_tables=False)
209
- result = extract_file_sync(tmp_path, config=config)
210
- text = result.content
211
- extraction_metadata = {
212
- "parser": "kreuzberg_fallback",
213
- "file_extension": tmp_path.suffix,
214
- }
300
+ logger.error(f"Subprocess parsing failed: {e}. Falling back to direct call.")
301
+ text, extraction_metadata = self._extract_with_config(tmp_path, force_ocr=force_ocr)
302
+ ocr_used = force_ocr
215
303
  else:
216
- # Normal execution (not in daemon) - kreuzberg 4.x with native ONNX/Rust
217
- from kreuzberg import ExtractionConfig, extract_file_sync
218
- config = ExtractionConfig(
219
- enable_quality_processing=True, # Enables table extraction with native ONNX
220
- chunk_content=False, # We handle chunking ourselves
221
- extract_tables=False, # Disable table extraction to avoid PyTorch dependency
222
- )
223
- result = extract_file_sync(tmp_path, config=config)
224
- text = result.content
225
- extraction_metadata = {
226
- "table_count": len(result.tables) if result.tables else 0,
227
- "parser": "kreuzberg",
228
- "file_extension": tmp_path.suffix,
229
- }
304
+ # Normal execution (not in daemon)
305
+ if force_ocr:
306
+ text, extraction_metadata = self._extract_with_config(tmp_path, force_ocr=True)
307
+ ocr_used = True
308
+ else:
309
+ # Try native first
310
+ text, extraction_metadata = self._extract_with_config(tmp_path, force_ocr=False)
311
+ native_char_count = len(text)
312
+
313
+ # OCR fallback if content is minimal
314
+ if ocr_fallback and len(text.strip()) < ocr_threshold:
315
+ logger.warning(f"Content below threshold ({len(text.strip())} < {ocr_threshold}) - trying OCR fallback")
316
+ try:
317
+ ocr_text, _ = self._extract_with_config(tmp_path, force_ocr=True)
318
+ if len(ocr_text.strip()) > len(text.strip()):
319
+ logger.info(f"OCR fallback improved result: {len(ocr_text)} chars (was {native_char_count})")
320
+ text = ocr_text
321
+ ocr_used = True
322
+ ocr_fallback_triggered = True
323
+ extraction_metadata["parser"] = "kreuzberg_ocr_fallback"
324
+ except Exception as e:
325
+ logger.warning(f"OCR fallback failed: {e}")
326
+
327
+ # Add OCR metadata
328
+ extraction_metadata["ocr_used"] = ocr_used
329
+ extraction_metadata["ocr_fallback_triggered"] = ocr_fallback_triggered
330
+ extraction_metadata["native_char_count"] = native_char_count
331
+ extraction_metadata["final_char_count"] = len(text)
230
332
 
231
333
  return {
232
334
  "text": text,
@@ -274,7 +274,7 @@ class ContentService:
274
274
  async def ingest_file(
275
275
  self,
276
276
  file_uri: str,
277
- user_id: str,
277
+ user_id: str | None = None,
278
278
  category: str | None = None,
279
279
  tags: list[str] | None = None,
280
280
  is_local_server: bool = False,
@@ -283,6 +283,10 @@ class ContentService:
283
283
  """
284
284
  Complete file ingestion pipeline: read → store → parse → chunk → embed.
285
285
 
286
+ **IMPORTANT: Data is PUBLIC by default (user_id=None).**
287
+ This is correct for shared knowledge bases (ontologies, procedures, reference data).
288
+ Private user-scoped data is rarely needed - only set user_id for truly personal content.
289
+
286
290
  **CENTRALIZED INGESTION**: This is the single entry point for all file ingestion
287
291
  in REM. It handles:
288
292
 
@@ -319,7 +323,9 @@ class ContentService:
319
323
 
320
324
  Args:
321
325
  file_uri: Source file location (local path, s3://, or https://)
322
- user_id: User identifier for data isolation and ownership
326
+ user_id: User identifier for PRIVATE data only. Default None = PUBLIC/shared.
327
+ Leave as None for shared knowledge bases, ontologies, reference data.
328
+ Only set for truly private user-specific content.
323
329
  category: Optional category tag (document, code, audio, etc.)
324
330
  tags: Optional list of tags
325
331
  is_local_server: True if running as local/stdio MCP server
@@ -347,12 +353,19 @@ class ContentService:
347
353
 
348
354
  Example:
349
355
  >>> service = ContentService()
356
+ >>> # PUBLIC data (default) - visible to all users
350
357
  >>> result = await service.ingest_file(
351
- ... file_uri="s3://bucket/contract.pdf",
352
- ... user_id="user-123",
353
- ... category="legal"
358
+ ... file_uri="s3://bucket/procedure.pdf",
359
+ ... category="medical"
354
360
  ... )
355
361
  >>> print(f"Created {result['resources_created']} searchable chunks")
362
+ >>>
363
+ >>> # PRIVATE data (rare) - only for user-specific content
364
+ >>> result = await service.ingest_file(
365
+ ... file_uri="s3://bucket/personal-notes.pdf",
366
+ ... user_id="user-123", # Only this user can access
367
+ ... category="personal"
368
+ ... )
356
369
  """
357
370
  from pathlib import Path
358
371
  from uuid import uuid4
@@ -23,6 +23,8 @@ Future:
23
23
  import asyncio
24
24
  import os
25
25
  from typing import Any, Optional
26
+ import hashlib
27
+ import uuid
26
28
  from uuid import uuid4
27
29
 
28
30
  import httpx
@@ -108,6 +110,7 @@ class EmbeddingWorker:
108
110
  self.task_queue: asyncio.Queue = asyncio.Queue()
109
111
  self.workers: list[asyncio.Task] = []
110
112
  self.running = False
113
+ self._in_flight_count = 0 # Track tasks being processed (not just in queue)
111
114
 
112
115
  # Store API key for direct HTTP requests
113
116
  from ...settings import settings
@@ -143,17 +146,18 @@ class EmbeddingWorker:
143
146
  return
144
147
 
145
148
  queue_size = self.task_queue.qsize()
146
- logger.debug(f"Stopping EmbeddingWorker (processing {queue_size} queued tasks first)")
149
+ in_flight = self._in_flight_count
150
+ logger.debug(f"Stopping EmbeddingWorker (queue={queue_size}, in_flight={in_flight})")
147
151
 
148
- # Wait for queue to drain (with timeout)
152
+ # Wait for both queue to drain AND in-flight tasks to complete
149
153
  max_wait = 30 # 30 seconds max
150
154
  waited = 0.0
151
- while not self.task_queue.empty() and waited < max_wait:
155
+ while (not self.task_queue.empty() or self._in_flight_count > 0) and waited < max_wait:
152
156
  await asyncio.sleep(0.5)
153
157
  waited += 0.5
154
158
 
155
- if not self.task_queue.empty():
156
- remaining = self.task_queue.qsize()
159
+ if not self.task_queue.empty() or self._in_flight_count > 0:
160
+ remaining = self.task_queue.qsize() + self._in_flight_count
157
161
  logger.warning(
158
162
  f"EmbeddingWorker timeout: {remaining} tasks remaining after {max_wait}s"
159
163
  )
@@ -205,12 +209,18 @@ class EmbeddingWorker:
205
209
  if not batch:
206
210
  continue
207
211
 
208
- logger.debug(f"Worker {worker_id} processing batch of {len(batch)} tasks")
212
+ # Track in-flight tasks
213
+ self._in_flight_count += len(batch)
209
214
 
210
- # Generate embeddings for batch
211
- await self._process_batch(batch)
215
+ logger.debug(f"Worker {worker_id} processing batch of {len(batch)} tasks")
212
216
 
213
- logger.debug(f"Worker {worker_id} completed batch")
217
+ try:
218
+ # Generate embeddings for batch
219
+ await self._process_batch(batch)
220
+ logger.debug(f"Worker {worker_id} completed batch")
221
+ finally:
222
+ # Always decrement in-flight count, even on error
223
+ self._in_flight_count -= len(batch)
214
224
 
215
225
  except asyncio.CancelledError:
216
226
  logger.debug(f"Worker {worker_id} cancelled")
@@ -373,7 +383,11 @@ class EmbeddingWorker:
373
383
  for task, embedding in zip(tasks, embeddings):
374
384
  table_name = f"embeddings_{task.table_name}"
375
385
 
376
- # Build upsert SQL
386
+ # Generate deterministic ID from key fields (entity_id, field_name, provider)
387
+ key_string = f"{task.entity_id}:{task.field_name}:{task.provider}"
388
+ embedding_id = str(uuid.UUID(hashlib.md5(key_string.encode()).hexdigest()))
389
+
390
+ # Build upsert SQL - conflict on deterministic ID
377
391
  sql = f"""
378
392
  INSERT INTO {table_name} (
379
393
  id,
@@ -386,7 +400,7 @@ class EmbeddingWorker:
386
400
  updated_at
387
401
  )
388
402
  VALUES ($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
389
- ON CONFLICT (entity_id, field_name, provider)
403
+ ON CONFLICT (id)
390
404
  DO UPDATE SET
391
405
  model = EXCLUDED.model,
392
406
  embedding = EXCLUDED.embedding,
@@ -400,7 +414,7 @@ class EmbeddingWorker:
400
414
  await self.postgres_service.execute(
401
415
  sql,
402
416
  (
403
- str(uuid4()),
417
+ embedding_id,
404
418
  task.entity_id,
405
419
  task.field_name,
406
420
  task.provider,
@@ -3,22 +3,47 @@ PostgreSQL service for CloudNativePG database operations.
3
3
  """
4
4
 
5
5
  from .diff_service import DiffService, SchemaDiff
6
+ from .programmable_diff_service import (
7
+ DiffResult,
8
+ ObjectDiff,
9
+ ObjectType,
10
+ ProgrammableDiffService,
11
+ )
6
12
  from .repository import Repository
7
13
  from .service import PostgresService
8
14
 
9
15
 
16
+ _postgres_instance: PostgresService | None = None
17
+
18
+
10
19
  def get_postgres_service() -> PostgresService | None:
11
20
  """
12
- Get PostgresService instance.
21
+ Get PostgresService singleton instance.
13
22
 
14
23
  Returns None if Postgres is disabled.
24
+ Uses singleton pattern to prevent connection pool exhaustion.
15
25
  """
26
+ global _postgres_instance
27
+
16
28
  from ...settings import settings
17
29
 
18
30
  if not settings.postgres.enabled:
19
31
  return None
20
32
 
21
- return PostgresService()
33
+ if _postgres_instance is None:
34
+ _postgres_instance = PostgresService()
35
+
36
+ return _postgres_instance
22
37
 
23
38
 
24
- __all__ = ["PostgresService", "get_postgres_service", "Repository", "DiffService", "SchemaDiff"]
39
+ __all__ = [
40
+ "DiffResult",
41
+ "DiffService",
42
+ "ObjectDiff",
43
+ "ObjectType",
44
+ "PostgresService",
45
+ "ProgrammableDiffService",
46
+ "Repository",
47
+ "SchemaDiff",
48
+ "get_postgres_service",
49
+ ]
@@ -5,12 +5,17 @@ Uses Alembic autogenerate to detect differences between:
5
5
  - Target schema (derived from Pydantic models)
6
6
  - Current database schema
7
7
 
8
+ Also compares programmable objects (functions, triggers, views) which
9
+ Alembic does not track.
10
+
8
11
  This enables:
9
12
  1. Local development: See what would change before applying migrations
10
13
  2. CI validation: Detect drift between code and database (--check mode)
11
14
  3. Migration generation: Create incremental migration files
12
15
  """
13
16
 
17
+ import asyncio
18
+ import re
14
19
  from dataclasses import dataclass, field
15
20
  from pathlib import Path
16
21
  from typing import Optional
@@ -51,11 +56,14 @@ class SchemaDiff:
51
56
  sql: str = ""
52
57
  upgrade_ops: Optional[ops.UpgradeOps] = None
53
58
  filtered_count: int = 0 # Number of operations filtered out by strategy
59
+ # Programmable objects (functions, triggers, views)
60
+ programmable_summary: list[str] = field(default_factory=list)
61
+ programmable_sql: str = ""
54
62
 
55
63
  @property
56
64
  def change_count(self) -> int:
57
65
  """Total number of detected changes."""
58
- return len(self.summary)
66
+ return len(self.summary) + len(self.programmable_summary)
59
67
 
60
68
 
61
69
  class DiffService:
@@ -127,10 +135,13 @@ class DiffService:
127
135
  # These are now generated in pydantic_to_sqlalchemy
128
136
  return True
129
137
 
130
- def compute_diff(self) -> SchemaDiff:
138
+ def compute_diff(self, include_programmable: bool = True) -> SchemaDiff:
131
139
  """
132
140
  Compare Pydantic models against database and return differences.
133
141
 
142
+ Args:
143
+ include_programmable: If True, also diff functions/triggers/views
144
+
134
145
  Returns:
135
146
  SchemaDiff with detected changes
136
147
  """
@@ -167,21 +178,62 @@ class DiffService:
167
178
  for op in filtered_ops:
168
179
  summary.extend(self._describe_operation(op))
169
180
 
170
- has_changes = len(summary) > 0
171
-
172
181
  # Generate SQL if there are changes
173
182
  sql = ""
174
- if has_changes and upgrade_ops:
183
+ if summary and upgrade_ops:
175
184
  sql = self._render_sql(upgrade_ops, engine)
176
185
 
186
+ # Programmable objects diff (functions, triggers, views)
187
+ programmable_summary = []
188
+ programmable_sql = ""
189
+ if include_programmable:
190
+ prog_summary, prog_sql = self._compute_programmable_diff()
191
+ programmable_summary = prog_summary
192
+ programmable_sql = prog_sql
193
+
194
+ has_changes = len(summary) > 0 or len(programmable_summary) > 0
195
+
177
196
  return SchemaDiff(
178
197
  has_changes=has_changes,
179
198
  summary=summary,
180
199
  sql=sql,
181
200
  upgrade_ops=upgrade_ops,
182
201
  filtered_count=filtered_count,
202
+ programmable_summary=programmable_summary,
203
+ programmable_sql=programmable_sql,
183
204
  )
184
205
 
206
+ def _compute_programmable_diff(self) -> tuple[list[str], str]:
207
+ """
208
+ Compute diff for programmable objects (functions, triggers, views).
209
+
210
+ Returns:
211
+ Tuple of (summary_lines, sync_sql)
212
+ """
213
+ from .programmable_diff_service import ProgrammableDiffService
214
+
215
+ service = ProgrammableDiffService()
216
+
217
+ # Run async diff in sync context
218
+ try:
219
+ loop = asyncio.get_event_loop()
220
+ except RuntimeError:
221
+ loop = asyncio.new_event_loop()
222
+ asyncio.set_event_loop(loop)
223
+
224
+ result = loop.run_until_complete(service.compute_diff())
225
+
226
+ summary = []
227
+ for diff in result.diffs:
228
+ if diff.status == "missing":
229
+ summary.append(f"+ {diff.object_type.value.upper()} {diff.name} (missing)")
230
+ elif diff.status == "different":
231
+ summary.append(f"~ {diff.object_type.value.upper()} {diff.name} (different)")
232
+ elif diff.status == "extra":
233
+ summary.append(f"- {diff.object_type.value.upper()} {diff.name} (extra in db)")
234
+
235
+ return summary, result.sync_sql
236
+
185
237
  def _filter_operations(self, operations: list) -> tuple[list, int]:
186
238
  """
187
239
  Filter operations based on migration strategy.