@cdklabs/cdk-appmod-catalog-blueprints 1.13.0 → 1.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.jsii +573 -136
- package/lib/document-processing/adapter/queued-s3-adapter.js +1 -1
- package/lib/document-processing/agentic-document-processing.js +1 -1
- package/lib/document-processing/base-document-processing.js +1 -1
- package/lib/document-processing/bedrock-document-processing.js +1 -1
- package/lib/document-processing/default-document-processing-config.js +1 -1
- package/lib/document-processing/resources/default-image-validator/package-lock.json +45 -45
- package/lib/document-processing/resources/default-image-validator/package.json +1 -1
- package/lib/framework/agents/base-agent.js +1 -1
- package/lib/framework/agents/batch-agent.js +4 -2
- package/lib/framework/agents/default-agent-config.js +1 -1
- package/lib/framework/agents/interactive-agent.d.ts +159 -2
- package/lib/framework/agents/interactive-agent.js +139 -19
- package/lib/framework/agents/knowledge-base/base-knowledge-base.js +1 -1
- package/lib/framework/agents/knowledge-base/bedrock-knowledge-base.js +1 -1
- package/lib/framework/agents/resources/interactive-agent-handler/index.py +561 -52
- package/lib/framework/agents/resources/interactive-agent-handler/requirements.txt +1 -0
- package/lib/framework/bedrock/bedrock.js +1 -1
- package/lib/framework/custom-resource/default-runtimes.js +1 -1
- package/lib/framework/foundation/access-log.js +1 -1
- package/lib/framework/foundation/eventbridge-broker.js +1 -1
- package/lib/framework/foundation/network.js +1 -1
- package/lib/framework/tests/framework-nag.test.js +2 -1
- package/lib/tsconfig.tsbuildinfo +1 -1
- package/lib/utilities/data-loader.js +1 -1
- package/lib/utilities/lambda-iam-utils.js +1 -1
- package/lib/utilities/observability/cloudfront-distribution-observability-property-injector.js +1 -1
- package/lib/utilities/observability/cloudwatch-transaction-search.js +1 -1
- package/lib/utilities/observability/default-observability-config.js +1 -1
- package/lib/utilities/observability/lambda-observability-property-injector.js +1 -1
- package/lib/utilities/observability/log-group-data-protection-utils.js +1 -1
- package/lib/utilities/observability/powertools-config.js +1 -1
- package/lib/utilities/observability/state-machine-observability-property-injector.js +1 -1
- package/lib/webapp/frontend-construct.js +1 -1
- package/package.json +4 -4
|
@@ -20,12 +20,16 @@ import importlib
|
|
|
20
20
|
import sys
|
|
21
21
|
import tempfile
|
|
22
22
|
import zipfile
|
|
23
|
+
import base64
|
|
24
|
+
import asyncio
|
|
25
|
+
import contextvars
|
|
26
|
+
import time
|
|
23
27
|
import boto3
|
|
24
28
|
from typing import Dict, Any, Optional, List
|
|
25
29
|
|
|
26
|
-
from fastapi import FastAPI, Request
|
|
30
|
+
from fastapi import FastAPI, Request, Header
|
|
27
31
|
from fastapi.middleware.cors import CORSMiddleware
|
|
28
|
-
from fastapi.responses import StreamingResponse
|
|
32
|
+
from fastapi.responses import StreamingResponse, JSONResponse
|
|
29
33
|
from pydantic import BaseModel
|
|
30
34
|
import uvicorn
|
|
31
35
|
|
|
@@ -38,12 +42,40 @@ from aws_lambda_powertools.metrics import MetricUnit
|
|
|
38
42
|
|
|
39
43
|
# Initialize AWS clients
|
|
40
44
|
s3_client = boto3.client('s3')
|
|
45
|
+
dynamodb = boto3.resource('dynamodb')
|
|
41
46
|
|
|
42
47
|
# Initialize observability
|
|
43
48
|
logger = Logger()
|
|
44
49
|
tracer = Tracer()
|
|
45
50
|
metrics = Metrics()
|
|
46
51
|
|
|
52
|
+
# =============================================================================
|
|
53
|
+
# SSE Event Queue for Tool-to-Handler Communication
|
|
54
|
+
# =============================================================================
|
|
55
|
+
# Context variable to hold an async queue for SSE events emitted by tools.
|
|
56
|
+
# This allows tools (like execute_script) to push events (schema, preview)
|
|
57
|
+
# directly into the SSE stream without relying on Strands' event types.
|
|
58
|
+
_sse_queue: contextvars.ContextVar[asyncio.Queue] = contextvars.ContextVar('sse_queue')
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def emit_sse_event(event_type: str, data: dict) -> None:
|
|
62
|
+
"""
|
|
63
|
+
Emit an SSE event from a tool function.
|
|
64
|
+
|
|
65
|
+
This function is safe to call from any tool — it will silently do nothing
|
|
66
|
+
if called outside of an SSE streaming context.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
event_type: The SSE event type (e.g., 'schema', 'preview', 'download')
|
|
70
|
+
data: The data payload to send with the event
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
queue = _sse_queue.get()
|
|
74
|
+
queue.put_nowait({'event': event_type, 'data': data})
|
|
75
|
+
except LookupError:
|
|
76
|
+
# Not in SSE context (e.g., during testing), silently ignore
|
|
77
|
+
pass
|
|
78
|
+
|
|
47
79
|
# Load configuration from environment variables
|
|
48
80
|
MODEL_ID = os.getenv('MODEL_ID', 'anthropic.claude-3-5-sonnet-20241022-v2:0')
|
|
49
81
|
SYSTEM_PROMPT_BUCKET = os.getenv('SYSTEM_PROMPT_S3_BUCKET_NAME')
|
|
@@ -51,6 +83,103 @@ SYSTEM_PROMPT_KEY = os.getenv('SYSTEM_PROMPT_S3_KEY')
|
|
|
51
83
|
TOOLS_CONFIG = os.getenv('TOOLS_CONFIG', '[]')
|
|
52
84
|
KNOWLEDGE_BASE_SYSTEM_PROMPT_ADDITION = os.getenv('KNOWLEDGE_BASE_SYSTEM_PROMPT_ADDITION', '')
|
|
53
85
|
SESSION_BUCKET = os.getenv('SESSION_BUCKET')
|
|
86
|
+
SESSION_LOCK_TABLE = os.getenv('SESSION_LOCK_TABLE')
|
|
87
|
+
SESSION_INDEX_TABLE = os.environ.get('SESSION_INDEX_TABLE', '')
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def validate_and_repair_session(session_id: str, bucket: str) -> bool:
|
|
91
|
+
"""
|
|
92
|
+
Validate session history for corrupted toolUse/toolResult pairs.
|
|
93
|
+
|
|
94
|
+
If corruption is detected (orphaned toolUse without matching toolResult),
|
|
95
|
+
the entire session is deleted to allow a fresh start.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
session_id: The session ID to validate
|
|
99
|
+
bucket: S3 bucket containing session data
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
True if session is valid or was repaired, False if session was deleted
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
# List all message files for this session
|
|
106
|
+
prefix = f'/session_{session_id}/agents/agent_default/messages/'
|
|
107
|
+
paginator = s3_client.get_paginator('list_objects_v2')
|
|
108
|
+
|
|
109
|
+
messages = []
|
|
110
|
+
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
|
|
111
|
+
for obj in page.get('Contents', []):
|
|
112
|
+
if 'message_' in obj['Key'] and obj['Key'].endswith('.json'):
|
|
113
|
+
try:
|
|
114
|
+
response = s3_client.get_object(Bucket=bucket, Key=obj['Key'])
|
|
115
|
+
msg_data = json.loads(response['Body'].read().decode('utf-8'))
|
|
116
|
+
# Extract message index from filename
|
|
117
|
+
filename = obj['Key'].split('/')[-1]
|
|
118
|
+
msg_index = int(filename.replace('message_', '').replace('.json', ''))
|
|
119
|
+
messages.append((msg_index, msg_data))
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.warning(f'Failed to read message file {obj["Key"]}: {e}')
|
|
122
|
+
|
|
123
|
+
if not messages:
|
|
124
|
+
# No messages yet, session is valid
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
# Sort by message index
|
|
128
|
+
messages.sort(key=lambda x: x[0])
|
|
129
|
+
|
|
130
|
+
# Track pending toolUse IDs that need matching toolResult
|
|
131
|
+
pending_tool_uses = set()
|
|
132
|
+
|
|
133
|
+
for msg_index, msg_data in messages:
|
|
134
|
+
message = msg_data.get('message', {})
|
|
135
|
+
content_list = message.get('content', [])
|
|
136
|
+
|
|
137
|
+
if isinstance(content_list, str):
|
|
138
|
+
# Simple text message, no tools
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
for content in content_list:
|
|
142
|
+
if isinstance(content, dict):
|
|
143
|
+
# Check for toolUse
|
|
144
|
+
if 'toolUse' in content:
|
|
145
|
+
tool_use_id = content['toolUse'].get('toolUseId')
|
|
146
|
+
if tool_use_id:
|
|
147
|
+
pending_tool_uses.add(tool_use_id)
|
|
148
|
+
|
|
149
|
+
# Check for toolResult
|
|
150
|
+
if 'toolResult' in content:
|
|
151
|
+
tool_use_id = content['toolResult'].get('toolUseId')
|
|
152
|
+
if tool_use_id:
|
|
153
|
+
pending_tool_uses.discard(tool_use_id)
|
|
154
|
+
|
|
155
|
+
if pending_tool_uses:
|
|
156
|
+
# Session is corrupted - has toolUse without matching toolResult
|
|
157
|
+
logger.warning(
|
|
158
|
+
f'Session {session_id} is corrupted: {len(pending_tool_uses)} orphaned toolUse IDs. '
|
|
159
|
+
f'Deleting session to allow fresh start.'
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Delete all session files
|
|
163
|
+
delete_prefix = f'/session_{session_id}/'
|
|
164
|
+
objects_to_delete = []
|
|
165
|
+
for page in paginator.paginate(Bucket=bucket, Prefix=delete_prefix):
|
|
166
|
+
for obj in page.get('Contents', []):
|
|
167
|
+
objects_to_delete.append({'Key': obj['Key']})
|
|
168
|
+
|
|
169
|
+
if objects_to_delete:
|
|
170
|
+
# Delete in batches of 1000 (S3 limit)
|
|
171
|
+
for i in range(0, len(objects_to_delete), 1000):
|
|
172
|
+
batch = objects_to_delete[i:i + 1000]
|
|
173
|
+
s3_client.delete_objects(Bucket=bucket, Delete={'Objects': batch})
|
|
174
|
+
logger.info(f'Deleted {len(objects_to_delete)} files from corrupted session {session_id}')
|
|
175
|
+
|
|
176
|
+
return False # Session was deleted
|
|
177
|
+
|
|
178
|
+
return True # Session is valid
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
logger.warning(f'Session validation failed for {session_id}: {e}. Continuing anyway.')
|
|
182
|
+
return True # Don't block on validation errors
|
|
54
183
|
|
|
55
184
|
|
|
56
185
|
def load_system_prompt() -> str:
|
|
@@ -126,12 +255,51 @@ def load_tools_from_s3() -> list:
|
|
|
126
255
|
return tools
|
|
127
256
|
|
|
128
257
|
|
|
258
|
+
def update_session_index(user_id: str, session_id: str, last_message: str = ''):
|
|
259
|
+
"""Update or create session index record in DynamoDB."""
|
|
260
|
+
if not SESSION_INDEX_TABLE or not user_id or not session_id:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
from datetime import datetime, timezone
|
|
265
|
+
table = dynamodb.Table(SESSION_INDEX_TABLE)
|
|
266
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
267
|
+
|
|
268
|
+
# Use update with SET to create or update
|
|
269
|
+
table.update_item(
|
|
270
|
+
Key={'user_id': user_id, 'session_id': session_id},
|
|
271
|
+
UpdateExpression='SET updated_at = :updated_at, last_message = :last_message, created_at = if_not_exists(created_at, :now)',
|
|
272
|
+
ExpressionAttributeValues={
|
|
273
|
+
':updated_at': now,
|
|
274
|
+
':last_message': last_message[:100] if last_message else '', # Truncate preview
|
|
275
|
+
':now': now,
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
logger.info(f'Updated session index in DynamoDB for user={user_id[:8]}..., session={session_id}')
|
|
279
|
+
except Exception as e:
|
|
280
|
+
logger.warning(f'Error updating session index: {e}')
|
|
281
|
+
|
|
282
|
+
|
|
129
283
|
# Cold start: load system prompt and tools
|
|
130
284
|
SYSTEM_PROMPT = load_system_prompt()
|
|
131
285
|
if KNOWLEDGE_BASE_SYSTEM_PROMPT_ADDITION:
|
|
132
286
|
SYSTEM_PROMPT = SYSTEM_PROMPT + '\n\n' + KNOWLEDGE_BASE_SYSTEM_PROMPT_ADDITION
|
|
133
287
|
AGENT_TOOLS = load_tools_from_s3()
|
|
134
288
|
|
|
289
|
+
# Session locks to prevent concurrent request processing for the same session.
|
|
290
|
+
# This prevents race conditions in Strands S3SessionManager that can corrupt
|
|
291
|
+
# conversation history when multiple requests for the same session overlap.
|
|
292
|
+
_session_locks: Dict[str, asyncio.Lock] = {}
|
|
293
|
+
_session_locks_lock = asyncio.Lock()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
async def get_session_lock(session_id: str) -> asyncio.Lock:
|
|
297
|
+
"""Get or create a lock for a specific session."""
|
|
298
|
+
async with _session_locks_lock:
|
|
299
|
+
if session_id not in _session_locks:
|
|
300
|
+
_session_locks[session_id] = asyncio.Lock()
|
|
301
|
+
return _session_locks[session_id]
|
|
302
|
+
|
|
135
303
|
|
|
136
304
|
# FastAPI app
|
|
137
305
|
app = FastAPI()
|
|
@@ -142,7 +310,7 @@ app = FastAPI()
|
|
|
142
310
|
app.add_middleware(
|
|
143
311
|
CORSMiddleware,
|
|
144
312
|
allow_origins=['*'],
|
|
145
|
-
allow_methods=['POST', 'OPTIONS'],
|
|
313
|
+
allow_methods=['GET', 'POST', 'OPTIONS'],
|
|
146
314
|
allow_headers=['Content-Type', 'Authorization'],
|
|
147
315
|
)
|
|
148
316
|
|
|
@@ -151,6 +319,47 @@ class ChatRequest(BaseModel):
|
|
|
151
319
|
"""Chat request body."""
|
|
152
320
|
message: str
|
|
153
321
|
session_id: Optional[str] = None
|
|
322
|
+
user_id: Optional[str] = None # Fallback if JWT extraction fails
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def extract_user_from_jwt(authorization: Optional[str]) -> Dict[str, str]:
|
|
326
|
+
"""
|
|
327
|
+
Extract user information from JWT token.
|
|
328
|
+
|
|
329
|
+
Decodes the JWT payload (without verification - API Gateway already verified).
|
|
330
|
+
Returns dict with user_id (sub) and other claims.
|
|
331
|
+
"""
|
|
332
|
+
if not authorization:
|
|
333
|
+
return {}
|
|
334
|
+
|
|
335
|
+
try:
|
|
336
|
+
# Remove 'Bearer ' prefix if present
|
|
337
|
+
token = authorization.replace('Bearer ', '').strip()
|
|
338
|
+
if not token:
|
|
339
|
+
return {}
|
|
340
|
+
|
|
341
|
+
# JWT format: header.payload.signature
|
|
342
|
+
parts = token.split('.')
|
|
343
|
+
if len(parts) != 3:
|
|
344
|
+
return {}
|
|
345
|
+
|
|
346
|
+
# Decode payload (middle part) - add padding if needed
|
|
347
|
+
payload_b64 = parts[1]
|
|
348
|
+
padding = 4 - len(payload_b64) % 4
|
|
349
|
+
if padding != 4:
|
|
350
|
+
payload_b64 += '=' * padding
|
|
351
|
+
|
|
352
|
+
payload_json = base64.urlsafe_b64decode(payload_b64).decode('utf-8')
|
|
353
|
+
payload = json.loads(payload_json)
|
|
354
|
+
|
|
355
|
+
return {
|
|
356
|
+
'user_id': payload.get('sub', ''),
|
|
357
|
+
'username': payload.get('cognito:username', payload.get('username', '')),
|
|
358
|
+
'email': payload.get('email', ''),
|
|
359
|
+
}
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.warning(f'Failed to extract user from JWT: {e}')
|
|
362
|
+
return {}
|
|
154
363
|
|
|
155
364
|
|
|
156
365
|
def format_sse(data: str, event: Optional[str] = None) -> str:
|
|
@@ -165,66 +374,192 @@ def format_sse(data: str, event: Optional[str] = None) -> str:
|
|
|
165
374
|
|
|
166
375
|
|
|
167
376
|
@app.post('/chat')
|
|
168
|
-
async def chat(request: ChatRequest):
|
|
377
|
+
async def chat(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
|
169
378
|
"""Handle chat request with SSE streaming response."""
|
|
170
379
|
session_id = request.session_id or str(uuid.uuid4())
|
|
171
380
|
user_message = request.message
|
|
172
381
|
|
|
173
|
-
|
|
382
|
+
# Extract user information from JWT for multi-tenant support
|
|
383
|
+
# Fall back to request body user_id if JWT extraction fails
|
|
384
|
+
user_info = extract_user_from_jwt(authorization)
|
|
385
|
+
user_id = user_info.get('user_id', '') or request.user_id or ''
|
|
386
|
+
|
|
387
|
+
logger.info(f'Chat request: session={session_id}, user={user_id[:8] if user_id else "anonymous"}, message={user_message[:80]}...')
|
|
174
388
|
metrics.add_metric(name='ChatRequests', unit=MetricUnit.Count, value=1)
|
|
175
389
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
390
|
+
# Get session lock to prevent concurrent request processing.
|
|
391
|
+
# This prevents race conditions that can corrupt conversation history
|
|
392
|
+
# when users send rapid retries or multiple messages simultaneously.
|
|
393
|
+
session_lock = await get_session_lock(session_id)
|
|
179
394
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
395
|
+
async def generate_sse():
|
|
396
|
+
# Send session metadata first (include user_id for frontend reference)
|
|
397
|
+
yield format_sse(json.dumps({'session_id': session_id, 'user_id': user_id}), event='metadata')
|
|
398
|
+
|
|
399
|
+
# Acquire distributed DynamoDB lock to prevent concurrent Lambda instances
|
|
400
|
+
# from processing the same session. This complements the in-memory asyncio.Lock
|
|
401
|
+
# which only works within a single Lambda instance.
|
|
402
|
+
ddb_lock_acquired = False
|
|
403
|
+
dynamodb_client = None
|
|
404
|
+
if SESSION_LOCK_TABLE:
|
|
405
|
+
try:
|
|
406
|
+
dynamodb_client = boto3.client('dynamodb')
|
|
407
|
+
now = int(time.time())
|
|
408
|
+
ttl = now + 300 # 5 minute lock TTL (tool executions can take a while)
|
|
409
|
+
|
|
410
|
+
dynamodb_client.put_item(
|
|
411
|
+
TableName=SESSION_LOCK_TABLE,
|
|
412
|
+
Item={
|
|
413
|
+
'session_id': {'S': session_id},
|
|
414
|
+
'locked_at': {'N': str(now)},
|
|
415
|
+
'ttl': {'N': str(ttl)}
|
|
416
|
+
},
|
|
417
|
+
ConditionExpression='attribute_not_exists(session_id) OR #ttl < :now',
|
|
418
|
+
ExpressionAttributeNames={'#ttl': 'ttl'},
|
|
419
|
+
ExpressionAttributeValues={':now': {'N': str(now)}}
|
|
188
420
|
)
|
|
421
|
+
ddb_lock_acquired = True
|
|
422
|
+
logger.info(f'Acquired distributed lock for session {session_id}')
|
|
423
|
+
except dynamodb_client.exceptions.ConditionalCheckFailedException:
|
|
424
|
+
logger.warning(f'Failed to acquire distributed lock for session {session_id} - already locked')
|
|
425
|
+
yield format_sse(json.dumps({
|
|
426
|
+
'error': 'Another request is already processing this session. Please wait and try again.'
|
|
427
|
+
}), event='error')
|
|
428
|
+
return
|
|
429
|
+
except Exception as e:
|
|
430
|
+
# Log but continue - fall back to in-memory lock only
|
|
431
|
+
logger.warning(f'DynamoDB lock acquisition failed, continuing with in-memory lock only: {e}')
|
|
189
432
|
|
|
190
|
-
|
|
191
|
-
conversation_manager = SlidingWindowConversationManager(window_size=20)
|
|
192
|
-
|
|
193
|
-
# Create Bedrock model
|
|
194
|
-
model = BedrockModel(model_id=MODEL_ID, streaming=True)
|
|
195
|
-
|
|
196
|
-
# Create agent with Strands-native session and conversation management.
|
|
197
|
-
# Strands handles session persistence and context windowing automatically.
|
|
198
|
-
# Disable the default callback handler so stream_async yields all events.
|
|
199
|
-
agent = Agent(
|
|
200
|
-
model=model,
|
|
201
|
-
system_prompt=SYSTEM_PROMPT,
|
|
202
|
-
tools=AGENT_TOOLS if AGENT_TOOLS else None,
|
|
203
|
-
session_manager=session_manager,
|
|
204
|
-
conversation_manager=conversation_manager,
|
|
205
|
-
callback_handler=None,
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
# Use stream_async for true token-by-token streaming.
|
|
209
|
-
# Each event with a "data" key contains a text chunk from the model.
|
|
210
|
-
async for event in agent.stream_async(user_message):
|
|
211
|
-
# Text chunk from model — stream it to the client immediately
|
|
212
|
-
if 'data' in event:
|
|
213
|
-
chunk = event['data']
|
|
214
|
-
full_response += chunk
|
|
215
|
-
yield format_sse(json.dumps({'text': chunk}))
|
|
216
|
-
|
|
217
|
-
# Session is saved automatically by Strands S3SessionManager
|
|
218
|
-
|
|
219
|
-
metrics.add_metric(name='ChatResponses', unit=MetricUnit.Count, value=1)
|
|
433
|
+
full_response = ''
|
|
220
434
|
|
|
221
|
-
|
|
222
|
-
|
|
435
|
+
# Set up SSE event queue for tool-to-handler communication.
|
|
436
|
+
# Tools can call emit_sse_event() to push events into this queue,
|
|
437
|
+
# and we'll yield them to the client during the streaming loop.
|
|
438
|
+
sse_event_queue: asyncio.Queue = asyncio.Queue()
|
|
439
|
+
_sse_queue.set(sse_event_queue)
|
|
440
|
+
|
|
441
|
+
# Acquire session lock to serialize requests for this session.
|
|
442
|
+
# This ensures conversation history remains consistent.
|
|
443
|
+
async with session_lock:
|
|
444
|
+
try:
|
|
445
|
+
# Validate session before loading - detect and clear corrupted sessions
|
|
446
|
+
# that have orphaned toolUse without matching toolResult
|
|
447
|
+
session_was_reset = False
|
|
448
|
+
if SESSION_BUCKET:
|
|
449
|
+
if not validate_and_repair_session(session_id, SESSION_BUCKET):
|
|
450
|
+
# Session was corrupted and deleted, notify user
|
|
451
|
+
session_was_reset = True
|
|
452
|
+
yield format_sse(json.dumps({
|
|
453
|
+
'text': '⚠️ Previous session was corrupted and has been reset. Starting fresh.\n\n'
|
|
454
|
+
}))
|
|
455
|
+
|
|
456
|
+
# Create Strands-native session manager (handles load/save automatically)
|
|
457
|
+
session_manager = None
|
|
458
|
+
if SESSION_BUCKET:
|
|
459
|
+
session_manager = StrandsS3SessionManager(
|
|
460
|
+
session_id=session_id,
|
|
461
|
+
bucket=SESSION_BUCKET,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Write session index EARLY (before streaming) so session appears in list immediately
|
|
465
|
+
# This prevents race conditions where user reloads before streaming completes
|
|
466
|
+
if user_id:
|
|
467
|
+
update_session_index(user_id, session_id, user_message)
|
|
468
|
+
|
|
469
|
+
# Create Strands-native conversation manager for context windowing
|
|
470
|
+
conversation_manager = SlidingWindowConversationManager(window_size=20)
|
|
471
|
+
|
|
472
|
+
# Create Bedrock model
|
|
473
|
+
model = BedrockModel(model_id=MODEL_ID, streaming=True)
|
|
474
|
+
|
|
475
|
+
# Set session_id as environment variable for tools to access
|
|
476
|
+
# This ensures tools like execute_script can persist data even if
|
|
477
|
+
# the AI forgets to pass session_id as a parameter
|
|
478
|
+
os.environ['CURRENT_SESSION_ID'] = session_id
|
|
479
|
+
os.environ['CURRENT_USER_ID'] = user_id
|
|
480
|
+
|
|
481
|
+
# Build runtime system prompt with user context for tools like export_dataset
|
|
482
|
+
runtime_system_prompt = SYSTEM_PROMPT
|
|
483
|
+
if user_id or session_id:
|
|
484
|
+
runtime_system_prompt += f'\n\n# User Context\n- user_id: {user_id}\n- session_id: {session_id}\n'
|
|
485
|
+
|
|
486
|
+
# Create agent with Strands-native session and conversation management.
|
|
487
|
+
# Strands handles session persistence and context windowing automatically.
|
|
488
|
+
# Disable the default callback handler so stream_async yields all events.
|
|
489
|
+
agent = Agent(
|
|
490
|
+
model=model,
|
|
491
|
+
system_prompt=runtime_system_prompt,
|
|
492
|
+
tools=AGENT_TOOLS if AGENT_TOOLS else None,
|
|
493
|
+
session_manager=session_manager,
|
|
494
|
+
conversation_manager=conversation_manager,
|
|
495
|
+
callback_handler=None,
|
|
496
|
+
)
|
|
223
497
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
498
|
+
# Use stream_async for true token-by-token streaming.
|
|
499
|
+
# Each event with a "data" key contains a text chunk from the model.
|
|
500
|
+
# Tool results are emitted via the SSE event queue (contextvars).
|
|
501
|
+
async for event in agent.stream_async(user_message):
|
|
502
|
+
# Debug: log event keys to understand what Strands emits
|
|
503
|
+
event_keys = list(event.keys()) if isinstance(event, dict) else ['non-dict']
|
|
504
|
+
if event_keys != ['data']: # Don't spam logs with text chunks
|
|
505
|
+
logger.debug(f'Stream event keys: {event_keys}')
|
|
506
|
+
|
|
507
|
+
# Text chunk from model — stream it to the client immediately
|
|
508
|
+
if 'data' in event:
|
|
509
|
+
chunk = event['data']
|
|
510
|
+
full_response += chunk
|
|
511
|
+
yield format_sse(json.dumps({'text': chunk}))
|
|
512
|
+
|
|
513
|
+
# Drain the SSE event queue — tools push events here via emit_sse_event()
|
|
514
|
+
while not sse_event_queue.empty():
|
|
515
|
+
try:
|
|
516
|
+
queued_event = sse_event_queue.get_nowait()
|
|
517
|
+
event_type = queued_event.get('event', 'message')
|
|
518
|
+
event_data = queued_event.get('data', {})
|
|
519
|
+
logger.info(f'Emitting queued SSE event: {event_type}')
|
|
520
|
+
yield format_sse(json.dumps(event_data), event=event_type)
|
|
521
|
+
except asyncio.QueueEmpty:
|
|
522
|
+
break
|
|
523
|
+
|
|
524
|
+
# After streaming completes, drain any remaining events from the queue
|
|
525
|
+
while not sse_event_queue.empty():
|
|
526
|
+
try:
|
|
527
|
+
queued_event = sse_event_queue.get_nowait()
|
|
528
|
+
event_type = queued_event.get('event', 'message')
|
|
529
|
+
event_data = queued_event.get('data', {})
|
|
530
|
+
logger.info(f'Emitting final queued SSE event: {event_type}')
|
|
531
|
+
yield format_sse(json.dumps(event_data), event=event_type)
|
|
532
|
+
except asyncio.QueueEmpty:
|
|
533
|
+
break
|
|
534
|
+
|
|
535
|
+
# Session is saved automatically by Strands S3SessionManager
|
|
536
|
+
|
|
537
|
+
# Update session index with final timestamp (early write already created it)
|
|
538
|
+
if user_id:
|
|
539
|
+
update_session_index(user_id, session_id, user_message)
|
|
540
|
+
|
|
541
|
+
metrics.add_metric(name='ChatResponses', unit=MetricUnit.Count, value=1)
|
|
542
|
+
|
|
543
|
+
# Send done event
|
|
544
|
+
yield format_sse('{}', event='done')
|
|
545
|
+
|
|
546
|
+
except Exception as e:
|
|
547
|
+
logger.error(f'Error processing chat: {e}', exc_info=True)
|
|
548
|
+
metrics.add_metric(name='ChatErrors', unit=MetricUnit.Count, value=1)
|
|
549
|
+
yield format_sse(json.dumps({'error': 'An internal error occurred. Check logs for details.'}), event='error')
|
|
550
|
+
|
|
551
|
+
finally:
|
|
552
|
+
# Release the distributed DynamoDB lock
|
|
553
|
+
if ddb_lock_acquired and SESSION_LOCK_TABLE and dynamodb_client:
|
|
554
|
+
try:
|
|
555
|
+
dynamodb_client.delete_item(
|
|
556
|
+
TableName=SESSION_LOCK_TABLE,
|
|
557
|
+
Key={'session_id': {'S': session_id}}
|
|
558
|
+
)
|
|
559
|
+
logger.info(f'Released distributed lock for session {session_id}')
|
|
560
|
+
except Exception as e:
|
|
561
|
+
# Log but don't fail - TTL will clean up eventually
|
|
562
|
+
logger.warning(f'Failed to release distributed lock for session {session_id}: {e}')
|
|
228
563
|
|
|
229
564
|
return StreamingResponse(
|
|
230
565
|
generate_sse(),
|
|
@@ -237,6 +572,180 @@ async def chat(request: ChatRequest):
|
|
|
237
572
|
)
|
|
238
573
|
|
|
239
574
|
|
|
575
|
+
def extract_text_content(content: Any) -> Optional[str]:
|
|
576
|
+
"""
|
|
577
|
+
Extract text content from Strands message content format.
|
|
578
|
+
|
|
579
|
+
Strands messages can have content as:
|
|
580
|
+
- A plain string
|
|
581
|
+
- A list of content blocks: [{"text": "..."}, {"toolUse": {...}}, {"toolResult": {...}}]
|
|
582
|
+
|
|
583
|
+
This function extracts only the text portions, ignoring tool use/result blocks.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
content: The content field from a Strands message
|
|
587
|
+
|
|
588
|
+
Returns:
|
|
589
|
+
Extracted text content or None if no text found
|
|
590
|
+
"""
|
|
591
|
+
if isinstance(content, str):
|
|
592
|
+
return content
|
|
593
|
+
|
|
594
|
+
if isinstance(content, list):
|
|
595
|
+
text_parts = []
|
|
596
|
+
for block in content:
|
|
597
|
+
if isinstance(block, dict) and 'text' in block:
|
|
598
|
+
text_parts.append(block['text'])
|
|
599
|
+
return ''.join(text_parts) if text_parts else None
|
|
600
|
+
|
|
601
|
+
return None
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
@app.get('/history/{session_id}')
|
|
605
|
+
async def get_history(session_id: str, authorization: Optional[str] = Header(None)):
|
|
606
|
+
"""
|
|
607
|
+
Retrieve chat history for a session.
|
|
608
|
+
|
|
609
|
+
Reads message files from S3 and returns them as a JSON array
|
|
610
|
+
in chronological order. Only returns user and assistant messages
|
|
611
|
+
with text content (tool interactions are filtered out).
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
session_id: The session ID to retrieve history for
|
|
615
|
+
authorization: Bearer token for authentication
|
|
616
|
+
|
|
617
|
+
Returns:
|
|
618
|
+
JSON array of messages: [{role: "user"|"assistant", content: string}]
|
|
619
|
+
"""
|
|
620
|
+
logger.info(f'History request: session={session_id}')
|
|
621
|
+
|
|
622
|
+
if not SESSION_BUCKET:
|
|
623
|
+
logger.warning('SESSION_BUCKET not configured, returning empty history')
|
|
624
|
+
return JSONResponse(content=[], status_code=200)
|
|
625
|
+
|
|
626
|
+
try:
|
|
627
|
+
# List all message files for this session
|
|
628
|
+
prefix = f'/session_{session_id}/agents/agent_default/messages/'
|
|
629
|
+
paginator = s3_client.get_paginator('list_objects_v2')
|
|
630
|
+
|
|
631
|
+
messages = []
|
|
632
|
+
for page in paginator.paginate(Bucket=SESSION_BUCKET, Prefix=prefix):
|
|
633
|
+
for obj in page.get('Contents', []):
|
|
634
|
+
if 'message_' in obj['Key'] and obj['Key'].endswith('.json'):
|
|
635
|
+
try:
|
|
636
|
+
response = s3_client.get_object(Bucket=SESSION_BUCKET, Key=obj['Key'])
|
|
637
|
+
msg_data = json.loads(response['Body'].read().decode('utf-8'))
|
|
638
|
+
|
|
639
|
+
# Extract message index from filename for ordering
|
|
640
|
+
filename = obj['Key'].split('/')[-1]
|
|
641
|
+
msg_index = int(filename.replace('message_', '').replace('.json', ''))
|
|
642
|
+
|
|
643
|
+
# Extract the message content
|
|
644
|
+
message = msg_data.get('message', {})
|
|
645
|
+
role = message.get('role')
|
|
646
|
+
content = message.get('content')
|
|
647
|
+
|
|
648
|
+
# Only include user and assistant messages with text content
|
|
649
|
+
if role in ('user', 'assistant'):
|
|
650
|
+
text_content = extract_text_content(content)
|
|
651
|
+
if text_content:
|
|
652
|
+
messages.append((msg_index, {
|
|
653
|
+
'role': role,
|
|
654
|
+
'content': text_content
|
|
655
|
+
}))
|
|
656
|
+
|
|
657
|
+
except Exception as e:
|
|
658
|
+
logger.warning(f'Failed to read message file {obj["Key"]}: {e}')
|
|
659
|
+
continue
|
|
660
|
+
|
|
661
|
+
# Sort by message index and extract just the messages
|
|
662
|
+
messages.sort(key=lambda x: x[0])
|
|
663
|
+
message_list = [msg for _, msg in messages]
|
|
664
|
+
|
|
665
|
+
# Build response with messages
|
|
666
|
+
result = {'messages': message_list}
|
|
667
|
+
|
|
668
|
+
# Try to fetch metadata (schema/preview) if available
|
|
669
|
+
metadata_key = f'session-metadata/{session_id}/latest_result.json'
|
|
670
|
+
try:
|
|
671
|
+
metadata_resp = s3_client.get_object(Bucket=SESSION_BUCKET, Key=metadata_key)
|
|
672
|
+
metadata = json.loads(metadata_resp['Body'].read().decode('utf-8'))
|
|
673
|
+
logger.info(f'Found metadata for session {session_id}')
|
|
674
|
+
|
|
675
|
+
# Include schema, preview, totalRows, and downloads if present
|
|
676
|
+
if 'schema' in metadata:
|
|
677
|
+
result['schema'] = metadata['schema']
|
|
678
|
+
if 'preview' in metadata:
|
|
679
|
+
result['preview'] = metadata['preview']
|
|
680
|
+
if 'totalRows' in metadata:
|
|
681
|
+
result['totalRows'] = metadata['totalRows']
|
|
682
|
+
if 'downloads' in metadata:
|
|
683
|
+
result['downloads'] = metadata['downloads']
|
|
684
|
+
|
|
685
|
+
except s3_client.exceptions.NoSuchKey:
|
|
686
|
+
logger.debug(f'No metadata found for session {session_id}')
|
|
687
|
+
except Exception as e:
|
|
688
|
+
logger.warning(f'Error fetching metadata for session {session_id}: {e}')
|
|
689
|
+
|
|
690
|
+
logger.info(f'History request complete: session={session_id}, messages={len(message_list)}')
|
|
691
|
+
return JSONResponse(content=result, status_code=200)
|
|
692
|
+
|
|
693
|
+
except Exception as e:
|
|
694
|
+
logger.error(f'Error retrieving history for session {session_id}: {e}', exc_info=True)
|
|
695
|
+
# Return empty array on error for graceful handling
|
|
696
|
+
return JSONResponse(content=[], status_code=200)
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
@app.get('/sessions')
|
|
700
|
+
async def list_sessions(authorization: Optional[str] = Header(None)):
|
|
701
|
+
"""
|
|
702
|
+
List all sessions for the authenticated user.
|
|
703
|
+
|
|
704
|
+
Queries session index from DynamoDB and returns them sorted by updated_at
|
|
705
|
+
(newest first).
|
|
706
|
+
|
|
707
|
+
Args:
|
|
708
|
+
authorization: Bearer token for authentication
|
|
709
|
+
|
|
710
|
+
Returns:
|
|
711
|
+
JSON array of session metadata:
|
|
712
|
+
[{"session_id": "...", "created_at": "...", "last_message": "...", "updated_at": "..."}]
|
|
713
|
+
"""
|
|
714
|
+
user_info = extract_user_from_jwt(authorization)
|
|
715
|
+
user_id = user_info.get('user_id', '')
|
|
716
|
+
|
|
717
|
+
if not user_id:
|
|
718
|
+
logger.info('List sessions request with no user_id, returning empty list')
|
|
719
|
+
return []
|
|
720
|
+
|
|
721
|
+
logger.info(f'List sessions request: user={user_id[:8]}...')
|
|
722
|
+
|
|
723
|
+
if not SESSION_INDEX_TABLE:
|
|
724
|
+
logger.warning('SESSION_INDEX_TABLE not configured, returning empty session list')
|
|
725
|
+
return []
|
|
726
|
+
|
|
727
|
+
try:
|
|
728
|
+
# Query DynamoDB for all sessions belonging to this user
|
|
729
|
+
table = dynamodb.Table(SESSION_INDEX_TABLE)
|
|
730
|
+
response = table.query(
|
|
731
|
+
KeyConditionExpression='user_id = :uid',
|
|
732
|
+
ExpressionAttributeValues={':uid': user_id}
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
sessions = response.get('Items', [])
|
|
736
|
+
|
|
737
|
+
# Sort by updated_at descending (newest first)
|
|
738
|
+
sessions.sort(key=lambda x: x.get('updated_at', ''), reverse=True)
|
|
739
|
+
|
|
740
|
+
logger.info(f'Returning {len(sessions)} sessions for user={user_id[:8]}...')
|
|
741
|
+
return sessions
|
|
742
|
+
|
|
743
|
+
except Exception as e:
|
|
744
|
+
logger.error(f'Error listing sessions for user {user_id[:8]}...: {e}', exc_info=True)
|
|
745
|
+
# Return empty array on error for graceful handling
|
|
746
|
+
return []
|
|
747
|
+
|
|
748
|
+
|
|
240
749
|
@app.get('/health')
|
|
241
750
|
async def health():
|
|
242
751
|
"""Health check endpoint for Lambda Web Adapter."""
|