remdb 0.3.133__py3-none-any.whl → 0.3.157__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. rem/agentic/agents/__init__.py +16 -0
  2. rem/agentic/agents/agent_manager.py +310 -0
  3. rem/agentic/context_builder.py +5 -3
  4. rem/agentic/mcp/tool_wrapper.py +48 -6
  5. rem/agentic/providers/phoenix.py +91 -21
  6. rem/agentic/providers/pydantic_ai.py +77 -43
  7. rem/api/deps.py +2 -2
  8. rem/api/main.py +1 -1
  9. rem/api/mcp_router/server.py +2 -0
  10. rem/api/mcp_router/tools.py +90 -0
  11. rem/api/routers/auth.py +208 -4
  12. rem/api/routers/chat/streaming.py +77 -22
  13. rem/auth/__init__.py +13 -3
  14. rem/auth/middleware.py +66 -1
  15. rem/auth/providers/__init__.py +4 -1
  16. rem/auth/providers/email.py +215 -0
  17. rem/cli/commands/configure.py +3 -4
  18. rem/cli/commands/experiments.py +50 -49
  19. rem/cli/commands/session.py +336 -0
  20. rem/cli/dreaming.py +2 -2
  21. rem/cli/main.py +2 -0
  22. rem/models/core/experiment.py +4 -14
  23. rem/models/entities/__init__.py +4 -0
  24. rem/models/entities/ontology.py +1 -1
  25. rem/models/entities/ontology_config.py +1 -1
  26. rem/models/entities/subscriber.py +175 -0
  27. rem/models/entities/user.py +1 -0
  28. rem/schemas/agents/core/agent-builder.yaml +134 -0
  29. rem/schemas/agents/examples/contract-analyzer.yaml +1 -1
  30. rem/schemas/agents/examples/contract-extractor.yaml +1 -1
  31. rem/schemas/agents/examples/cv-parser.yaml +1 -1
  32. rem/services/__init__.py +3 -1
  33. rem/services/content/service.py +4 -3
  34. rem/services/email/__init__.py +10 -0
  35. rem/services/email/service.py +459 -0
  36. rem/services/email/templates.py +360 -0
  37. rem/services/postgres/README.md +38 -0
  38. rem/services/postgres/diff_service.py +19 -3
  39. rem/services/postgres/pydantic_to_sqlalchemy.py +45 -13
  40. rem/services/session/compression.py +113 -50
  41. rem/services/session/reload.py +14 -7
  42. rem/settings.py +191 -4
  43. rem/sql/migrations/002_install_models.sql +91 -91
  44. rem/sql/migrations/005_schema_update.sql +145 -0
  45. rem/utils/README.md +45 -0
  46. rem/utils/files.py +157 -1
  47. rem/utils/vision.py +1 -1
  48. {remdb-0.3.133.dist-info → remdb-0.3.157.dist-info}/METADATA +7 -5
  49. {remdb-0.3.133.dist-info → remdb-0.3.157.dist-info}/RECORD +51 -42
  50. {remdb-0.3.133.dist-info → remdb-0.3.157.dist-info}/WHEEL +0 -0
  51. {remdb-0.3.133.dist-info → remdb-0.3.157.dist-info}/entry_points.txt +0 -0
@@ -564,9 +564,37 @@ async def create_agent(
564
564
  mcp_server_configs = []
565
565
  resource_configs = []
566
566
 
567
- # Default to rem.mcp_server if no MCP servers configured
567
+ # Auto-detect local MCP server if not explicitly configured
568
+ # This makes mcp_servers config optional - agents get tools automatically
568
569
  if not mcp_server_configs:
569
- mcp_server_configs = [{"type": "local", "module": "rem.mcp_server", "id": "rem"}]
570
+ import importlib
571
+ import os
572
+ import sys
573
+
574
+ # Ensure current working directory is in sys.path for local imports
575
+ cwd = os.getcwd()
576
+ if cwd not in sys.path:
577
+ sys.path.insert(0, cwd)
578
+
579
+ # Try common local MCP server module paths first
580
+ auto_detect_modules = [
581
+ "tools.mcp_server", # Convention: tools/mcp_server.py
582
+ "mcp_server", # Alternative: mcp_server.py in root
583
+ ]
584
+ for module_path in auto_detect_modules:
585
+ try:
586
+ mcp_module = importlib.import_module(module_path)
587
+ if hasattr(mcp_module, "mcp"):
588
+ logger.info(f"Auto-detected local MCP server: {module_path}")
589
+ mcp_server_configs = [{"type": "local", "module": module_path, "id": "auto-detected"}]
590
+ break
591
+ except ImportError:
592
+ continue
593
+
594
+ # Fall back to REM's default MCP server if no local server found
595
+ if not mcp_server_configs:
596
+ logger.debug("No local MCP server found, using REM default")
597
+ mcp_server_configs = [{"type": "local", "module": "rem.mcp_server", "id": "rem"}]
570
598
 
571
599
  # Extract temperature and max_iterations from schema metadata (with fallback to settings defaults)
572
600
  if metadata:
@@ -612,46 +640,51 @@ async def create_agent(
612
640
  search_rem_suffix += f"Example: `SEARCH \"your query\" FROM {default_table} LIMIT 10`"
613
641
 
614
642
  # Add tools from MCP server (in-process, no subprocess)
615
- if mcp_server_configs:
616
- for server_config in mcp_server_configs:
617
- server_type = server_config.get("type")
618
- server_id = server_config.get("id", "mcp-server")
619
-
620
- if server_type == "local":
621
- # Import MCP server directly (in-process)
622
- module_path = server_config.get("module", "rem.mcp_server")
623
-
624
- try:
625
- # Dynamic import of MCP server module
626
- import importlib
627
- mcp_module = importlib.import_module(module_path)
628
- mcp_server = mcp_module.mcp
629
-
630
- # Extract tools from MCP server (get_tools is async)
631
- from ..mcp.tool_wrapper import create_mcp_tool_wrapper
632
-
633
- # Await async get_tools() call
634
- mcp_tools_dict = await mcp_server.get_tools()
635
-
636
- for tool_name, tool_func in mcp_tools_dict.items():
637
- # Add description suffix to search_rem tool if schema specifies a default table
638
- tool_suffix = search_rem_suffix if tool_name == "search_rem" else None
639
-
640
- wrapped_tool = create_mcp_tool_wrapper(
641
- tool_name,
642
- tool_func,
643
- user_id=context.user_id if context else None,
644
- description_suffix=tool_suffix,
645
- )
646
- tools.append(wrapped_tool)
647
- logger.debug(f"Loaded MCP tool: {tool_name}" + (" (with schema suffix)" if tool_suffix else ""))
648
-
649
- logger.info(f"Loaded {len(mcp_tools_dict)} tools from MCP server: {server_id} (in-process)")
650
-
651
- except Exception as e:
652
- logger.error(f"Failed to load MCP server {server_id}: {e}", exc_info=True)
653
- else:
654
- logger.warning(f"Unsupported MCP server type: {server_type}")
643
+ # Track loaded MCP servers for resource resolution
644
+ loaded_mcp_server = None
645
+
646
+ for server_config in mcp_server_configs:
647
+ server_type = server_config.get("type")
648
+ server_id = server_config.get("id", "mcp-server")
649
+
650
+ if server_type == "local":
651
+ # Import MCP server directly (in-process)
652
+ module_path = server_config.get("module", "rem.mcp_server")
653
+
654
+ try:
655
+ # Dynamic import of MCP server module
656
+ import importlib
657
+ mcp_module = importlib.import_module(module_path)
658
+ mcp_server = mcp_module.mcp
659
+
660
+ # Store the loaded server for resource resolution
661
+ loaded_mcp_server = mcp_server
662
+
663
+ # Extract tools from MCP server (get_tools is async)
664
+ from ..mcp.tool_wrapper import create_mcp_tool_wrapper
665
+
666
+ # Await async get_tools() call
667
+ mcp_tools_dict = await mcp_server.get_tools()
668
+
669
+ for tool_name, tool_func in mcp_tools_dict.items():
670
+ # Add description suffix to search_rem tool if schema specifies a default table
671
+ tool_suffix = search_rem_suffix if tool_name == "search_rem" else None
672
+
673
+ wrapped_tool = create_mcp_tool_wrapper(
674
+ tool_name,
675
+ tool_func,
676
+ user_id=context.user_id if context else None,
677
+ description_suffix=tool_suffix,
678
+ )
679
+ tools.append(wrapped_tool)
680
+ logger.debug(f"Loaded MCP tool: {tool_name}" + (" (with schema suffix)" if tool_suffix else ""))
681
+
682
+ logger.info(f"Loaded {len(mcp_tools_dict)} tools from MCP server: {server_id} (in-process)")
683
+
684
+ except Exception as e:
685
+ logger.error(f"Failed to load MCP server {server_id}: {e}", exc_info=True)
686
+ else:
687
+ logger.warning(f"Unsupported MCP server type: {server_type}")
655
688
 
656
689
  # Convert resources to tools (MCP convenience syntax)
657
690
  # Resources declared in agent YAML become callable tools - eliminates
@@ -693,8 +726,9 @@ async def create_agent(
693
726
  resource_uris.append((tool_name, tool_desc))
694
727
 
695
728
  # Create tools from collected resource URIs
729
+ # Pass the loaded MCP server so resources can be resolved from it
696
730
  for uri, usage in resource_uris:
697
- resource_tool = create_resource_tool(uri, usage)
731
+ resource_tool = create_resource_tool(uri, usage, mcp_server=loaded_mcp_server)
698
732
  tools.append(resource_tool)
699
733
  logger.debug(f"Loaded resource as tool: {uri}")
700
734
 
rem/api/deps.py CHANGED
@@ -185,8 +185,8 @@ async def get_user_filter(
185
185
  f"User {user.get('email')} attempted to filter by user_id={x_user_id}"
186
186
  )
187
187
  else:
188
- # Anonymous: could use anonymous tracking ID or restrict access
189
- # For now, anonymous can't access user-scoped data
188
+ # Anonymous: use anonymous tracking ID
189
+ # Note: user_id should come from JWT, not from parameters
190
190
  anon_id = getattr(request.state, "anon_id", None)
191
191
  if anon_id:
192
192
  filters["user_id"] = f"anon:{anon_id}"
rem/api/main.py CHANGED
@@ -304,7 +304,7 @@ def create_app() -> FastAPI:
304
304
  app.add_middleware(
305
305
  AuthMiddleware,
306
306
  protected_paths=["/api/v1"],
307
- excluded_paths=["/api/auth", "/api/dev", "/api/v1/mcp/auth"],
307
+ excluded_paths=["/api/auth", "/api/dev", "/api/v1/mcp/auth", "/api/v1/slack"],
308
308
  # Allow anonymous when auth is disabled, otherwise use setting
309
309
  allow_anonymous=(not settings.auth.enabled) or settings.auth.allow_anonymous,
310
310
  # MCP requires auth only when auth is fully enabled
@@ -182,6 +182,7 @@ def create_mcp_server(is_local: bool = False) -> FastMCP:
182
182
  list_schema,
183
183
  read_resource,
184
184
  register_metadata,
185
+ save_agent,
185
186
  search_rem,
186
187
  )
187
188
 
@@ -191,6 +192,7 @@ def create_mcp_server(is_local: bool = False) -> FastMCP:
191
192
  mcp.tool()(register_metadata)
192
193
  mcp.tool()(list_schema)
193
194
  mcp.tool()(get_schema)
195
+ mcp.tool()(save_agent)
194
196
 
195
197
  # File ingestion tool (with local path support for local servers)
196
198
  # Wrap to inject is_local parameter
@@ -1040,3 +1040,93 @@ async def get_schema(
1040
1040
  logger.info(f"Retrieved schema for table '{table_name}' with {len(column_defs)} columns")
1041
1041
 
1042
1042
  return result
1043
+
1044
+
1045
+ @mcp_tool_error_handler
1046
+ async def save_agent(
1047
+ name: str,
1048
+ description: str,
1049
+ properties: dict[str, Any] | None = None,
1050
+ required: list[str] | None = None,
1051
+ tools: list[str] | None = None,
1052
+ tags: list[str] | None = None,
1053
+ version: str = "1.0.0",
1054
+ user_id: str | None = None,
1055
+ ) -> dict[str, Any]:
1056
+ """
1057
+ Save an agent schema to REM, making it available for use.
1058
+
1059
+ This tool creates or updates an agent definition in the user's schema space.
1060
+ The agent becomes immediately available for conversations.
1061
+
1062
+ **Default Tools**: All agents automatically get `search_rem` and `register_metadata`
1063
+ tools unless explicitly overridden.
1064
+
1065
+ Args:
1066
+ name: Agent name in kebab-case (e.g., "code-reviewer", "sales-assistant").
1067
+ Must be unique within the user's schema space.
1068
+ description: The agent's system prompt. This is the full instruction set
1069
+ that defines the agent's behavior, personality, and capabilities.
1070
+ Use markdown formatting for structure.
1071
+ properties: Output schema properties as a dict. Each property should have:
1072
+ - type: "string", "number", "boolean", "array", "object"
1073
+ - description: What this field captures
1074
+ Example: {"answer": {"type": "string", "description": "Response to user"}}
1075
+ If not provided, defaults to a simple {"answer": {"type": "string"}} schema.
1076
+ required: List of required property names. Defaults to ["answer"] if not provided.
1077
+ tools: List of tool names the agent can use. Defaults to ["search_rem", "register_metadata"].
1078
+ tags: Optional tags for categorizing the agent.
1079
+ version: Semantic version string (default: "1.0.0").
1080
+ user_id: User identifier for scoping. Uses authenticated user if not provided.
1081
+
1082
+ Returns:
1083
+ Dict with:
1084
+ - status: "success" or "error"
1085
+ - agent_name: Name of the saved agent
1086
+ - version: Version saved
1087
+ - message: Human-readable status
1088
+
1089
+ Examples:
1090
+ # Create a simple agent
1091
+ save_agent(
1092
+ name="greeting-bot",
1093
+ description="You are a friendly greeter. Say hello warmly.",
1094
+ properties={"answer": {"type": "string", "description": "Greeting message"}},
1095
+ required=["answer"]
1096
+ )
1097
+
1098
+ # Create agent with structured output
1099
+ save_agent(
1100
+ name="sentiment-analyzer",
1101
+ description="Analyze sentiment of text provided by the user.",
1102
+ properties={
1103
+ "answer": {"type": "string", "description": "Analysis explanation"},
1104
+ "sentiment": {"type": "string", "enum": ["positive", "negative", "neutral"]},
1105
+ "confidence": {"type": "number", "minimum": 0, "maximum": 1}
1106
+ },
1107
+ required=["answer", "sentiment"],
1108
+ tags=["analysis", "nlp"]
1109
+ )
1110
+ """
1111
+ from ...agentic.agents.agent_manager import save_agent as _save_agent
1112
+
1113
+ # Get user_id from context if not provided
1114
+ user_id = AgentContext.get_user_id_or_default(user_id, source="save_agent")
1115
+
1116
+ # Delegate to agent_manager
1117
+ result = await _save_agent(
1118
+ name=name,
1119
+ description=description,
1120
+ user_id=user_id,
1121
+ properties=properties,
1122
+ required=required,
1123
+ tools=tools,
1124
+ tags=tags,
1125
+ version=version,
1126
+ )
1127
+
1128
+ # Add helpful message for Slack users
1129
+ if result.get("status") == "success":
1130
+ result["message"] = f"Agent '{name}' saved. Use `/custom-agent {name}` to chat with it."
1131
+
1132
+ return result
rem/api/routers/auth.py CHANGED
@@ -1,20 +1,68 @@
1
1
  """
2
- OAuth 2.1 Authentication Router.
2
+ Authentication Router.
3
3
 
4
- Leverages Authlib for standards-compliant OAuth/OIDC implementation.
5
- Minimal custom code - Authlib handles PKCE, token validation, JWKS.
4
+ Supports multiple authentication methods:
5
+ 1. Email (passwordless): POST /api/auth/email/send-code, POST /api/auth/email/verify
6
+ 2. OAuth (Google, Microsoft): GET /api/auth/{provider}/login, GET /api/auth/{provider}/callback
6
7
 
7
8
  Endpoints:
9
+ - POST /api/auth/email/send-code - Send login code to email
10
+ - POST /api/auth/email/verify - Verify code and create session
8
11
  - GET /api/auth/{provider}/login - Initiate OAuth flow
9
12
  - GET /api/auth/{provider}/callback - OAuth callback
10
13
  - POST /api/auth/logout - Clear session
11
14
  - GET /api/auth/me - Current user info
12
15
 
13
16
  Supported providers:
17
+ - email: Passwordless email login
14
18
  - google: Google OAuth 2.0 / OIDC
15
19
  - microsoft: Microsoft Entra ID OIDC
16
20
 
17
- Design Pattern (OAuth 2.1 + PKCE):
21
+ =============================================================================
22
+ Email Authentication Access Control
23
+ =============================================================================
24
+
25
+ The email auth provider implements a tiered access control system:
26
+
27
+ Access Control Flow (send-code):
28
+ User requests login code
29
+ ├── User exists in database?
30
+ │ ├── Yes → Check user.tier
31
+ │ │ ├── tier == BLOCKED → Reject "Account is blocked"
32
+ │ │ └── tier != BLOCKED → Allow (send code, existing users grandfathered)
33
+ │ └── No (new user) → Check EMAIL__TRUSTED_EMAIL_DOMAINS
34
+ │ ├── Setting configured → domain in trusted list?
35
+ │ │ ├── Yes → Create user & send code
36
+ │ │ └── No → Reject "Email domain not allowed for signup"
37
+ │ └── Not configured (empty) → Create user & send code (no restrictions)
38
+
39
+ Key Behaviors:
40
+ - Existing users: Always allowed to login (unless tier=BLOCKED)
41
+ - New users: Must have email from trusted domain (if EMAIL__TRUSTED_EMAIL_DOMAINS is set)
42
+ - No restrictions: Leave EMAIL__TRUSTED_EMAIL_DOMAINS empty to allow all domains
43
+
44
+ User Tiers (models.entities.UserTier):
45
+ - BLOCKED: Cannot login (rejected at send-code)
46
+ - ANONYMOUS: Rate-limited anonymous access
47
+ - FREE: Standard free tier
48
+ - BASIC/PRO: Paid tiers with additional features
49
+
50
+ Configuration:
51
+ # Allow only specific domains for new signups
52
+ EMAIL__TRUSTED_EMAIL_DOMAINS=siggymd.ai,example.com
53
+
54
+ # Allow all domains (no restrictions)
55
+ EMAIL__TRUSTED_EMAIL_DOMAINS=
56
+
57
+ Example blocking a user:
58
+ user = await user_repo.get_by_id(user_id, tenant_id="default")
59
+ user.tier = UserTier.BLOCKED
60
+ await user_repo.upsert(user)
61
+
62
+ =============================================================================
63
+ OAuth Design Pattern (OAuth 2.1 + PKCE)
64
+ =============================================================================
65
+
18
66
  1. User clicks "Login with Google"
19
67
  2. /login generates state + PKCE code_verifier
20
68
  3. Store code_verifier in session
@@ -37,6 +85,7 @@ Environment variables:
37
85
  AUTH__MICROSOFT__CLIENT_ID=<microsoft-client-id>
38
86
  AUTH__MICROSOFT__CLIENT_SECRET=<microsoft-client-secret>
39
87
  AUTH__MICROSOFT__TENANT=common
88
+ EMAIL__TRUSTED_EMAIL_DOMAINS=example.com # Optional: restrict new signups
40
89
 
41
90
  References:
42
91
  - Authlib: https://docs.authlib.org/en/latest/
@@ -46,11 +95,13 @@ References:
46
95
  from fastapi import APIRouter, HTTPException, Request
47
96
  from fastapi.responses import RedirectResponse
48
97
  from authlib.integrations.starlette_client import OAuth
98
+ from pydantic import BaseModel, EmailStr
49
99
  from loguru import logger
50
100
 
51
101
  from ...settings import settings
52
102
  from ...services.postgres.service import PostgresService
53
103
  from ...services.user_service import UserService
104
+ from ...auth.providers.email import EmailAuthProvider
54
105
 
55
106
  router = APIRouter(prefix="/api/auth", tags=["auth"])
56
107
 
@@ -87,6 +138,159 @@ if settings.auth.microsoft.client_id:
87
138
  logger.info(f"Microsoft OAuth provider registered (tenant: {tenant})")
88
139
 
89
140
 
141
+ # =============================================================================
142
+ # Email Authentication Endpoints
143
+ # =============================================================================
144
+
145
+
146
+ class EmailSendCodeRequest(BaseModel):
147
+ """Request to send login code."""
148
+ email: EmailStr
149
+
150
+
151
+ class EmailVerifyRequest(BaseModel):
152
+ """Request to verify login code."""
153
+ email: EmailStr
154
+ code: str
155
+
156
+
157
+ @router.post("/email/send-code")
158
+ async def send_email_code(request: Request, body: EmailSendCodeRequest):
159
+ """
160
+ Send a login code to an email address.
161
+
162
+ Creates user if not exists (using deterministic UUID from email).
163
+ Stores code in user metadata with expiry.
164
+
165
+ Args:
166
+ request: FastAPI request
167
+ body: EmailSendCodeRequest with email
168
+
169
+ Returns:
170
+ Success status and message
171
+ """
172
+ if not settings.email.is_configured:
173
+ raise HTTPException(
174
+ status_code=501,
175
+ detail="Email authentication is not configured"
176
+ )
177
+
178
+ # Get database connection
179
+ if not settings.postgres.enabled:
180
+ raise HTTPException(
181
+ status_code=501,
182
+ detail="Database is required for email authentication"
183
+ )
184
+
185
+ db = PostgresService()
186
+ try:
187
+ await db.connect()
188
+
189
+ # Initialize email auth provider
190
+ email_auth = EmailAuthProvider()
191
+
192
+ # Send code
193
+ result = await email_auth.send_code(
194
+ email=body.email,
195
+ db=db,
196
+ )
197
+
198
+ if result.success:
199
+ return {
200
+ "success": True,
201
+ "message": result.message,
202
+ "email": result.email,
203
+ }
204
+ else:
205
+ raise HTTPException(
206
+ status_code=400,
207
+ detail=result.message or result.error
208
+ )
209
+
210
+ except HTTPException:
211
+ raise
212
+ except Exception as e:
213
+ logger.error(f"Error sending login code: {e}")
214
+ raise HTTPException(status_code=500, detail="Failed to send login code")
215
+ finally:
216
+ await db.disconnect()
217
+
218
+
219
+ @router.post("/email/verify")
220
+ async def verify_email_code(request: Request, body: EmailVerifyRequest):
221
+ """
222
+ Verify login code and create session.
223
+
224
+ Args:
225
+ request: FastAPI request
226
+ body: EmailVerifyRequest with email and code
227
+
228
+ Returns:
229
+ Success status with user info
230
+ """
231
+ if not settings.email.is_configured:
232
+ raise HTTPException(
233
+ status_code=501,
234
+ detail="Email authentication is not configured"
235
+ )
236
+
237
+ if not settings.postgres.enabled:
238
+ raise HTTPException(
239
+ status_code=501,
240
+ detail="Database is required for email authentication"
241
+ )
242
+
243
+ db = PostgresService()
244
+ try:
245
+ await db.connect()
246
+
247
+ # Initialize email auth provider
248
+ email_auth = EmailAuthProvider()
249
+
250
+ # Verify code
251
+ result = await email_auth.verify_code(
252
+ email=body.email,
253
+ code=body.code,
254
+ db=db,
255
+ )
256
+
257
+ if not result.success:
258
+ raise HTTPException(
259
+ status_code=400,
260
+ detail=result.message or result.error
261
+ )
262
+
263
+ # Create session - compatible with OAuth session format
264
+ user_dict = email_auth.get_user_dict(
265
+ email=result.email,
266
+ user_id=result.user_id,
267
+ )
268
+
269
+ # Store user in session
270
+ request.session["user"] = user_dict
271
+
272
+ logger.info(f"User authenticated via email: {result.email}")
273
+
274
+ return {
275
+ "success": True,
276
+ "message": result.message,
277
+ "user": user_dict,
278
+ }
279
+
280
+ except HTTPException:
281
+ raise
282
+ except Exception as e:
283
+ logger.error(f"Error verifying login code: {e}")
284
+ raise HTTPException(status_code=500, detail="Failed to verify login code")
285
+ finally:
286
+ await db.disconnect()
287
+
288
+
289
+ # =============================================================================
290
+ # OAuth Authentication Endpoints
291
+ # =============================================================================
292
+
293
+
90
294
  @router.get("/{provider}/login")
91
295
  async def login(provider: str, request: Request):
92
296
  """
@@ -76,6 +76,9 @@ async def stream_openai_response(
76
76
  agent_schema: str | None = None,
77
77
  # Mutable container to capture trace context (deterministic, not AI-dependent)
78
78
  trace_context_out: dict | None = None,
79
+ # Mutable container to capture tool calls for persistence
80
+ # Format: list of {"tool_name": str, "tool_id": str, "arguments": dict, "result": any}
81
+ tool_calls_out: list | None = None,
79
82
  ) -> AsyncGenerator[str, None]:
80
83
  """
81
84
  Stream Pydantic AI agent responses with rich SSE events.
@@ -146,6 +149,9 @@ async def stream_openai_response(
146
149
  pending_tool_completions: list[tuple[str, str]] = []
147
150
  # Track if metadata was registered via register_metadata tool
148
151
  metadata_registered = False
152
+ # Track pending tool calls with full data for persistence
153
+ # Maps tool_id -> {"tool_name": str, "tool_id": str, "arguments": dict}
154
+ pending_tool_data: dict[str, dict] = {}
149
155
 
150
156
  try:
151
157
  # Emit initial progress event
@@ -299,6 +305,13 @@ async def stream_openai_response(
299
305
  arguments=args_dict
300
306
  ))
301
307
 
308
+ # Track tool call data for persistence (especially register_metadata)
309
+ pending_tool_data[tool_id] = {
310
+ "tool_name": tool_name,
311
+ "tool_id": tool_id,
312
+ "arguments": args_dict,
313
+ }
314
+
302
315
  # Update progress
303
316
  current_step = 2
304
317
  total_steps = 4 # Added tool execution step
@@ -421,6 +434,15 @@ async def stream_openai_response(
421
434
  hidden=False,
422
435
  ))
423
436
 
437
+ # Capture tool call with result for persistence
438
+ # Special handling for register_metadata - always capture full data
439
+ if tool_calls_out is not None and tool_id in pending_tool_data:
440
+ tool_data = pending_tool_data[tool_id]
441
+ tool_data["result"] = result_content
442
+ tool_data["is_metadata"] = is_metadata_event
443
+ tool_calls_out.append(tool_data)
444
+ del pending_tool_data[tool_id]
445
+
424
446
  if not is_metadata_event:
425
447
  # Normal tool completion - emit ToolCallEvent
426
448
  result_str = str(result_content)
@@ -728,6 +750,9 @@ async def stream_openai_response_with_save(
728
750
  # Accumulate content during streaming
729
751
  accumulated_content = []
730
752
 
753
+ # Capture tool calls for persistence (especially register_metadata)
754
+ tool_calls: list = []
755
+
731
756
  async for chunk in stream_openai_response(
732
757
  agent=agent,
733
758
  prompt=prompt,
@@ -737,6 +762,7 @@ async def stream_openai_response_with_save(
737
762
  session_id=session_id,
738
763
  message_id=message_id,
739
764
  trace_context_out=trace_context, # Pass container to capture trace IDs
765
+ tool_calls_out=tool_calls, # Capture tool calls for persistence
740
766
  ):
741
767
  yield chunk
742
768
 
@@ -755,28 +781,57 @@ async def stream_openai_response_with_save(
755
781
  except (json.JSONDecodeError, KeyError, IndexError):
756
782
  pass # Skip non-JSON or malformed chunks
757
783
 
758
- # After streaming completes, save the assistant response
759
- if settings.postgres.enabled and session_id and accumulated_content:
760
- full_content = "".join(accumulated_content)
784
+ # After streaming completes, save tool calls and assistant response
785
+ # Note: All messages stored UNCOMPRESSED. Compression happens on reload.
786
+ if settings.postgres.enabled and session_id:
761
787
  # Get captured trace context from container (deterministically captured inside agent execution)
762
788
  captured_trace_id = trace_context.get("trace_id")
763
789
  captured_span_id = trace_context.get("span_id")
764
- assistant_message = {
765
- "id": message_id, # Use pre-generated ID for consistency with metadata event
766
- "role": "assistant",
767
- "content": full_content,
768
- "timestamp": to_iso(utc_now()),
769
- "trace_id": captured_trace_id,
770
- "span_id": captured_span_id,
771
- }
772
- try:
773
- store = SessionMessageStore(user_id=user_id or settings.test.effective_user_id)
774
- await store.store_session_messages(
775
- session_id=session_id,
776
- messages=[assistant_message],
777
- user_id=user_id,
778
- compress=True, # Compress long assistant responses
779
- )
780
- logger.debug(f"Saved assistant response {message_id} to session {session_id} ({len(full_content)} chars)")
781
- except Exception as e:
782
- logger.error(f"Failed to save assistant response: {e}", exc_info=True)
790
+ timestamp = to_iso(utc_now())
791
+
792
+ messages_to_store = []
793
+
794
+ # First, store tool call messages (message_type: "tool")
795
+ for tool_call in tool_calls:
796
+ tool_message = {
797
+ "role": "tool",
798
+ "content": json.dumps(tool_call.get("result", {}), default=str),
799
+ "timestamp": timestamp,
800
+ "trace_id": captured_trace_id,
801
+ "span_id": captured_span_id,
802
+ # Store tool call details in a way that can be reconstructed
803
+ "tool_call_id": tool_call.get("tool_id"),
804
+ "tool_name": tool_call.get("tool_name"),
805
+ "tool_arguments": tool_call.get("arguments"),
806
+ }
807
+ messages_to_store.append(tool_message)
808
+
809
+ # Then store assistant text response (if any)
810
+ if accumulated_content:
811
+ full_content = "".join(accumulated_content)
812
+ assistant_message = {
813
+ "id": message_id, # Use pre-generated ID for consistency with metadata event
814
+ "role": "assistant",
815
+ "content": full_content,
816
+ "timestamp": timestamp,
817
+ "trace_id": captured_trace_id,
818
+ "span_id": captured_span_id,
819
+ }
820
+ messages_to_store.append(assistant_message)
821
+
822
+ if messages_to_store:
823
+ try:
824
+ store = SessionMessageStore(user_id=user_id or settings.test.effective_user_id)
825
+ await store.store_session_messages(
826
+ session_id=session_id,
827
+ messages=messages_to_store,
828
+ user_id=user_id,
829
+ compress=False, # Store uncompressed; compression happens on reload
830
+ )
831
+ logger.debug(
832
+ f"Saved {len(tool_calls)} tool calls and "
833
+ f"{'assistant response' if accumulated_content else 'no text'} "
834
+ f"to session {session_id}"
835
+ )
836
+ except Exception as e:
837
+ logger.error(f"Failed to save session messages: {e}", exc_info=True)