vibesurf 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vibesurf might be problematic. Click here for more details.

Files changed (51) hide show
  1. vibe_surf/_version.py +2 -2
  2. vibe_surf/agents/browser_use_agent.py +68 -45
  3. vibe_surf/agents/prompts/report_writer_prompt.py +73 -0
  4. vibe_surf/agents/prompts/vibe_surf_prompt.py +85 -172
  5. vibe_surf/agents/report_writer_agent.py +380 -226
  6. vibe_surf/agents/vibe_surf_agent.py +879 -825
  7. vibe_surf/agents/views.py +130 -0
  8. vibe_surf/backend/api/activity.py +3 -1
  9. vibe_surf/backend/api/browser.py +9 -5
  10. vibe_surf/backend/api/config.py +8 -5
  11. vibe_surf/backend/api/files.py +59 -50
  12. vibe_surf/backend/api/models.py +2 -2
  13. vibe_surf/backend/api/task.py +45 -12
  14. vibe_surf/backend/database/manager.py +24 -18
  15. vibe_surf/backend/database/queries.py +199 -192
  16. vibe_surf/backend/database/schemas.py +1 -1
  17. vibe_surf/backend/main.py +4 -2
  18. vibe_surf/backend/shared_state.py +28 -35
  19. vibe_surf/backend/utils/encryption.py +3 -1
  20. vibe_surf/backend/utils/llm_factory.py +41 -36
  21. vibe_surf/browser/agent_browser_session.py +0 -4
  22. vibe_surf/browser/browser_manager.py +14 -8
  23. vibe_surf/browser/utils.py +5 -3
  24. vibe_surf/browser/watchdogs/dom_watchdog.py +0 -45
  25. vibe_surf/chrome_extension/background.js +4 -0
  26. vibe_surf/chrome_extension/scripts/api-client.js +13 -0
  27. vibe_surf/chrome_extension/scripts/file-manager.js +27 -71
  28. vibe_surf/chrome_extension/scripts/session-manager.js +21 -3
  29. vibe_surf/chrome_extension/scripts/ui-manager.js +831 -48
  30. vibe_surf/chrome_extension/sidepanel.html +21 -4
  31. vibe_surf/chrome_extension/styles/activity.css +365 -5
  32. vibe_surf/chrome_extension/styles/input.css +139 -0
  33. vibe_surf/cli.py +4 -22
  34. vibe_surf/common.py +35 -0
  35. vibe_surf/llm/openai_compatible.py +148 -93
  36. vibe_surf/logger.py +99 -0
  37. vibe_surf/{controller/vibesurf_tools.py → tools/browser_use_tools.py} +233 -219
  38. vibe_surf/tools/file_system.py +415 -0
  39. vibe_surf/{controller → tools}/mcp_client.py +4 -3
  40. vibe_surf/tools/report_writer_tools.py +21 -0
  41. vibe_surf/tools/vibesurf_tools.py +657 -0
  42. vibe_surf/tools/views.py +120 -0
  43. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/METADATA +6 -2
  44. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/RECORD +49 -43
  45. vibe_surf/controller/file_system.py +0 -53
  46. vibe_surf/controller/views.py +0 -37
  47. /vibe_surf/{controller → tools}/__init__.py +0 -0
  48. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/WHEEL +0 -0
  49. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/entry_points.txt +0 -0
  50. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/licenses/LICENSE +0 -0
  51. {vibesurf-0.1.10.dist-info → vibesurf-0.1.11.dist-info}/top_level.txt +0 -0
@@ -13,33 +13,36 @@ from ..utils.encryption import encrypt_api_key, decrypt_api_key
13
13
  import logging
14
14
  import json
15
15
 
16
- logger = logging.getLogger(__name__)
16
+ from vibe_surf.logger import get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
17
20
 
18
21
  class LLMProfileQueries:
19
22
  """Query operations for LLMProfile model"""
20
-
23
+
21
24
  @staticmethod
22
25
  async def create_profile(
23
- db: AsyncSession,
24
- profile_name: str,
25
- provider: str,
26
- model: str,
27
- api_key: Optional[str] = None,
28
- base_url: Optional[str] = None,
29
- temperature: Optional[float] = None,
30
- max_tokens: Optional[int] = None,
31
- top_p: Optional[float] = None,
32
- frequency_penalty: Optional[float] = None,
33
- seed: Optional[int] = None,
34
- provider_config: Optional[Dict[str, Any]] = None,
35
- description: Optional[str] = None,
36
- is_default: bool = False
26
+ db: AsyncSession,
27
+ profile_name: str,
28
+ provider: str,
29
+ model: str,
30
+ api_key: Optional[str] = None,
31
+ base_url: Optional[str] = None,
32
+ temperature: Optional[float] = None,
33
+ max_tokens: Optional[int] = None,
34
+ top_p: Optional[float] = None,
35
+ frequency_penalty: Optional[float] = None,
36
+ seed: Optional[int] = None,
37
+ provider_config: Optional[Dict[str, Any]] = None,
38
+ description: Optional[str] = None,
39
+ is_default: bool = False
37
40
  ) -> Dict[str, Any]:
38
41
  """Create a new LLM profile with encrypted API key"""
39
42
  try:
40
43
  # Encrypt API key if provided
41
44
  encrypted_api_key = encrypt_api_key(api_key) if api_key else None
42
-
45
+
43
46
  profile = LLMProfile(
44
47
  profile_name=profile_name,
45
48
  provider=provider,
@@ -55,11 +58,11 @@ class LLMProfileQueries:
55
58
  description=description,
56
59
  is_default=is_default
57
60
  )
58
-
61
+
59
62
  db.add(profile)
60
63
  await db.flush()
61
64
  await db.refresh(profile)
62
-
65
+
63
66
  # Extract data immediately to avoid greenlet issues
64
67
  profile_data = {
65
68
  "profile_id": profile.profile_id,
@@ -80,12 +83,12 @@ class LLMProfileQueries:
80
83
  "updated_at": profile.updated_at,
81
84
  "last_used_at": profile.last_used_at
82
85
  }
83
-
86
+
84
87
  return profile_data
85
88
  except Exception as e:
86
89
  logger.error(f"Failed to create LLM profile {profile_name}: {e}")
87
90
  raise
88
-
91
+
89
92
  @staticmethod
90
93
  async def get_profile(db: AsyncSession, profile_name: str) -> Optional[LLMProfile]:
91
94
  """Get LLM profile by name"""
@@ -102,7 +105,7 @@ class LLMProfileQueries:
102
105
  except Exception as e:
103
106
  logger.error(f"Failed to get LLM profile {profile_name}: {e}")
104
107
  raise
105
-
108
+
106
109
  @staticmethod
107
110
  async def get_profile_with_decrypted_key(db: AsyncSession, profile_name: str) -> Optional[Dict[str, Any]]:
108
111
  """Get LLM profile with decrypted API key"""
@@ -110,10 +113,10 @@ class LLMProfileQueries:
110
113
  profile = await LLMProfileQueries.get_profile(db, profile_name)
111
114
  if not profile:
112
115
  return None
113
-
116
+
114
117
  # Decrypt API key
115
118
  decrypted_api_key = decrypt_api_key(profile.encrypted_api_key) if profile.encrypted_api_key else None
116
-
119
+
117
120
  return {
118
121
  "profile_id": profile.profile_id,
119
122
  "profile_name": profile.profile_name,
@@ -137,42 +140,42 @@ class LLMProfileQueries:
137
140
  except Exception as e:
138
141
  logger.error(f"Failed to get LLM profile with decrypted key {profile_name}: {e}")
139
142
  raise
140
-
143
+
141
144
  @staticmethod
142
145
  async def list_profiles(
143
- db: AsyncSession,
144
- active_only: bool = True,
145
- limit: int = 50,
146
- offset: int = 0
146
+ db: AsyncSession,
147
+ active_only: bool = True,
148
+ limit: int = 50,
149
+ offset: int = 0
147
150
  ) -> List[LLMProfile]:
148
151
  """List LLM profiles"""
149
152
  try:
150
153
  query = select(LLMProfile)
151
-
154
+
152
155
  if active_only:
153
156
  query = query.where(LLMProfile.is_active == True)
154
-
157
+
155
158
  query = query.order_by(desc(LLMProfile.last_used_at), desc(LLMProfile.created_at))
156
159
  query = query.limit(limit).offset(offset)
157
-
160
+
158
161
  result = await db.execute(query)
159
162
  profiles = result.scalars().all()
160
-
163
+
161
164
  # Ensure all attributes are loaded for each profile
162
165
  for profile in profiles:
163
166
  _ = (profile.profile_id, profile.created_at, profile.updated_at,
164
167
  profile.last_used_at, profile.is_active, profile.is_default)
165
-
168
+
166
169
  return profiles
167
170
  except Exception as e:
168
171
  logger.error(f"Failed to list LLM profiles: {e}")
169
172
  raise
170
-
173
+
171
174
  @staticmethod
172
175
  async def update_profile(
173
- db: AsyncSession,
174
- profile_name: str,
175
- updates: Dict[str, Any]
176
+ db: AsyncSession,
177
+ profile_name: str,
178
+ updates: Dict[str, Any]
176
179
  ) -> bool:
177
180
  """Update LLM profile"""
178
181
  try:
@@ -183,18 +186,18 @@ class LLMProfileQueries:
183
186
  updates["encrypted_api_key"] = encrypt_api_key(api_key)
184
187
  else:
185
188
  updates["encrypted_api_key"] = None
186
-
189
+
187
190
  result = await db.execute(
188
191
  update(LLMProfile)
189
192
  .where(LLMProfile.profile_name == profile_name)
190
193
  .values(**updates)
191
194
  )
192
-
195
+
193
196
  return result.rowcount > 0
194
197
  except Exception as e:
195
198
  logger.error(f"Failed to update LLM profile {profile_name}: {e}")
196
199
  raise
197
-
200
+
198
201
  @staticmethod
199
202
  async def delete_profile(db: AsyncSession, profile_name: str) -> bool:
200
203
  """Delete LLM profile"""
@@ -206,7 +209,7 @@ class LLMProfileQueries:
206
209
  except Exception as e:
207
210
  logger.error(f"Failed to delete LLM profile {profile_name}: {e}")
208
211
  raise
209
-
212
+
210
213
  @staticmethod
211
214
  async def get_default_profile(db: AsyncSession) -> Optional[LLMProfile]:
212
215
  """Get the default LLM profile"""
@@ -223,7 +226,7 @@ class LLMProfileQueries:
223
226
  except Exception as e:
224
227
  logger.error(f"Failed to get default LLM profile: {e}")
225
228
  raise
226
-
229
+
227
230
  @staticmethod
228
231
  async def set_default_profile(db: AsyncSession, profile_name: str) -> bool:
229
232
  """Set a profile as default (and unset others)"""
@@ -232,19 +235,19 @@ class LLMProfileQueries:
232
235
  await db.execute(
233
236
  update(LLMProfile).values(is_default=False)
234
237
  )
235
-
238
+
236
239
  # Then set the specified profile as default
237
240
  result = await db.execute(
238
241
  update(LLMProfile)
239
242
  .where(LLMProfile.profile_name == profile_name)
240
243
  .values(is_default=True)
241
244
  )
242
-
245
+
243
246
  return result.rowcount > 0
244
247
  except Exception as e:
245
248
  logger.error(f"Failed to set default LLM profile {profile_name}: {e}")
246
249
  raise
247
-
250
+
248
251
  @staticmethod
249
252
  async def update_last_used(db: AsyncSession, profile_name: str) -> bool:
250
253
  """Update the last_used_at timestamp for a profile"""
@@ -259,16 +262,17 @@ class LLMProfileQueries:
259
262
  logger.error(f"Failed to update last_used for LLM profile {profile_name}: {e}")
260
263
  raise
261
264
 
265
+
262
266
  class McpProfileQueries:
263
267
  """Query operations for McpProfile model"""
264
-
268
+
265
269
  @staticmethod
266
270
  async def create_profile(
267
- db: AsyncSession,
268
- display_name: str,
269
- mcp_server_name: str,
270
- mcp_server_params: Dict[str, Any],
271
- description: Optional[str] = None
271
+ db: AsyncSession,
272
+ display_name: str,
273
+ mcp_server_name: str,
274
+ mcp_server_params: Dict[str, Any],
275
+ description: Optional[str] = None
272
276
  ) -> Dict[str, Any]:
273
277
  """Create a new MCP profile"""
274
278
  try:
@@ -278,11 +282,11 @@ class McpProfileQueries:
278
282
  mcp_server_params=mcp_server_params,
279
283
  description=description
280
284
  )
281
-
285
+
282
286
  db.add(profile)
283
287
  await db.flush()
284
288
  await db.refresh(profile)
285
-
289
+
286
290
  # Extract data immediately to avoid greenlet issues
287
291
  profile_data = {
288
292
  "mcp_id": profile.mcp_id,
@@ -295,12 +299,12 @@ class McpProfileQueries:
295
299
  "updated_at": profile.updated_at,
296
300
  "last_used_at": profile.last_used_at
297
301
  }
298
-
302
+
299
303
  return profile_data
300
304
  except Exception as e:
301
305
  logger.error(f"Failed to create MCP profile {display_name}: {e}")
302
306
  raise
303
-
307
+
304
308
  @staticmethod
305
309
  async def get_profile(db: AsyncSession, mcp_id: str) -> Optional[McpProfile]:
306
310
  """Get MCP profile by ID"""
@@ -317,7 +321,7 @@ class McpProfileQueries:
317
321
  except Exception as e:
318
322
  logger.error(f"Failed to get MCP profile {mcp_id}: {e}")
319
323
  raise
320
-
324
+
321
325
  @staticmethod
322
326
  async def get_profile_by_display_name(db: AsyncSession, display_name: str) -> Optional[McpProfile]:
323
327
  """Get MCP profile by display name"""
@@ -333,37 +337,37 @@ class McpProfileQueries:
333
337
  except Exception as e:
334
338
  logger.error(f"Failed to get MCP profile by display name {display_name}: {e}")
335
339
  raise
336
-
340
+
337
341
  @staticmethod
338
342
  async def list_profiles(
339
- db: AsyncSession,
340
- active_only: bool = True,
341
- limit: int = 50,
342
- offset: int = 0
343
+ db: AsyncSession,
344
+ active_only: bool = True,
345
+ limit: int = 50,
346
+ offset: int = 0
343
347
  ) -> List[McpProfile]:
344
348
  """List MCP profiles"""
345
349
  try:
346
350
  query = select(McpProfile)
347
-
351
+
348
352
  if active_only:
349
353
  query = query.where(McpProfile.is_active == True)
350
-
354
+
351
355
  query = query.order_by(desc(McpProfile.last_used_at), desc(McpProfile.created_at))
352
356
  query = query.limit(limit).offset(offset)
353
-
357
+
354
358
  result = await db.execute(query)
355
359
  profiles = result.scalars().all()
356
-
360
+
357
361
  # Ensure all attributes are loaded for each profile
358
362
  for profile in profiles:
359
363
  _ = (profile.mcp_id, profile.created_at, profile.updated_at,
360
364
  profile.last_used_at, profile.is_active)
361
-
365
+
362
366
  return profiles
363
367
  except Exception as e:
364
368
  logger.error(f"Failed to list MCP profiles: {e}")
365
369
  raise
366
-
370
+
367
371
  @staticmethod
368
372
  async def get_active_profiles(db: AsyncSession) -> List[McpProfile]:
369
373
  """Get all active MCP profiles"""
@@ -372,41 +376,41 @@ class McpProfileQueries:
372
376
  select(McpProfile).where(McpProfile.is_active == True)
373
377
  )
374
378
  profiles = result.scalars().all()
375
-
379
+
376
380
  # Ensure all attributes are loaded for each profile
377
381
  for profile in profiles:
378
382
  _ = (profile.mcp_id, profile.created_at, profile.updated_at,
379
383
  profile.last_used_at, profile.is_active)
380
-
384
+
381
385
  return profiles
382
386
  except Exception as e:
383
387
  logger.error(f"Failed to get active MCP profiles: {e}")
384
388
  raise
385
-
389
+
386
390
  @staticmethod
387
391
  async def update_profile(
388
- db: AsyncSession,
389
- mcp_id: str,
390
- updates: Dict[str, Any]
392
+ db: AsyncSession,
393
+ mcp_id: str,
394
+ updates: Dict[str, Any]
391
395
  ) -> bool:
392
396
  """Update MCP profile"""
393
397
  try:
394
398
  logger.info(f"Updating profile {mcp_id}")
395
-
399
+
396
400
  result = await db.execute(
397
401
  update(McpProfile)
398
402
  .where(McpProfile.mcp_id == mcp_id)
399
403
  .values(**updates)
400
404
  )
401
-
405
+
402
406
  rows_affected = result.rowcount
403
407
  logger.info(f"Update query affected {rows_affected} rows")
404
-
408
+
405
409
  return rows_affected > 0
406
410
  except Exception as e:
407
411
  logger.error(f"Failed to update MCP profile {mcp_id}: {e}")
408
412
  raise
409
-
413
+
410
414
  @staticmethod
411
415
  async def delete_profile(db: AsyncSession, mcp_id: str) -> bool:
412
416
  """Delete MCP profile"""
@@ -418,7 +422,7 @@ class McpProfileQueries:
418
422
  except Exception as e:
419
423
  logger.error(f"Failed to delete MCP profile {mcp_id}: {e}")
420
424
  raise
421
-
425
+
422
426
  @staticmethod
423
427
  async def update_last_used(db: AsyncSession, mcp_id: str) -> bool:
424
428
  """Update the last_used_at timestamp for a profile"""
@@ -433,30 +437,31 @@ class McpProfileQueries:
433
437
  logger.error(f"Failed to update last_used for MCP profile {mcp_id}: {e}")
434
438
  raise
435
439
 
440
+
436
441
  class TaskQueries:
437
442
  """Database queries for task management with LLM Profile support"""
438
-
443
+
439
444
  @staticmethod
440
445
  async def save_task(
441
- db: AsyncSession,
442
- task_id: str,
443
- session_id: str,
444
- task_description: str,
445
- llm_profile_name: str,
446
- upload_files_path: Optional[str] = None,
447
- workspace_dir: Optional[str] = None,
448
- mcp_server_config: Optional[str] = None, # JSON string
449
- task_result: Optional[str] = None,
450
- task_status: str = "pending",
451
- error_message: Optional[str] = None,
452
- report_path: Optional[str] = None
446
+ db: AsyncSession,
447
+ task_id: str,
448
+ session_id: str,
449
+ task_description: str,
450
+ llm_profile_name: str,
451
+ upload_files_path: Optional[str] = None,
452
+ workspace_dir: Optional[str] = None,
453
+ mcp_server_config: Optional[str] = None, # JSON string
454
+ task_result: Optional[str] = None,
455
+ task_status: str = "pending",
456
+ error_message: Optional[str] = None,
457
+ report_path: Optional[str] = None
453
458
  ) -> Task:
454
459
  """Create or update a task record"""
455
460
  try:
456
461
  # Check if task exists
457
462
  result = await db.execute(select(Task).where(Task.task_id == task_id))
458
463
  existing_task = result.scalar_one_or_none()
459
-
464
+
460
465
  if existing_task:
461
466
  # Update existing task
462
467
  update_data = {}
@@ -472,7 +477,7 @@ class TaskQueries:
472
477
  update_data['started_at'] = func.now()
473
478
  if task_status in ["completed", "failed", "stopped"]:
474
479
  update_data['completed_at'] = func.now()
475
-
480
+
476
481
  await db.execute(
477
482
  update(Task).where(Task.task_id == task_id).values(**update_data)
478
483
  )
@@ -480,15 +485,16 @@ class TaskQueries:
480
485
  return existing_task
481
486
  else:
482
487
  # DEBUG: Log the type and content of mcp_server_config before saving
483
- logger.info(f"Creating task with mcp_server_config type: {type(mcp_server_config)}, value: {mcp_server_config}")
484
-
488
+ logger.info(
489
+ f"Creating task with mcp_server_config type: {type(mcp_server_config)}, value: {mcp_server_config}")
490
+
485
491
  # Serialize mcp_server_config to JSON string if it's a dict
486
492
  if isinstance(mcp_server_config, dict):
487
493
  mcp_server_config_json = json.dumps(mcp_server_config)
488
494
  logger.info(f"Converted dict to JSON string: {mcp_server_config_json}")
489
495
  else:
490
496
  mcp_server_config_json = mcp_server_config
491
-
497
+
492
498
  # Create new task
493
499
  task = Task(
494
500
  task_id=task_id,
@@ -503,16 +509,16 @@ class TaskQueries:
503
509
  error_message=error_message,
504
510
  report_path=report_path
505
511
  )
506
-
512
+
507
513
  db.add(task)
508
514
  await db.flush()
509
515
  await db.refresh(task)
510
516
  return task
511
-
517
+
512
518
  except Exception as e:
513
519
  logger.error(f"Failed to save task {task_id}: {e}")
514
520
  raise
515
-
521
+
516
522
  @staticmethod
517
523
  async def get_task(db: AsyncSession, task_id: str) -> Optional[Task]:
518
524
  """Get task by ID"""
@@ -522,13 +528,13 @@ class TaskQueries:
522
528
  except Exception as e:
523
529
  logger.error(f"Failed to get task {task_id}: {e}")
524
530
  raise
525
-
531
+
526
532
  @staticmethod
527
533
  async def get_tasks_by_session(
528
- db: AsyncSession,
529
- session_id: str,
530
- limit: int = 50,
531
- offset: int = 0
534
+ db: AsyncSession,
535
+ session_id: str,
536
+ limit: int = 50,
537
+ offset: int = 0
532
538
  ) -> List[Task]:
533
539
  """Get all tasks for a session"""
534
540
  try:
@@ -543,28 +549,28 @@ class TaskQueries:
543
549
  except Exception as e:
544
550
  logger.error(f"Failed to get tasks for session {session_id}: {e}")
545
551
  raise
546
-
552
+
547
553
  @staticmethod
548
554
  async def get_recent_tasks(db: AsyncSession, limit: int = -1) -> List[Task]:
549
555
  """Get recent tasks"""
550
556
  try:
551
557
  query = select(Task).order_by(desc(Task.created_at))
552
-
558
+
553
559
  # Handle -1 as "get all records"
554
560
  if limit != -1:
555
561
  query = query.limit(limit)
556
-
562
+
557
563
  result = await db.execute(query)
558
564
  return result.scalars().all()
559
565
  except Exception as e:
560
566
  logger.error(f"Failed to get recent tasks: {e}")
561
567
  raise
562
-
568
+
563
569
  @staticmethod
564
570
  async def get_all_sessions(
565
- db: AsyncSession,
566
- limit: int = -1,
567
- offset: int = 0
571
+ db: AsyncSession,
572
+ limit: int = -1,
573
+ offset: int = 0
568
574
  ) -> List[Dict[str, Any]]:
569
575
  """Get all unique sessions with task counts and metadata"""
570
576
  try:
@@ -576,17 +582,17 @@ class TaskQueries:
576
582
  func.max(Task.created_at).label('last_activity'),
577
583
  func.max(Task.status).label('latest_status')
578
584
  ).group_by(Task.session_id).order_by(desc(func.max(Task.created_at)))
579
-
585
+
580
586
  # Handle -1 as "get all records"
581
587
  if limit != -1:
582
588
  query = query.limit(limit)
583
-
589
+
584
590
  # Always apply offset if provided
585
591
  if offset > 0:
586
592
  query = query.offset(offset)
587
-
593
+
588
594
  result = await db.execute(query)
589
-
595
+
590
596
  sessions = []
591
597
  for row in result.all():
592
598
  sessions.append({
@@ -596,7 +602,7 @@ class TaskQueries:
596
602
  'last_activity': row.last_activity.isoformat() if row.last_activity else None,
597
603
  'status': row.latest_status.value if row.latest_status else 'unknown'
598
604
  })
599
-
605
+
600
606
  return sessions
601
607
  except Exception as e:
602
608
  logger.error(f"Failed to get all sessions: {e}")
@@ -604,40 +610,40 @@ class TaskQueries:
604
610
 
605
611
  @staticmethod
606
612
  async def update_task_status(
607
- db: AsyncSession,
608
- task_id: str,
609
- status: str,
610
- error_message: Optional[str] = None,
611
- task_result: Optional[str] = None,
612
- report_path: Optional[str] = None
613
+ db: AsyncSession,
614
+ task_id: str,
615
+ status: str,
616
+ error_message: Optional[str] = None,
617
+ task_result: Optional[str] = None,
618
+ report_path: Optional[str] = None
613
619
  ) -> bool:
614
620
  """Update task status"""
615
621
  try:
616
622
  update_data = {
617
623
  'status': TaskStatus(status)
618
624
  }
619
-
625
+
620
626
  if status == "running":
621
627
  update_data['started_at'] = func.now()
622
628
  elif status in ["completed", "failed", "stopped"]:
623
629
  update_data['completed_at'] = func.now()
624
-
630
+
625
631
  if error_message:
626
632
  update_data['error_message'] = error_message
627
633
  if task_result:
628
634
  update_data['task_result'] = task_result
629
635
  if report_path:
630
636
  update_data['report_path'] = report_path
631
-
637
+
632
638
  result = await db.execute(
633
639
  update(Task).where(Task.task_id == task_id).values(**update_data)
634
640
  )
635
-
641
+
636
642
  return result.rowcount > 0
637
643
  except Exception as e:
638
644
  logger.error(f"Failed to update task status {task_id}: {e}")
639
645
  raise
640
-
646
+
641
647
  @staticmethod
642
648
  async def delete_task(db: AsyncSession, task_id: str) -> bool:
643
649
  """Delete a task"""
@@ -647,7 +653,7 @@ class TaskQueries:
647
653
  except Exception as e:
648
654
  logger.error(f"Failed to delete task {task_id}: {e}")
649
655
  raise
650
-
656
+
651
657
  @staticmethod
652
658
  async def get_running_tasks(db: AsyncSession) -> List[Task]:
653
659
  """Get all currently running tasks"""
@@ -659,7 +665,7 @@ class TaskQueries:
659
665
  except Exception as e:
660
666
  logger.error(f"Failed to get running tasks: {e}")
661
667
  raise
662
-
668
+
663
669
  @staticmethod
664
670
  async def get_active_task(db: AsyncSession) -> Optional[Task]:
665
671
  """Get currently running task (for single-task model)"""
@@ -671,13 +677,13 @@ class TaskQueries:
671
677
  except Exception as e:
672
678
  logger.error(f"Failed to get active task: {e}")
673
679
  raise
674
-
680
+
675
681
  @staticmethod
676
682
  async def get_tasks_by_llm_profile(
677
- db: AsyncSession,
678
- llm_profile_name: str,
679
- limit: int = 50,
680
- offset: int = 0
683
+ db: AsyncSession,
684
+ llm_profile_name: str,
685
+ limit: int = 50,
686
+ offset: int = 0
681
687
  ) -> List[Task]:
682
688
  """Get tasks that used a specific LLM profile"""
683
689
  try:
@@ -692,15 +698,15 @@ class TaskQueries:
692
698
  except Exception as e:
693
699
  logger.error(f"Failed to get tasks for LLM profile {llm_profile_name}: {e}")
694
700
  raise
695
-
701
+
696
702
  @staticmethod
697
703
  async def update_task_completion(
698
- db: AsyncSession,
699
- task_id: str,
700
- task_result: Optional[str] = None,
701
- task_status: str = "completed",
702
- error_message: Optional[str] = None,
703
- report_path: Optional[str] = None
704
+ db: AsyncSession,
705
+ task_id: str,
706
+ task_result: Optional[str] = None,
707
+ task_status: str = "completed",
708
+ error_message: Optional[str] = None,
709
+ report_path: Optional[str] = None
704
710
  ) -> bool:
705
711
  """Update task completion status and results"""
706
712
  try:
@@ -708,23 +714,23 @@ class TaskQueries:
708
714
  'status': TaskStatus(task_status),
709
715
  'completed_at': func.now()
710
716
  }
711
-
717
+
712
718
  if task_result is not None:
713
719
  update_data['task_result'] = task_result
714
720
  if error_message is not None:
715
721
  update_data['error_message'] = error_message
716
722
  if report_path is not None:
717
723
  update_data['report_path'] = report_path
718
-
724
+
719
725
  result = await db.execute(
720
726
  update(Task).where(Task.task_id == task_id).values(**update_data)
721
727
  )
722
-
728
+
723
729
  return result.rowcount > 0
724
730
  except Exception as e:
725
731
  logger.error(f"Failed to update task completion {task_id}: {e}")
726
732
  raise
727
-
733
+
728
734
  @staticmethod
729
735
  async def get_task_counts_by_status(db: AsyncSession) -> Dict[str, int]:
730
736
  """Get count of tasks by status"""
@@ -733,30 +739,31 @@ class TaskQueries:
733
739
  select(Task.status, func.count(Task.task_id))
734
740
  .group_by(Task.status)
735
741
  )
736
-
742
+
737
743
  counts = {}
738
744
  for status, count in result.all():
739
745
  counts[status.value] = count
740
-
746
+
741
747
  return counts
742
748
  except Exception as e:
743
749
  logger.error(f"Failed to get task counts by status: {e}")
744
750
  raise
745
751
 
752
+
746
753
  class UploadedFileQueries:
747
754
  """Query operations for UploadedFile model"""
748
-
755
+
749
756
  @staticmethod
750
757
  async def create_file_record(
751
- db: AsyncSession,
752
- file_id: str,
753
- original_filename: str,
754
- stored_filename: str,
755
- file_path: str,
756
- session_id: Optional[str],
757
- file_size: int,
758
- mime_type: str,
759
- relative_path: str
758
+ db: AsyncSession,
759
+ file_id: str,
760
+ original_filename: str,
761
+ stored_filename: str,
762
+ file_path: str,
763
+ session_id: Optional[str],
764
+ file_size: int,
765
+ mime_type: str,
766
+ relative_path: str
760
767
  ) -> UploadedFile:
761
768
  """Create a new uploaded file record"""
762
769
  try:
@@ -770,7 +777,7 @@ class UploadedFileQueries:
770
777
  mime_type=mime_type,
771
778
  relative_path=relative_path
772
779
  )
773
-
780
+
774
781
  db.add(uploaded_file)
775
782
  await db.flush()
776
783
  await db.refresh(uploaded_file)
@@ -778,7 +785,7 @@ class UploadedFileQueries:
778
785
  except Exception as e:
779
786
  logger.error(f"Failed to create file record {file_id}: {e}")
780
787
  raise
781
-
788
+
782
789
  @staticmethod
783
790
  async def get_file(db: AsyncSession, file_id: str) -> Optional[UploadedFile]:
784
791
  """Get uploaded file by ID"""
@@ -792,63 +799,63 @@ class UploadedFileQueries:
792
799
  except Exception as e:
793
800
  logger.error(f"Failed to get file {file_id}: {e}")
794
801
  raise
795
-
802
+
796
803
  @staticmethod
797
804
  async def list_files(
798
- db: AsyncSession,
799
- session_id: Optional[str] = None,
800
- limit: int = -1,
801
- offset: int = 0,
802
- active_only: bool = True
805
+ db: AsyncSession,
806
+ session_id: Optional[str] = None,
807
+ limit: int = -1,
808
+ offset: int = 0,
809
+ active_only: bool = True
803
810
  ) -> List[UploadedFile]:
804
811
  """List uploaded files with optional filtering"""
805
812
  try:
806
813
  query = select(UploadedFile)
807
-
814
+
808
815
  if active_only:
809
816
  query = query.where(UploadedFile.is_deleted == False)
810
-
817
+
811
818
  if session_id is not None:
812
819
  query = query.where(UploadedFile.session_id == session_id)
813
-
820
+
814
821
  query = query.order_by(desc(UploadedFile.upload_time))
815
-
822
+
816
823
  # Handle -1 as "get all records"
817
824
  if limit != -1:
818
825
  query = query.limit(limit)
819
-
826
+
820
827
  # Always apply offset if provided
821
828
  if offset > 0:
822
829
  query = query.offset(offset)
823
-
830
+
824
831
  result = await db.execute(query)
825
832
  return result.scalars().all()
826
833
  except Exception as e:
827
834
  logger.error(f"Failed to list files: {e}")
828
835
  raise
829
-
836
+
830
837
  @staticmethod
831
838
  async def count_files(
832
- db: AsyncSession,
833
- session_id: Optional[str] = None,
834
- active_only: bool = True
839
+ db: AsyncSession,
840
+ session_id: Optional[str] = None,
841
+ active_only: bool = True
835
842
  ) -> int:
836
843
  """Count uploaded files with optional filtering"""
837
844
  try:
838
845
  query = select(func.count(UploadedFile.file_id))
839
-
846
+
840
847
  if active_only:
841
848
  query = query.where(UploadedFile.is_deleted == False)
842
-
849
+
843
850
  if session_id is not None:
844
851
  query = query.where(UploadedFile.session_id == session_id)
845
-
852
+
846
853
  result = await db.execute(query)
847
854
  return result.scalar() or 0
848
855
  except Exception as e:
849
856
  logger.error(f"Failed to count files: {e}")
850
857
  raise
851
-
858
+
852
859
  @staticmethod
853
860
  async def delete_file(db: AsyncSession, file_id: str) -> bool:
854
861
  """Soft delete uploaded file by marking as deleted"""
@@ -862,7 +869,7 @@ class UploadedFileQueries:
862
869
  except Exception as e:
863
870
  logger.error(f"Failed to delete file {file_id}: {e}")
864
871
  raise
865
-
872
+
866
873
  @staticmethod
867
874
  async def hard_delete_file(db: AsyncSession, file_id: str) -> bool:
868
875
  """Permanently delete uploaded file record"""
@@ -874,13 +881,13 @@ class UploadedFileQueries:
874
881
  except Exception as e:
875
882
  logger.error(f"Failed to hard delete file {file_id}: {e}")
876
883
  raise
877
-
884
+
878
885
  @staticmethod
879
886
  async def get_files_by_session(
880
- db: AsyncSession,
881
- session_id: str,
882
- limit: int = -1,
883
- offset: int = 0
887
+ db: AsyncSession,
888
+ session_id: str,
889
+ limit: int = -1,
890
+ offset: int = 0
884
891
  ) -> List[UploadedFile]:
885
892
  """Get all uploaded files for a specific session"""
886
893
  try:
@@ -888,27 +895,27 @@ class UploadedFileQueries:
888
895
  UploadedFile.session_id == session_id,
889
896
  UploadedFile.is_deleted == False
890
897
  )).order_by(desc(UploadedFile.upload_time))
891
-
898
+
892
899
  # Handle -1 as "get all records"
893
900
  if limit != -1:
894
901
  query = query.limit(limit)
895
-
902
+
896
903
  # Always apply offset if provided
897
904
  if offset > 0:
898
905
  query = query.offset(offset)
899
-
906
+
900
907
  result = await db.execute(query)
901
908
  return result.scalars().all()
902
909
  except Exception as e:
903
910
  logger.error(f"Failed to get files for session {session_id}: {e}")
904
911
  raise
905
-
912
+
906
913
  @staticmethod
907
914
  async def cleanup_deleted_files(db: AsyncSession, days_old: int = 30) -> int:
908
915
  """Clean up files marked as deleted for more than specified days"""
909
916
  try:
910
917
  cutoff_date = func.now() - func.make_interval(days=days_old)
911
-
918
+
912
919
  result = await db.execute(
913
920
  delete(UploadedFile)
914
921
  .where(and_(
@@ -919,4 +926,4 @@ class UploadedFileQueries:
919
926
  return result.rowcount
920
927
  except Exception as e:
921
928
  logger.error(f"Failed to cleanup deleted files: {e}")
922
- raise
929
+ raise