ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
package/src/api/app.py ADDED
@@ -0,0 +1,1626 @@
1
+ """
2
+ FastAPI Application for Google Cloud Run
3
+ Thin HTTP wrapper around DataScienceCopilot - No logic changes, just API exposure.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import tempfile
9
+ import shutil
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, Any, List
13
+ import logging
14
+ from dotenv import load_dotenv
15
+
16
+ # Load environment variables from .env file
17
+ load_dotenv()
18
+
19
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, BackgroundTasks
20
+ from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import BaseModel
23
+ import asyncio
24
+ import json
25
+ import numpy as np
26
+
27
+ # Import from parent package
28
+ from src.orchestrator import DataScienceCopilot
29
+ from src.progress_manager import progress_manager
30
+ from src.session_memory import SessionMemory
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # JSON serializer that handles numpy types
37
+ def safe_json_dumps(obj):
38
+ """Convert object to JSON string, handling numpy types, datetime, and all non-serializable objects."""
39
+ from datetime import datetime, date, timedelta
40
+
41
+ def convert(o):
42
+ if isinstance(o, (np.integer, np.int64, np.int32)):
43
+ return int(o)
44
+ elif isinstance(o, (np.floating, np.float64, np.float32)):
45
+ return float(o)
46
+ elif isinstance(o, np.ndarray):
47
+ return o.tolist()
48
+ elif isinstance(o, (datetime, date)):
49
+ return o.isoformat()
50
+ elif isinstance(o, timedelta):
51
+ return str(o)
52
+ elif isinstance(o, dict):
53
+ return {k: convert(v) for k, v in o.items()}
54
+ elif isinstance(o, (list, tuple)):
55
+ return [convert(item) for item in o]
56
+ elif hasattr(o, '__dict__') and not isinstance(o, (str, int, float, bool, type(None))):
57
+ # Non-serializable object (like DataScienceCopilot)
58
+ return f"<{o.__class__.__name__} object>"
59
+ elif hasattr(o, '__class__') and 'Figure' in o.__class__.__name__:
60
+ return f"<{o.__class__.__name__} object>"
61
+ return o
62
+
63
+ return json.dumps(convert(obj))
64
+
65
+ # Initialize FastAPI
66
+ app = FastAPI(
67
+ title="Data Science Agent API",
68
+ description="Cloud Run wrapper for autonomous data science workflows",
69
+ version="1.0.0"
70
+ )
71
+
72
+ # Enable CORS
73
+ app.add_middleware(
74
+ CORSMiddleware,
75
+ allow_origins=["*"], # Configure this properly in production
76
+ allow_credentials=True,
77
+ allow_methods=["*"],
78
+ allow_headers=["*"],
79
+ )
80
+
81
+ # SSE event queues for real-time streaming
82
+ class ProgressEventManager:
83
+ """Manages SSE connections and progress events for real-time updates."""
84
+
85
+ def __init__(self):
86
+ self.active_streams: Dict[str, List[asyncio.Queue]] = {}
87
+ self.session_status: Dict[str, Dict[str, Any]] = {}
88
+
89
+ def create_stream(self, session_id: str) -> asyncio.Queue:
90
+ """Create a new SSE stream for a session."""
91
+ if session_id not in self.active_streams:
92
+ self.active_streams[session_id] = []
93
+
94
+ queue = asyncio.Queue()
95
+ self.active_streams[session_id].append(queue)
96
+ return queue
97
+
98
+ def remove_stream(self, session_id: str, queue: asyncio.Queue):
99
+ """Remove an SSE stream when client disconnects."""
100
+ if session_id in self.active_streams:
101
+ try:
102
+ self.active_streams[session_id].remove(queue)
103
+ if not self.active_streams[session_id]:
104
+ del self.active_streams[session_id]
105
+ except (ValueError, KeyError):
106
+ pass
107
+
108
+ async def send_event(self, session_id: str, event_type: str, data: Dict[str, Any]):
109
+ """Send an event to all connected clients for a session."""
110
+ if session_id not in self.active_streams:
111
+ return
112
+
113
+ # Store current status
114
+ self.session_status[session_id] = {
115
+ "type": event_type,
116
+ "data": data,
117
+ "timestamp": time.time()
118
+ }
119
+
120
+ # Send to all connected streams
121
+ dead_queues = []
122
+ for queue in self.active_streams[session_id]:
123
+ try:
124
+ await asyncio.wait_for(queue.put((event_type, data)), timeout=1.0)
125
+ except (asyncio.TimeoutError, Exception):
126
+ dead_queues.append(queue)
127
+
128
+ # Clean up dead queues
129
+ for queue in dead_queues:
130
+ self.remove_stream(session_id, queue)
131
+
132
+ def get_current_status(self, session_id: str) -> Optional[Dict[str, Any]]:
133
+ """Get the current status for a session."""
134
+ return self.session_status.get(session_id)
135
+
136
+ def clear_session(self, session_id: str):
137
+ """Clear all data for a session."""
138
+ if session_id in self.active_streams:
139
+ # Close all queues
140
+ for queue in self.active_streams[session_id]:
141
+ try:
142
+ queue.put_nowait(("complete", {}))
143
+ except:
144
+ pass
145
+ del self.active_streams[session_id]
146
+
147
+ if session_id in self.session_status:
148
+ del self.session_status[session_id]
149
+
150
+ # 👥 MULTI-USER SUPPORT: Session state isolation
151
+ # Heavy components (SBERT, tools, LLM client) are shared via global 'agent'
152
+ # Only session memory is isolated per user for fast initialization
153
+
154
+ from dataclasses import dataclass
155
+ from datetime import datetime, timedelta
156
+ import threading
157
+
158
+ @dataclass
159
+ class SessionState:
160
+ """Wrapper for session with metadata for cleanup"""
161
+ session: Any
162
+ created_at: datetime
163
+ last_accessed: datetime
164
+ request_count: int = 0
165
+
166
+ session_states: Dict[str, SessionState] = {} # session_id -> SessionState
167
+ agent_cache_lock = asyncio.Lock()
168
+ MAX_CACHED_SESSIONS = 50 # Increased limit for scale
169
+ SESSION_TTL_MINUTES = 60 # Sessions expire after 1 hour of inactivity
170
+ logger.info("👥 Multi-user session isolation initialized (fast mode)")
171
+
172
+ # Global agent - Heavy components loaded ONCE at startup
173
+ # SBERT model, tool functions, LLM client are shared across all users
174
+ # CRITICAL: We use threading.local() to ensure thread-safe session isolation
175
+ agent: Optional[DataScienceCopilot] = None
176
+ agent_thread_local = threading.local() # Thread-local storage for session isolation
177
+ agent = None
178
+
179
+ # Session state isolation (lightweight - just session memory)
180
+ session_states: Dict[str, any] = {} # session_id -> session memory only
181
+
182
+
183
+ async def get_agent_for_session(session_id: str) -> DataScienceCopilot:
184
+ """
185
+ Get agent with isolated session state.
186
+
187
+ OPTIMIZATION: Heavy components (SBERT, tools, LLM client) are shared.
188
+ Session state is isolated using thread-local storage to prevent race conditions.
189
+ This reduces per-user initialization from 20s to <1s.
190
+
191
+ THREAD SAFETY: Uses threading.local() so each request thread gets its own
192
+ agent reference with isolated session, preventing cross-contamination.
193
+
194
+ Args:
195
+ session_id: Unique session identifier
196
+
197
+ Returns:
198
+ DataScienceCopilot instance with isolated session for this user
199
+ """
200
+ global agent
201
+
202
+ async with agent_cache_lock:
203
+ # Ensure base agent exists (heavy components loaded once at startup)
204
+ if agent is None:
205
+ logger.warning("Base agent not initialized - this shouldn't happen after startup")
206
+ provider = os.getenv("LLM_PROVIDER", "mistral")
207
+ agent = DataScienceCopilot(
208
+ reasoning_effort="medium",
209
+ provider=provider,
210
+ use_compact_prompts=False
211
+ )
212
+
213
+ # Clean up expired sessions periodically (every 10th request)
214
+ if len(session_states) > 0 and len(session_states) % 10 == 0:
215
+ cleanup_expired_sessions()
216
+
217
+ now = datetime.now()
218
+
219
+ # Check if we have cached session memory for this session
220
+ if session_id in session_states:
221
+ state = session_states[session_id]
222
+ state.last_accessed = now
223
+ state.request_count += 1
224
+ logger.info(f"[♻️] Reusing session {session_id[:8]}... (requests: {state.request_count})")
225
+
226
+ # Store in thread-local storage for isolation
227
+ agent_thread_local.session = state.session
228
+ agent_thread_local.session_id = session_id
229
+
230
+ # Return agent with session set (safe because of workflow_lock)
231
+ agent.session = state.session
232
+ agent.http_session_key = session_id
233
+ return agent
234
+
235
+ # 🚀 FAST PATH: Create new session memory only (no SBERT reload!)
236
+ logger.info(f"[🆕] Creating lightweight session for {session_id[:8]}...")
237
+
238
+ # Create isolated session memory for this user
239
+ new_session = SessionMemory(session_id=session_id)
240
+
241
+ # Cache management: Remove expired first, then LRU if still over limit
242
+ if len(session_states) >= MAX_CACHED_SESSIONS:
243
+ expired_count = cleanup_expired_sessions()
244
+
245
+ # If still over limit after cleanup, remove least recently used
246
+ if len(session_states) >= MAX_CACHED_SESSIONS:
247
+ # Sort by last_accessed and remove oldest
248
+ sorted_sessions = sorted(session_states.items(), key=lambda x: x[1].last_accessed)
249
+ oldest_session_id = sorted_sessions[0][0]
250
+ logger.info(f"[🗑️] Cache full, removing LRU session {oldest_session_id[:8]}...")
251
+ del session_states[oldest_session_id]
252
+
253
+ # Create session state wrapper with metadata
254
+ session_state = SessionState(
255
+ session=new_session,
256
+ created_at=now,
257
+ last_accessed=now,
258
+ request_count=1
259
+ )
260
+ session_states[session_id] = session_state
261
+
262
+ # Store in thread-local storage
263
+ agent_thread_local.session = new_session
264
+ agent_thread_local.session_id = session_id
265
+
266
+ # Set session on shared agent (safe with workflow_lock)
267
+ agent.session = new_session
268
+ agent.http_session_key = session_id
269
+
270
+ logger.info(f"✅ Session created for {session_id[:8]} (cache: {len(session_states)}/{MAX_CACHED_SESSIONS}) - <1s init")
271
+
272
+ return agent
273
+
274
+ def cleanup_expired_sessions():
275
+ """Remove expired sessions based on TTL."""
276
+ now = datetime.now()
277
+ expired = []
278
+
279
+ for session_id, state in session_states.items():
280
+ # Check if session has been inactive for too long
281
+ inactive_duration = now - state.last_accessed
282
+ if inactive_duration > timedelta(minutes=SESSION_TTL_MINUTES):
283
+ expired.append(session_id)
284
+
285
+ for session_id in expired:
286
+ logger.info(f"[🗑️] Removing expired session {session_id[:8]}... (inactive for {SESSION_TTL_MINUTES}min)")
287
+ del session_states[session_id]
288
+
289
+ return len(expired)
290
+
291
+ # 🔒 REQUEST QUEUING: Global lock to prevent concurrent workflows
292
+ # This ensures only one analysis runs at a time, preventing:
293
+ # - Race conditions on file writes
294
+ # - Memory exhaustion from parallel model training
295
+ # - Session state corruption
296
+ workflow_lock = asyncio.Lock()
297
+ logger.info("🔒 Workflow lock initialized for request queuing")
298
+
299
+
300
+ @app.on_event("startup")
301
+ async def startup_event():
302
+ """Initialize DataScienceCopilot on service startup."""
303
+ global agent
304
+ try:
305
+ logger.info("Initializing legacy global agent for health checks...")
306
+ provider = os.getenv("LLM_PROVIDER", "mistral")
307
+ use_compact = False # Always use multi-agent routing
308
+
309
+ # Create one agent for health checks only
310
+ # Real requests will use get_agent_for_session() for isolation
311
+ agent = DataScienceCopilot(
312
+ reasoning_effort="medium",
313
+ provider=provider,
314
+ use_compact_prompts=use_compact
315
+ )
316
+ logger.info(f"✅ Health check agent initialized with provider: {agent.provider}")
317
+ logger.info("👥 Per-session agents enabled - each user gets isolated instance")
318
+ logger.info("🤖 Multi-agent architecture enabled with 5 specialists")
319
+ except Exception as e:
320
+ logger.error(f"❌ Failed to initialize agent: {e}")
321
+ raise
322
+
323
+
324
+ @app.get("/api/health")
325
+ async def root():
326
+ """Health check endpoint."""
327
+ return {
328
+ "service": "Data Science Agent API",
329
+ "status": "healthy",
330
+ "provider": agent.provider if agent else "not initialized",
331
+ "tools_available": len(agent.tool_functions) if agent else 0
332
+ }
333
+
334
+
335
+ @app.get("/api/progress/{session_id}")
336
+ async def get_progress(session_id: str):
337
+ """Get progress updates for a specific session (legacy polling endpoint)."""
338
+ return {
339
+ "session_id": session_id,
340
+ "steps": progress_manager.get_history(session_id),
341
+ "current": {"status": "active" if progress_manager.get_subscriber_count(session_id) > 0 else "idle"}
342
+ }
343
+
344
+
345
+ @app.get("/api/progress/stream/{session_id}")
346
+ async def stream_progress(session_id: str):
347
+ """Stream real-time progress updates using Server-Sent Events (SSE).
348
+
349
+ This endpoint connects clients to the global progress_manager which
350
+ receives events from the orchestrator as tools execute.
351
+
352
+ Events:
353
+ - tool_executing: When a tool begins execution
354
+ - tool_completed: When a tool finishes successfully
355
+ - tool_failed: When a tool fails
356
+ - token_update: Token budget updates
357
+ - analysis_complete: When the entire workflow finishes
358
+ """
359
+ print(f"[SSE] ENDPOINT: Client connected for session_id={session_id}")
360
+
361
+ # CRITICAL: Create queue and register subscriber IMMEDIATELY
362
+ queue = asyncio.Queue(maxsize=100)
363
+ if session_id not in progress_manager._queues:
364
+ progress_manager._queues[session_id] = []
365
+ progress_manager._queues[session_id].append(queue)
366
+ print(f"[SSE] Queue registered, total subscribers: {len(progress_manager._queues[session_id])}")
367
+
368
+ async def event_generator():
369
+ try:
370
+ # Send initial connection event
371
+ connection_event = {
372
+ 'type': 'connected',
373
+ 'message': '🔗 Connected to progress stream',
374
+ 'session_id': session_id
375
+ }
376
+ print(f"[SSE] SENDING connection event to client")
377
+ yield f"data: {safe_json_dumps(connection_event)}\n\n"
378
+
379
+ # 🔥 FIX: Replay any events that were emitted BEFORE this subscriber connected
380
+ # This handles the race condition where background analysis starts emitting events
381
+ # before the client's SSE reconnection completes
382
+ history = progress_manager.get_history(session_id)
383
+ if history:
384
+ print(f"[SSE] Replaying {len(history)} missed events for late-joining subscriber")
385
+ for past_event in history:
386
+ # Don't replay if it's already a terminal event
387
+ if past_event.get('type') != 'analysis_complete':
388
+ yield f"data: {safe_json_dumps(past_event)}\n\n"
389
+ else:
390
+ # If analysis already completed before we connected, send it and close
391
+ yield f"data: {safe_json_dumps(past_event)}\n\n"
392
+ print(f"[SSE] Analysis already completed before subscriber connected - closing")
393
+ return
394
+ else:
395
+ print(f"[SSE] No history to replay (fresh session)")
396
+
397
+ print(f"[SSE] Starting event stream loop for session {session_id}")
398
+
399
+ # Stream new events from the queue (poll with get_nowait to avoid blocking issues)
400
+ while True:
401
+ if not queue.empty():
402
+ event = queue.get_nowait()
403
+ print(f"[SSE] GOT event from queue: {event.get('type')}")
404
+ yield f"data: {safe_json_dumps(event)}\n\n"
405
+
406
+ # Check if analysis is complete
407
+ if event.get('type') == 'analysis_complete':
408
+ break
409
+ else:
410
+ # No events available, send keepalive and wait
411
+ yield f": keepalive\n\n"
412
+ await asyncio.sleep(0.5) # Poll every 500ms
413
+
414
+ except asyncio.CancelledError:
415
+ logger.info(f"SSE stream cancelled for session {session_id}")
416
+ except Exception as e:
417
+ logger.error(f"SSE error for session {session_id}: {e}")
418
+ finally:
419
+ # Cleanup queue
420
+ if session_id in progress_manager._queues and queue in progress_manager._queues[session_id]:
421
+ progress_manager._queues[session_id].remove(queue)
422
+ logger.info(f"SSE stream closed for session {session_id}")
423
+
424
+ return StreamingResponse(
425
+ event_generator(),
426
+ media_type="text/event-stream",
427
+ headers={
428
+ "Cache-Control": "no-cache",
429
+ "Connection": "keep-alive",
430
+ "X-Accel-Buffering": "no" # Disable nginx buffering
431
+ }
432
+ )
433
+
434
+
435
+ @app.get("/health")
436
+ async def health_check():
437
+ """
438
+ Health check for Cloud Run.
439
+ Returns 200 if service is ready to accept requests.
440
+ """
441
+ if agent is None:
442
+ raise HTTPException(status_code=503, detail="Agent not initialized")
443
+
444
+ return {
445
+ "status": "healthy",
446
+ "agent_ready": True,
447
+ "provider": agent.provider,
448
+ "tools_count": len(agent.tool_functions)
449
+ }
450
+
451
+
452
+ class AnalysisRequest(BaseModel):
453
+ """Request model for analysis endpoint (JSON body)."""
454
+ task_description: str
455
+ target_col: Optional[str] = None
456
+ use_cache: bool = True
457
+ max_iterations: int = 20
458
+
459
+
460
+ def run_analysis_background(file_path: str, task_description: str, target_col: Optional[str],
461
+ use_cache: bool, max_iterations: int, session_id: str):
462
+ """Background task to run analysis and emit events."""
463
+ async def _run_with_lock():
464
+ """Wrap analysis in lock to ensure sequential execution."""
465
+ async with workflow_lock:
466
+ try:
467
+ logger.info(f"[BACKGROUND] Starting analysis for session {session_id[:8]}...")
468
+
469
+ # 🧹 Clear SSE history for fresh event stream (prevents duplicate results)
470
+ print(f"[🧹] Clearing SSE history for {session_id[:8]}...")
471
+ if session_id in progress_manager._history:
472
+ progress_manager._history[session_id] = []
473
+
474
+ # 👥 Get isolated agent for this session
475
+ session_agent = await get_agent_for_session(session_id)
476
+
477
+ result = session_agent.analyze(
478
+ file_path=file_path,
479
+ task_description=task_description,
480
+ target_col=target_col,
481
+ use_cache=use_cache,
482
+ max_iterations=max_iterations
483
+ )
484
+
485
+ logger.info(f"[BACKGROUND] Analysis completed for session {session_id[:8]}...")
486
+
487
+ # Send appropriate completion event based on status
488
+ if result.get("status") == "error":
489
+ progress_manager.emit(session_id, {
490
+ "type": "analysis_failed",
491
+ "status": "error",
492
+ "message": result.get("summary", "❌ Analysis failed"),
493
+ "error": result.get("error", "Analysis error"),
494
+ "result": result
495
+ })
496
+ else:
497
+ progress_manager.emit(session_id, {
498
+ "type": "analysis_complete",
499
+ "status": result.get("status"),
500
+ "message": "✅ Analysis completed successfully!",
501
+ "result": result
502
+ })
503
+
504
+ except Exception as e:
505
+ logger.error(f"[BACKGROUND] Analysis failed for session {session_id[:8]}...: {e}")
506
+ progress_manager.emit(session_id, {
507
+ "type": "analysis_failed",
508
+ "error": str(e),
509
+ "message": f"❌ Analysis failed: {str(e)}"
510
+ })
511
+
512
+ # Run async function in event loop
513
+ import asyncio
514
+ try:
515
+ loop = asyncio.get_event_loop()
516
+ except RuntimeError:
517
+ loop = asyncio.new_event_loop()
518
+ asyncio.set_event_loop(loop)
519
+
520
+ loop.run_until_complete(_run_with_lock())
521
+
522
+
523
+ @app.post("/run-async")
524
+ async def run_analysis_async(
525
+ background_tasks: BackgroundTasks,
526
+ file: Optional[UploadFile] = File(None),
527
+ task_description: str = Form(...),
528
+ target_col: Optional[str] = Form(None),
529
+ session_id: Optional[str] = Form(None),
530
+ use_cache: bool = Form(False), # Disabled to show multi-agent in action
531
+ max_iterations: int = Form(20)
532
+ ) -> JSONResponse:
533
+ """
534
+ Start analysis in background and return session UUID immediately.
535
+ Clients can connect SSE with this UUID to receive real-time updates.
536
+
537
+ For follow-up queries, send the same session_id to maintain context.
538
+ """
539
+ if agent is None:
540
+ raise HTTPException(status_code=503, detail="Agent not initialized")
541
+
542
+ # 🆔 Session ID handling:
543
+ # - If client sends a valid UUID, reuse it (follow-up query)
544
+ # - Otherwise generate a new one (first query)
545
+ import uuid
546
+ if session_id and '-' in session_id and len(session_id) > 20:
547
+ # Valid UUID from client - this is a follow-up query
548
+ logger.info(f"[ASYNC] Reusing session: {session_id[:8]}... (follow-up)")
549
+ else:
550
+ # Generate new session for first query
551
+ session_id = str(uuid.uuid4())
552
+ logger.info(f"[ASYNC] Created new session: {session_id[:8]}...")
553
+
554
+ # Handle file upload
555
+ temp_file_path = None
556
+ if file:
557
+ temp_dir = Path("/tmp") / "data_science_agent"
558
+ temp_dir.mkdir(parents=True, exist_ok=True)
559
+ temp_file_path = temp_dir / file.filename
560
+
561
+ with open(temp_file_path, "wb") as buffer:
562
+ shutil.copyfileobj(file.file, buffer)
563
+
564
+ logger.info(f"[ASYNC] File saved: {file.filename}")
565
+ else:
566
+ # 🛡️ VALIDATION: Check if this session has dataset cached
567
+ has_dataset = False
568
+ async with agent_cache_lock:
569
+ # Check session_states cache for this specific session_id
570
+ if session_id in session_states:
571
+ state = session_states[session_id]
572
+ cached_session = state.session # Extract SessionMemory from wrapper
573
+ if hasattr(cached_session, 'last_dataset') and cached_session.last_dataset:
574
+ has_dataset = True
575
+ logger.info(f"[ASYNC] Follow-up query for session {session_id[:8]}... - using cached dataset")
576
+
577
+ if not has_dataset:
578
+ logger.warning(f"[ASYNC] No file uploaded and no dataset for session {session_id[:8]}...")
579
+ return JSONResponse(
580
+ content={
581
+ "success": False,
582
+ "error": "No dataset available",
583
+ "message": "Please upload a CSV, Excel, or Parquet file first.",
584
+ "session_id": session_id
585
+ },
586
+ status_code=400
587
+ )
588
+
589
+ # Start background analysis
590
+ background_tasks.add_task(
591
+ run_analysis_background,
592
+ file_path=str(temp_file_path) if temp_file_path else "",
593
+ task_description=task_description,
594
+ target_col=target_col,
595
+ use_cache=use_cache,
596
+ max_iterations=max_iterations,
597
+ session_id=session_id
598
+ )
599
+
600
+ # Return UUID immediately so client can connect SSE
601
+ return JSONResponse(content={
602
+ "session_id": session_id,
603
+ "status": "started",
604
+ "message": "Analysis started in background"
605
+ })
606
+
607
+
608
+ @app.post("/run")
609
+ async def run_analysis(
610
+ file: Optional[UploadFile] = File(None, description="Dataset file (CSV or Parquet) - optional for follow-up requests"),
611
+ task_description: str = Form(..., description="Natural language task description"),
612
+ target_col: Optional[str] = Form(None, description="Target column name for prediction"),
613
+ use_cache: bool = Form(False, description="Enable caching for expensive operations"), # Disabled to show multi-agent
614
+ max_iterations: int = Form(20, description="Maximum workflow iterations"),
615
+ session_id: Optional[str] = Form(None, description="Session ID for follow-up requests")
616
+ ) -> JSONResponse:
617
+ """
618
+ Run complete data science workflow on uploaded dataset.
619
+
620
+ This is a thin wrapper - all logic lives in DataScienceCopilot.analyze().
621
+
622
+ Args:
623
+ file: CSV or Parquet file upload
624
+ task_description: Natural language description of the task
625
+ target_col: Optional target column for ML tasks
626
+ use_cache: Whether to use cached results
627
+ max_iterations: Maximum number of workflow steps
628
+
629
+ Returns:
630
+ JSON response with analysis results, workflow history, and execution stats
631
+
632
+ Example:
633
+ ```bash
634
+ curl -X POST http://localhost:8080/run \
635
+ -F "file=@data.csv" \
636
+ -F "task_description=Analyze this dataset and predict house prices" \
637
+ -F "target_col=price"
638
+ ```
639
+ """
640
+ if agent is None:
641
+ raise HTTPException(status_code=503, detail="Agent not initialized")
642
+
643
+ # 🆔 Generate or use provided session ID
644
+ if not session_id:
645
+ import uuid
646
+ session_id = str(uuid.uuid4())
647
+ logger.info(f"[SYNC] Created new session: {session_id[:8]}...")
648
+ else:
649
+ logger.info(f"[SYNC] Using provided session: {session_id[:8]}...")
650
+
651
+ # 👥 Get isolated agent for this session
652
+ session_agent = await get_agent_for_session(session_id)
653
+
654
+ # Handle follow-up requests (no file, using session memory)
655
+ if file is None:
656
+ logger.info(f"Follow-up request without file, using session memory")
657
+ logger.info(f"Task: {task_description}")
658
+
659
+ # 🛡️ VALIDATION: Check if session has a dataset
660
+ if not (hasattr(session_agent, 'session') and session_agent.session and session_agent.session.last_dataset):
661
+ logger.warning("No file uploaded and no session dataset available")
662
+ return JSONResponse(
663
+ content={
664
+ "success": False,
665
+ "error": "No dataset available",
666
+ "message": "Please upload a CSV, Excel, or Parquet file first before asking questions."
667
+ },
668
+ status_code=400
669
+ )
670
+
671
+ # Get the agent's actual session UUID for SSE routing
672
+ actual_session_id = session_agent.session.session_id if hasattr(session_agent, 'session') and session_agent.session else session_id
673
+ print(f"[SSE] Follow-up using agent session UUID: {actual_session_id}")
674
+
675
+ # NO progress_callback - orchestrator emits directly to UUID
676
+
677
+ try:
678
+ # Agent's session memory should resolve file_path from context
679
+ result = session_agent.analyze(
680
+ file_path="", # Empty - will be resolved by session memory
681
+ task_description=task_description,
682
+ target_col=target_col,
683
+ use_cache=use_cache,
684
+ max_iterations=max_iterations
685
+ )
686
+
687
+ logger.info(f"Follow-up analysis completed: {result.get('status')}")
688
+
689
+ # Send appropriate completion event based on status
690
+ if result.get("status") == "error":
691
+ progress_manager.emit(actual_session_id, {
692
+ "type": "analysis_failed",
693
+ "status": "error",
694
+ "message": result.get("summary", "❌ Analysis failed"),
695
+ "error": result.get("error", "No dataset available")
696
+ })
697
+ else:
698
+ progress_manager.emit(actual_session_id, {
699
+ "type": "analysis_complete",
700
+ "status": result.get("status"),
701
+ "message": "✅ Analysis completed successfully!"
702
+ })
703
+
704
+ # Make result JSON serializable
705
+ def make_json_serializable(obj):
706
+ if isinstance(obj, dict):
707
+ return {k: make_json_serializable(v) for k, v in obj.items()}
708
+ elif isinstance(obj, list):
709
+ return [make_json_serializable(item) for item in obj]
710
+ elif hasattr(obj, '__class__') and obj.__class__.__name__ in ['Figure', 'Axes', 'Artist']:
711
+ return f"<{obj.__class__.__name__} object - see artifacts>"
712
+ elif isinstance(obj, (str, int, float, bool, type(None))):
713
+ return obj
714
+ else:
715
+ try:
716
+ return str(obj)
717
+ except:
718
+ return f"<{type(obj).__name__}>"
719
+
720
+ serializable_result = make_json_serializable(result)
721
+
722
+ return JSONResponse(
723
+ content={
724
+ "success": result.get("status") == "success",
725
+ "result": serializable_result,
726
+ "metadata": {
727
+ "filename": "session_context",
728
+ "task": task_description,
729
+ "target": target_col,
730
+ "provider": agent.provider,
731
+ "follow_up": True
732
+ }
733
+ },
734
+ status_code=200
735
+ )
736
+
737
+ except Exception as e:
738
+ logger.error(f"Follow-up analysis failed: {str(e)}", exc_info=True)
739
+ raise HTTPException(
740
+ status_code=500,
741
+ detail={
742
+ "error": str(e),
743
+ "error_type": type(e).__name__,
744
+ "message": "Follow-up request failed. Make sure you've uploaded a file first."
745
+ }
746
+ )
747
+
748
+ # Validate file format for new uploads
749
+ filename = file.filename.lower()
750
+ if not (filename.endswith('.csv') or filename.endswith('.parquet')):
751
+ raise HTTPException(
752
+ status_code=400,
753
+ detail="Invalid file format. Only CSV and Parquet files are supported."
754
+ )
755
+
756
+ # Use /tmp for Cloud Run (ephemeral storage)
757
+ temp_dir = Path("/tmp") / "data_science_agent"
758
+ temp_dir.mkdir(parents=True, exist_ok=True)
759
+
760
+ temp_file_path = None
761
+
762
+ try:
763
+ # Save uploaded file to temporary location
764
+ temp_file_path = temp_dir / file.filename
765
+ logger.info(f"Saving uploaded file to: {temp_file_path}")
766
+
767
+ with open(temp_file_path, "wb") as buffer:
768
+ shutil.copyfileobj(file.file, buffer)
769
+
770
+ logger.info(f"File saved successfully: {file.filename} ({os.path.getsize(temp_file_path)} bytes)")
771
+
772
+ # Get the agent's actual session UUID for SSE routing (BEFORE analyze())
773
+ actual_session_id = session_agent.session.session_id if hasattr(session_agent, 'session') and session_agent.session else session_id
774
+ print(f"[SSE] File upload using agent session UUID: {actual_session_id}")
775
+
776
+ # NO progress_callback - orchestrator emits directly to UUID
777
+
778
+ # Call existing agent logic
779
+ logger.info(f"Starting analysis with task: {task_description}")
780
+ result = session_agent.analyze(
781
+ file_path=str(temp_file_path),
782
+ task_description=task_description,
783
+ target_col=target_col,
784
+ use_cache=use_cache,
785
+ max_iterations=max_iterations
786
+ )
787
+
788
+ logger.info(f"Analysis completed: {result.get('status')}")
789
+
790
+ # Send appropriate completion event based on status
791
+ if result.get("status") == "error":
792
+ progress_manager.emit(actual_session_id, {
793
+ "type": "analysis_failed",
794
+ "status": "error",
795
+ "message": result.get("summary", "❌ Analysis failed"),
796
+ "error": result.get("error", "Analysis error")
797
+ })
798
+ else:
799
+ progress_manager.emit(actual_session_id, {
800
+ "type": "analysis_complete",
801
+ "status": result.get("status"),
802
+ "message": "✅ Analysis completed successfully!"
803
+ })
804
+
805
+ # Filter out non-JSON-serializable objects (like matplotlib/plotly Figures)
806
+ def make_json_serializable(obj):
807
+ """Recursively convert objects to JSON-serializable format."""
808
+ if isinstance(obj, dict):
809
+ return {k: make_json_serializable(v) for k, v in obj.items()}
810
+ elif isinstance(obj, list):
811
+ return [make_json_serializable(item) for item in obj]
812
+ elif hasattr(obj, '__class__') and obj.__class__.__name__ in ['Figure', 'Axes', 'Artist']:
813
+ # Skip matplotlib/plotly Figure objects
814
+ return f"<{obj.__class__.__name__} object - see artifacts>"
815
+ elif isinstance(obj, (str, int, float, bool, type(None))):
816
+ return obj
817
+ else:
818
+ # Try to convert to string for other types
819
+ try:
820
+ return str(obj)
821
+ except:
822
+ return f"<{type(obj).__name__}>"
823
+
824
+ serializable_result = make_json_serializable(result)
825
+
826
+ # Return result with ACTUAL session UUID for SSE
827
+ return JSONResponse(
828
+ content={
829
+ "success": result.get("status") == "success",
830
+ "result": serializable_result,
831
+ "session_id": actual_session_id, # Return UUID for SSE connection
832
+ "metadata": {
833
+ "filename": file.filename,
834
+ "task": task_description,
835
+ "target": target_col,
836
+ "provider": agent.provider
837
+ }
838
+ },
839
+ status_code=200
840
+ )
841
+
842
+ except Exception as e:
843
+ logger.error(f"Analysis failed: {str(e)}", exc_info=True)
844
+ raise HTTPException(
845
+ status_code=500,
846
+ detail={
847
+ "error": str(e),
848
+ "error_type": type(e).__name__,
849
+ "message": "Analysis workflow failed. Check logs for details."
850
+ }
851
+ )
852
+
853
+ finally:
854
+ # Keep temporary file for session continuity (follow-up requests)
855
+ # Files in /tmp are automatically cleaned up by the OS
856
+ # For HuggingFace Spaces: space restart clears /tmp
857
+ # For production: implement session-based cleanup after timeout
858
+ pass
859
+
860
+
861
+ @app.post("/profile")
862
+ async def profile_dataset(
863
+ file: UploadFile = File(..., description="Dataset file (CSV or Parquet)")
864
+ ) -> JSONResponse:
865
+ """
866
+ Quick dataset profiling without full workflow.
867
+
868
+ Returns basic statistics, data types, and quality issues.
869
+ Useful for initial data exploration without running full analysis.
870
+
871
+ Example:
872
+ ```bash
873
+ curl -X POST http://localhost:8080/profile \
874
+ -F "file=@data.csv"
875
+ ```
876
+ """
877
+ if agent is None:
878
+ raise HTTPException(status_code=503, detail="Agent not initialized")
879
+
880
+ filename = file.filename.lower()
881
+ if not (filename.endswith('.csv') or filename.endswith('.parquet')):
882
+ raise HTTPException(
883
+ status_code=400,
884
+ detail="Invalid file format. Only CSV and Parquet files are supported."
885
+ )
886
+
887
+ temp_dir = Path("/tmp") / "data_science_agent"
888
+ temp_dir.mkdir(parents=True, exist_ok=True)
889
+ temp_file_path = None
890
+
891
+ try:
892
+ # Save file temporarily
893
+ temp_file_path = temp_dir / file.filename
894
+ with open(temp_file_path, "wb") as buffer:
895
+ shutil.copyfileobj(file.file, buffer)
896
+
897
+ # Import profiling tool directly
898
+ from tools.data_profiling import profile_dataset as profile_tool
899
+ from tools.data_profiling import detect_data_quality_issues
900
+
901
+ # Run profiling tools
902
+ logger.info(f"Profiling dataset: {file.filename}")
903
+ profile_result = profile_tool(str(temp_file_path))
904
+ quality_result = detect_data_quality_issues(str(temp_file_path))
905
+
906
+ return JSONResponse(
907
+ content={
908
+ "success": True,
909
+ "filename": file.filename,
910
+ "profile": profile_result,
911
+ "quality_issues": quality_result
912
+ },
913
+ status_code=200
914
+ )
915
+
916
+ except Exception as e:
917
+ logger.error(f"Profiling failed: {str(e)}", exc_info=True)
918
+ raise HTTPException(
919
+ status_code=500,
920
+ detail={
921
+ "error": str(e),
922
+ "error_type": type(e).__name__
923
+ }
924
+ )
925
+
926
+ finally:
927
+ if temp_file_path and temp_file_path.exists():
928
+ try:
929
+ temp_file_path.unlink()
930
+ except Exception as e:
931
+ logger.warning(f"Failed to cleanup temp file: {e}")
932
+
933
+
934
+ @app.get("/tools")
935
+ async def list_tools():
936
+ """
937
+ List all available tools in the agent.
938
+
939
+ Returns tool names organized by category.
940
+ Useful for understanding agent capabilities.
941
+ """
942
+ if agent is None:
943
+ raise HTTPException(status_code=503, detail="Agent not initialized")
944
+
945
+ from tools.tools_registry import get_tools_by_category
946
+
947
+ return {
948
+ "total_tools": len(agent.tool_functions),
949
+ "tools_by_category": get_tools_by_category(),
950
+ "all_tools": list(agent.tool_functions.keys())
951
+ }
952
+
953
+
954
+ class ChatMessage(BaseModel):
955
+ """Chat message model."""
956
+ role: str # 'user' or 'assistant'
957
+ content: str
958
+
959
+
960
+ class ChatRequest(BaseModel):
961
+ """Chat request model."""
962
+ messages: List[ChatMessage]
963
+ stream: bool = False
964
+
965
+
966
+ @app.post("/chat")
967
+ async def chat(request: ChatRequest) -> JSONResponse:
968
+ """
969
+ Chat endpoint for conversational interface.
970
+
971
+ Processes chat messages and returns agent responses.
972
+ Uses the same underlying agent as /run but in chat format.
973
+
974
+ Args:
975
+ request: Chat request with message history
976
+
977
+ Returns:
978
+ JSON response with agent's reply
979
+ """
980
+ if agent is None:
981
+ raise HTTPException(status_code=503, detail="Agent not initialized")
982
+
983
+ try:
984
+ # Extract the latest user message
985
+ user_messages = [msg for msg in request.messages if msg.role == "user"]
986
+ if not user_messages:
987
+ raise HTTPException(status_code=400, detail="No user message found")
988
+
989
+ latest_message = user_messages[-1].content
990
+
991
+ # Check for API key
992
+ api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
993
+ if not api_key:
994
+ raise HTTPException(
995
+ status_code=500,
996
+ detail="GOOGLE_API_KEY or GEMINI_API_KEY not configured. Please set the environment variable."
997
+ )
998
+
999
+ # Use Google Gemini API
1000
+ import google.generativeai as genai
1001
+
1002
+ logger.info(f"Configuring Gemini with API key (length: {len(api_key)})")
1003
+ genai.configure(api_key=api_key)
1004
+
1005
+ # Safety settings for data science content
1006
+ safety_settings = [
1007
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
1008
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
1009
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
1010
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
1011
+ ]
1012
+
1013
+ # Initialize Gemini model (system_instruction not supported in this SDK version)
1014
+ model = genai.GenerativeModel(
1015
+ model_name=os.getenv("GEMINI_MODEL", "gemini-2.5-flash-lite"),
1016
+ generation_config={"temperature": 0.7},
1017
+ safety_settings=safety_settings
1018
+ )
1019
+
1020
+ # System message will be prepended to first user message
1021
+ system_msg = "You are a Senior Data Science Autonomous Agent. You help users with end-to-end machine learning, data profiling, visualization, and strategic insights. Use a professional, technical yet accessible tone. Provide code snippets in Python if requested. You have access to tools for data analysis, ML training, visualization, and more.\\n\\n"
1022
+
1023
+ # Convert messages to Gemini format (exclude system message, just conversation)
1024
+ chat_history = []
1025
+ first_user_msg = True
1026
+ for msg in request.messages[:-1]: # Exclude the latest message
1027
+ content = msg.content
1028
+ # Prepend system instruction to first user message
1029
+ if first_user_msg and msg.role == "user":
1030
+ content = system_msg + content
1031
+ first_user_msg = False
1032
+ chat_history.append({
1033
+ "role": "user" if msg.role == "user" else "model",
1034
+ "parts": [content]
1035
+ })
1036
+
1037
+ # Start chat with history
1038
+ chat = model.start_chat(history=chat_history)
1039
+
1040
+ # Send the latest message
1041
+ response = chat.send_message(latest_message)
1042
+
1043
+ assistant_message = response.text
1044
+
1045
+ return JSONResponse(
1046
+ content={
1047
+ "success": True,
1048
+ "message": assistant_message,
1049
+ "model": "gemini-2.0-flash-exp",
1050
+ "provider": "gemini"
1051
+ },
1052
+ status_code=200
1053
+ )
1054
+
1055
+ except Exception as e:
1056
+ logger.error(f"Chat failed: {str(e)}", exc_info=True)
1057
+ raise HTTPException(
1058
+ status_code=500,
1059
+ detail={
1060
+ "error": str(e),
1061
+ "error_type": type(e).__name__
1062
+ }
1063
+ )
1064
+
1065
+
1066
+ # ==================== FILE STORAGE API ====================
1067
+ # These endpoints handle persistent file storage with R2 + Supabase
1068
+
1069
+ class FileMetadataResponse(BaseModel):
1070
+ """Response model for file metadata."""
1071
+ id: str
1072
+ file_type: str
1073
+ file_name: str
1074
+ size_bytes: int
1075
+ created_at: str
1076
+ expires_at: str
1077
+ download_url: Optional[str] = None
1078
+ metadata: Dict[str, Any] = {}
1079
+
1080
+ class UserFilesResponse(BaseModel):
1081
+ """Response model for user files list."""
1082
+ success: bool
1083
+ files: List[FileMetadataResponse]
1084
+ total_count: int
1085
+ total_size_mb: float
1086
+
1087
+ @app.get("/api/files")
1088
+ async def get_user_files(
1089
+ user_id: str,
1090
+ file_type: Optional[str] = None,
1091
+ session_id: Optional[str] = None
1092
+ ):
1093
+ """
1094
+ Get all files for a user.
1095
+
1096
+ Query params:
1097
+ - user_id: User ID (required)
1098
+ - file_type: Filter by type (plot, csv, report, model)
1099
+ - session_id: Filter by chat session
1100
+ """
1101
+ try:
1102
+ from src.storage.user_files_service import get_files_service, FileType
1103
+ from src.storage.r2_storage import get_r2_service
1104
+
1105
+ files_service = get_files_service()
1106
+ r2_service = get_r2_service()
1107
+
1108
+ # Convert file_type string to enum if provided
1109
+ file_type_enum = None
1110
+ if file_type:
1111
+ file_type_enum = FileType(file_type)
1112
+
1113
+ files = files_service.get_user_files(
1114
+ user_id=user_id,
1115
+ file_type=file_type_enum,
1116
+ session_id=session_id
1117
+ )
1118
+
1119
+ # Generate download URLs
1120
+ file_responses = []
1121
+ total_size = 0
1122
+ for f in files:
1123
+ download_url = None
1124
+ if f.file_type == FileType.CSV:
1125
+ download_url = r2_service.get_csv_download_url(f.r2_key)
1126
+ elif f.file_type in [FileType.REPORT, FileType.PLOT]:
1127
+ download_url = r2_service.get_report_url(f.r2_key)
1128
+
1129
+ file_responses.append(FileMetadataResponse(
1130
+ id=f.id,
1131
+ file_type=f.file_type.value,
1132
+ file_name=f.file_name,
1133
+ size_bytes=f.size_bytes,
1134
+ created_at=f.created_at.isoformat(),
1135
+ expires_at=f.expires_at.isoformat(),
1136
+ download_url=download_url,
1137
+ metadata=f.metadata
1138
+ ))
1139
+ total_size += f.size_bytes
1140
+
1141
+ return UserFilesResponse(
1142
+ success=True,
1143
+ files=file_responses,
1144
+ total_count=len(files),
1145
+ total_size_mb=round(total_size / (1024 * 1024), 2)
1146
+ )
1147
+
1148
+ except ImportError:
1149
+ # Storage services not configured
1150
+ return UserFilesResponse(
1151
+ success=True,
1152
+ files=[],
1153
+ total_count=0,
1154
+ total_size_mb=0
1155
+ )
1156
+ except Exception as e:
1157
+ logger.error(f"Error fetching user files: {e}")
1158
+ raise HTTPException(status_code=500, detail=str(e))
1159
+
1160
+ @app.get("/api/files/{file_id}")
1161
+ async def get_file(file_id: str):
1162
+ """Get a specific file by ID with download URL."""
1163
+ try:
1164
+ from src.storage.user_files_service import get_files_service, FileType
1165
+ from src.storage.r2_storage import get_r2_service
1166
+
1167
+ files_service = get_files_service()
1168
+ r2_service = get_r2_service()
1169
+
1170
+ file = files_service.get_file_by_id(file_id)
1171
+ if not file:
1172
+ raise HTTPException(status_code=404, detail="File not found")
1173
+
1174
+ # Generate appropriate URL
1175
+ download_url = None
1176
+ if file.file_type == FileType.CSV:
1177
+ download_url = r2_service.get_csv_download_url(file.r2_key)
1178
+ elif file.file_type == FileType.PLOT:
1179
+ # For plots, return the plot data directly
1180
+ plot_data = r2_service.get_plot_data(file.r2_key)
1181
+ return {
1182
+ "success": True,
1183
+ "file": {
1184
+ "id": file.id,
1185
+ "file_type": file.file_type.value,
1186
+ "file_name": file.file_name,
1187
+ "metadata": file.metadata
1188
+ },
1189
+ "plot_data": plot_data
1190
+ }
1191
+ else:
1192
+ download_url = r2_service.get_report_url(file.r2_key)
1193
+
1194
+ return {
1195
+ "success": True,
1196
+ "file": FileMetadataResponse(
1197
+ id=file.id,
1198
+ file_type=file.file_type.value,
1199
+ file_name=file.file_name,
1200
+ size_bytes=file.size_bytes,
1201
+ created_at=file.created_at.isoformat(),
1202
+ expires_at=file.expires_at.isoformat(),
1203
+ download_url=download_url,
1204
+ metadata=file.metadata
1205
+ )
1206
+ }
1207
+
1208
+ except HTTPException:
1209
+ raise
1210
+ except Exception as e:
1211
+ logger.error(f"Error fetching file: {e}")
1212
+ raise HTTPException(status_code=500, detail=str(e))
1213
+
1214
+ @app.delete("/api/files/{file_id}")
1215
+ async def delete_file(file_id: str, user_id: str):
1216
+ """Delete a file (both from R2 and Supabase)."""
1217
+ try:
1218
+ from src.storage.user_files_service import get_files_service
1219
+ from src.storage.r2_storage import get_r2_service
1220
+
1221
+ files_service = get_files_service()
1222
+ r2_service = get_r2_service()
1223
+
1224
+ file = files_service.get_file_by_id(file_id)
1225
+ if not file:
1226
+ raise HTTPException(status_code=404, detail="File not found")
1227
+
1228
+ # Verify ownership
1229
+ if file.user_id != user_id:
1230
+ raise HTTPException(status_code=403, detail="Not authorized")
1231
+
1232
+ # Delete from R2
1233
+ r2_service.delete_file(file.r2_key)
1234
+
1235
+ # Delete from Supabase
1236
+ files_service.hard_delete_file(file_id)
1237
+
1238
+ return {"success": True, "message": "File deleted"}
1239
+
1240
+ except HTTPException:
1241
+ raise
1242
+ except Exception as e:
1243
+ logger.error(f"Error deleting file: {e}")
1244
+ raise HTTPException(status_code=500, detail=str(e))
1245
+
1246
+ @app.get("/api/files/stats/{user_id}")
1247
+ async def get_storage_stats(user_id: str):
1248
+ """Get storage statistics for a user."""
1249
+ try:
1250
+ from src.storage.user_files_service import get_files_service
1251
+
1252
+ files_service = get_files_service()
1253
+ stats = files_service.get_user_storage_stats(user_id)
1254
+
1255
+ return {
1256
+ "success": True,
1257
+ "stats": stats
1258
+ }
1259
+
1260
+ except Exception as e:
1261
+ logger.error(f"Error getting stats: {e}")
1262
+ return {
1263
+ "success": True,
1264
+ "stats": {
1265
+ "total_files": 0,
1266
+ "total_size_bytes": 0,
1267
+ "total_size_mb": 0,
1268
+ "by_type": {}
1269
+ }
1270
+ }
1271
+
1272
+ @app.post("/api/files/extend/{file_id}")
1273
+ async def extend_file_expiration(file_id: str, user_id: str, days: int = 7):
1274
+ """Extend a file's expiration date."""
1275
+ try:
1276
+ from src.storage.user_files_service import get_files_service
1277
+
1278
+ files_service = get_files_service()
1279
+
1280
+ file = files_service.get_file_by_id(file_id)
1281
+ if not file:
1282
+ raise HTTPException(status_code=404, detail="File not found")
1283
+
1284
+ if file.user_id != user_id:
1285
+ raise HTTPException(status_code=403, detail="Not authorized")
1286
+
1287
+ success = files_service.extend_expiration(file_id, days)
1288
+
1289
+ return {"success": success}
1290
+
1291
+ except HTTPException:
1292
+ raise
1293
+ except Exception as e:
1294
+ logger.error(f"Error extending expiration: {e}")
1295
+ raise HTTPException(status_code=500, detail=str(e))
1296
+
1297
+
1298
+ # Error handlers
1299
+ @app.exception_handler(HTTPException)
1300
+ async def http_exception_handler(request, exc):
1301
+ """Custom error response format."""
1302
+ return JSONResponse(
1303
+ status_code=exc.status_code,
1304
+ content={
1305
+ "success": False,
1306
+ "error": exc.detail,
1307
+ "status_code": exc.status_code
1308
+ }
1309
+ )
1310
+
1311
+
1312
+ @app.exception_handler(Exception)
1313
+ async def general_exception_handler(request, exc):
1314
+ """Catch-all error handler."""
1315
+ logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
1316
+ return JSONResponse(
1317
+ status_code=500,
1318
+ content={
1319
+ "success": False,
1320
+ "error": "Internal server error",
1321
+ "detail": str(exc),
1322
+ "error_type": type(exc).__name__
1323
+ }
1324
+ )
1325
+
1326
+
1327
+ @app.get("/outputs/{file_path:path}")
1328
+ async def serve_output_files(file_path: str):
1329
+ """
1330
+ Serve generated output files (reports, plots, models, etc.).
1331
+ Checks multiple locations: ./outputs, /tmp/data_science_agent/outputs, and /tmp/data_science_agent.
1332
+ """
1333
+ # Locations to check (in order of priority)
1334
+ search_paths = [
1335
+ Path("./outputs") / file_path, # Local development
1336
+ Path("/tmp/data_science_agent/outputs") / file_path, # Production with subdirs
1337
+ Path("/tmp/data_science_agent") / file_path, # Production flat OR relative paths like plots/xxx.html
1338
+ Path("/tmp/data_science_agent/outputs") / Path(file_path).name, # Production filename only
1339
+ Path("/tmp/data_science_agent") / Path(file_path).name, # Production root filename only
1340
+ Path("./outputs") / Path(file_path).name, # Local development filename only
1341
+ ]
1342
+
1343
+ output_path = None
1344
+ for path in search_paths:
1345
+ logger.debug(f"Checking path: {path}")
1346
+ if path.exists() and path.is_file():
1347
+ output_path = path
1348
+ logger.info(f"Found file at: {path}")
1349
+ break
1350
+
1351
+ if output_path is None:
1352
+ logger.error(f"File not found in any location: {file_path}")
1353
+ logger.error(f"Searched paths: {[str(p) for p in search_paths]}")
1354
+ raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
1355
+
1356
+ # Security: prevent directory traversal
1357
+ resolved_path = output_path.resolve()
1358
+ allowed_bases = [
1359
+ Path("./outputs").resolve(),
1360
+ Path("/tmp/data_science_agent").resolve()
1361
+ ]
1362
+
1363
+ # Check if path is within allowed directories
1364
+ is_allowed = False
1365
+ for base in allowed_bases:
1366
+ try:
1367
+ resolved_path.relative_to(base)
1368
+ is_allowed = True
1369
+ break
1370
+ except ValueError:
1371
+ continue
1372
+
1373
+ if not is_allowed:
1374
+ raise HTTPException(status_code=403, detail="Access denied")
1375
+
1376
+ # Determine media type based on file extension
1377
+ media_type = None
1378
+ if file_path.endswith('.html'):
1379
+ media_type = "text/html"
1380
+ elif file_path.endswith('.csv'):
1381
+ media_type = "text/csv"
1382
+ elif file_path.endswith('.json'):
1383
+ media_type = "application/json"
1384
+ elif file_path.endswith('.png'):
1385
+ media_type = "image/png"
1386
+ elif file_path.endswith('.jpg') or file_path.endswith('.jpeg'):
1387
+ media_type = "image/jpeg"
1388
+
1389
+ return FileResponse(output_path, media_type=media_type)
1390
+
1391
+
1392
+ # ============== HUGGINGFACE EXPORT ENDPOINT ==============
1393
+
1394
+ class HuggingFaceExportRequest(BaseModel):
1395
+ """Request model for HuggingFace export."""
1396
+ user_id: str
1397
+ session_id: str
1398
+
1399
+ @app.post("/api/export/huggingface")
1400
+ async def export_to_huggingface(request: HuggingFaceExportRequest):
1401
+ """
1402
+ Export session assets (datasets, models, plots) to user's HuggingFace account.
1403
+
1404
+ Requires user to have connected their HuggingFace token in settings.
1405
+ """
1406
+ import glob
1407
+
1408
+ logger.info(f"[HF Export] Starting export for user {request.user_id[:8]}... session {request.session_id[:8]}...")
1409
+
1410
+ try:
1411
+ # Try to import supabase - may not be installed
1412
+ try:
1413
+ from supabase import create_client, Client
1414
+ except ImportError as e:
1415
+ logger.error(f"[HF Export] Supabase package not installed: {e}")
1416
+ raise HTTPException(status_code=500, detail="Server error: supabase package not installed")
1417
+
1418
+ # Get user's HuggingFace credentials from Supabase
1419
+ supabase_url = os.getenv("SUPABASE_URL")
1420
+ supabase_key = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
1421
+
1422
+ logger.info(f"[HF Export] Supabase URL configured: {bool(supabase_url)}, Key configured: {bool(supabase_key)}")
1423
+
1424
+ if not supabase_url or not supabase_key:
1425
+ raise HTTPException(status_code=500, detail="Supabase configuration missing")
1426
+
1427
+ supabase: Client = create_client(supabase_url, supabase_key)
1428
+
1429
+ # Fetch user's HuggingFace token from hf_tokens table (not user_profiles)
1430
+ logger.info(f"[HF Export] Fetching HF token from hf_tokens table...")
1431
+ try:
1432
+ result = supabase.table("hf_tokens").select(
1433
+ "huggingface_token, huggingface_username"
1434
+ ).eq("user_id", request.user_id).execute()
1435
+
1436
+ logger.info(f"[HF Export] Query result: {result.data}")
1437
+
1438
+ if not result.data or len(result.data) == 0:
1439
+ raise HTTPException(status_code=404, detail="HuggingFace not connected. Please connect in Settings first.")
1440
+
1441
+ row = result.data[0]
1442
+ hf_token = row.get("huggingface_token")
1443
+ hf_username = row.get("huggingface_username")
1444
+ except HTTPException:
1445
+ raise
1446
+ except Exception as e:
1447
+ logger.error(f"[HF Export] Supabase query error: {e}")
1448
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
1449
+
1450
+ if not hf_token:
1451
+ raise HTTPException(
1452
+ status_code=400,
1453
+ detail="HuggingFace token not found. Please connect in Settings."
1454
+ )
1455
+
1456
+ # Import HuggingFace storage service
1457
+ try:
1458
+ from src.storage.huggingface_storage import HuggingFaceStorage
1459
+ logger.info(f"[HF Export] HuggingFaceStorage imported successfully")
1460
+ except ImportError as e:
1461
+ logger.error(f"[HF Export] Failed to import HuggingFaceStorage: {e}")
1462
+ raise HTTPException(status_code=500, detail=f"Server error: {str(e)}")
1463
+
1464
+ try:
1465
+ hf_service = HuggingFaceStorage(hf_token=hf_token)
1466
+ logger.info(f"[HF Export] HuggingFaceStorage initialized for user: {hf_username}")
1467
+ except Exception as e:
1468
+ logger.error(f"[HF Export] Failed to initialize HuggingFaceStorage: {e}")
1469
+ raise HTTPException(status_code=500, detail=f"HuggingFace error: {str(e)}")
1470
+
1471
+ # Collect all session assets
1472
+ uploaded_files = []
1473
+ errors = []
1474
+
1475
+ # Session-specific output directory - check /tmp/data_science_agent for HF Spaces
1476
+ session_outputs_dir = Path(f"./outputs/{request.session_id}")
1477
+ global_outputs_dir = Path("./outputs")
1478
+ tmp_outputs_dir = Path("/tmp/data_science_agent")
1479
+
1480
+ logger.info(f"[HF Export] Looking for files in: {session_outputs_dir}, {global_outputs_dir}, {tmp_outputs_dir}")
1481
+
1482
+ # Upload datasets (CSVs)
1483
+ csv_patterns = [
1484
+ session_outputs_dir / "*.csv",
1485
+ global_outputs_dir / "*.csv",
1486
+ tmp_outputs_dir / "*.csv"
1487
+ ]
1488
+ for pattern in csv_patterns:
1489
+ for csv_file in glob.glob(str(pattern)):
1490
+ try:
1491
+ logger.info(f"[HF Export] Uploading dataset: {csv_file}")
1492
+ result = hf_service.upload_dataset(
1493
+ file_path=csv_file,
1494
+ session_id=request.session_id,
1495
+ file_name=Path(csv_file).name,
1496
+ compress=True
1497
+ )
1498
+ if result.get("success"):
1499
+ uploaded_files.append({"type": "dataset", "name": Path(csv_file).name, "url": result.get("url")})
1500
+ else:
1501
+ errors.append(f"Dataset {Path(csv_file).name}: {result.get('error', 'Unknown error')}")
1502
+ except Exception as e:
1503
+ logger.error(f"[HF Export] Dataset upload error: {e}")
1504
+ errors.append(f"Dataset {Path(csv_file).name}: {str(e)}")
1505
+
1506
+ # Upload models (PKL files)
1507
+ model_patterns = [
1508
+ session_outputs_dir / "models" / "*.pkl",
1509
+ global_outputs_dir / "models" / "*.pkl",
1510
+ tmp_outputs_dir / "models" / "*.pkl"
1511
+ ]
1512
+ for pattern in model_patterns:
1513
+ for model_file in glob.glob(str(pattern)):
1514
+ try:
1515
+ logger.info(f"[HF Export] Uploading model: {model_file}")
1516
+ result = hf_service.upload_model(
1517
+ model_path=model_file,
1518
+ session_id=request.session_id,
1519
+ model_name=Path(model_file).stem,
1520
+ model_type="sklearn"
1521
+ )
1522
+ if result.get("success"):
1523
+ uploaded_files.append({"type": "model", "name": Path(model_file).name, "url": result.get("url")})
1524
+ else:
1525
+ errors.append(f"Model {Path(model_file).name}: {result.get('error', 'Unknown error')}")
1526
+ except Exception as e:
1527
+ logger.error(f"[HF Export] Model upload error: {e}")
1528
+ errors.append(f"Model {Path(model_file).name}: {str(e)}")
1529
+
1530
+ # Upload visualizations (HTML plots) - use generic file upload
1531
+ plot_patterns = [
1532
+ session_outputs_dir / "*.html",
1533
+ global_outputs_dir / "*.html",
1534
+ session_outputs_dir / "plots" / "*.html",
1535
+ global_outputs_dir / "plots" / "*.html",
1536
+ tmp_outputs_dir / "*.html",
1537
+ tmp_outputs_dir / "plots" / "*.html"
1538
+ ]
1539
+ for pattern in plot_patterns:
1540
+ for plot_file in glob.glob(str(pattern)):
1541
+ # Skip index.html or other non-plot files
1542
+ if "index" in Path(plot_file).name.lower():
1543
+ continue
1544
+ try:
1545
+ logger.info(f"[HF Export] Uploading HTML plot: {plot_file}")
1546
+ result = hf_service.upload_generic_file(
1547
+ file_path=plot_file,
1548
+ session_id=request.session_id,
1549
+ subfolder="plots"
1550
+ )
1551
+ if result.get("success"):
1552
+ uploaded_files.append({"type": "plot", "name": Path(plot_file).name, "url": result.get("url")})
1553
+ else:
1554
+ errors.append(f"Plot {Path(plot_file).name}: {result.get('error', 'Unknown error')}")
1555
+ except Exception as e:
1556
+ logger.error(f"[HF Export] Plot upload error: {e}")
1557
+ errors.append(f"Plot {Path(plot_file).name}: {str(e)}")
1558
+
1559
+ # Upload PNG images - use generic file upload
1560
+ image_patterns = [
1561
+ session_outputs_dir / "*.png",
1562
+ global_outputs_dir / "*.png",
1563
+ session_outputs_dir / "plots" / "*.png",
1564
+ global_outputs_dir / "plots" / "*.png",
1565
+ tmp_outputs_dir / "*.png",
1566
+ tmp_outputs_dir / "plots" / "*.png"
1567
+ ]
1568
+ for pattern in image_patterns:
1569
+ for image_file in glob.glob(str(pattern)):
1570
+ try:
1571
+ logger.info(f"[HF Export] Uploading image: {image_file}")
1572
+ result = hf_service.upload_generic_file(
1573
+ file_path=image_file,
1574
+ session_id=request.session_id,
1575
+ subfolder="images"
1576
+ )
1577
+ if result.get("success"):
1578
+ uploaded_files.append({"type": "image", "name": Path(image_file).name, "url": result.get("url")})
1579
+ else:
1580
+ errors.append(f"Image {Path(image_file).name}: {result.get('error', 'Unknown error')}")
1581
+ except Exception as e:
1582
+ logger.error(f"[HF Export] Image upload error: {e}")
1583
+ errors.append(f"Image {Path(image_file).name}: {str(e)}")
1584
+
1585
+ if not uploaded_files and errors:
1586
+ logger.error(f"[HF Export] All uploads failed: {errors}")
1587
+ raise HTTPException(
1588
+ status_code=500,
1589
+ detail=f"Export failed: {'; '.join(errors)}"
1590
+ )
1591
+
1592
+ if not uploaded_files and not errors:
1593
+ logger.info(f"[HF Export] No files found to export")
1594
+ return JSONResponse({
1595
+ "success": True,
1596
+ "uploaded_files": [],
1597
+ "errors": None,
1598
+ "message": "No files found to export. Run some analysis first to generate outputs."
1599
+ })
1600
+
1601
+ logger.info(f"[HF Export] Export completed: {len(uploaded_files)} files uploaded, {len(errors)} errors")
1602
+ return JSONResponse({
1603
+ "success": True,
1604
+ "uploaded_files": uploaded_files,
1605
+ "errors": errors if errors else None,
1606
+ "message": f"Successfully exported {len(uploaded_files)} files to HuggingFace"
1607
+ })
1608
+
1609
+ except HTTPException:
1610
+ raise
1611
+ except Exception as e:
1612
+ logger.error(f"HuggingFace export failed: {str(e)}")
1613
+ raise HTTPException(status_code=500, detail=f"Export failed: {str(e)}")
1614
+
1615
+ # Cloud Run listens on PORT environment variable
1616
+ if __name__ == "__main__":
1617
+ import uvicorn
1618
+
1619
+ port = int(os.getenv("PORT", 8080))
1620
+
1621
+ uvicorn.run(
1622
+ "app:app",
1623
+ host="0.0.0.0",
1624
+ port=port,
1625
+ log_level="info"
1626
+ )