service-forge 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (83) hide show
  1. service_forge/api/deprecated_websocket_api.py +86 -0
  2. service_forge/api/deprecated_websocket_manager.py +425 -0
  3. service_forge/api/http_api.py +152 -0
  4. service_forge/api/http_api_doc.py +455 -0
  5. service_forge/api/kafka_api.py +126 -0
  6. service_forge/api/routers/feedback/feedback_router.py +148 -0
  7. service_forge/api/routers/service/service_router.py +127 -0
  8. service_forge/api/routers/websocket/websocket_manager.py +83 -0
  9. service_forge/api/routers/websocket/websocket_router.py +78 -0
  10. service_forge/api/task_manager.py +141 -0
  11. service_forge/current_service.py +14 -0
  12. service_forge/db/__init__.py +1 -0
  13. service_forge/db/database.py +237 -0
  14. service_forge/db/migrations/feedback_migration.py +154 -0
  15. service_forge/db/models/__init__.py +0 -0
  16. service_forge/db/models/feedback.py +33 -0
  17. service_forge/llm/__init__.py +67 -0
  18. service_forge/llm/llm.py +56 -0
  19. service_forge/model/__init__.py +0 -0
  20. service_forge/model/feedback.py +30 -0
  21. service_forge/model/websocket.py +13 -0
  22. service_forge/proto/foo_input.py +5 -0
  23. service_forge/service.py +280 -0
  24. service_forge/service_config.py +44 -0
  25. service_forge/sft/cli.py +91 -0
  26. service_forge/sft/cmd/config_command.py +67 -0
  27. service_forge/sft/cmd/deploy_service.py +123 -0
  28. service_forge/sft/cmd/list_tars.py +41 -0
  29. service_forge/sft/cmd/service_command.py +149 -0
  30. service_forge/sft/cmd/upload_service.py +36 -0
  31. service_forge/sft/config/injector.py +129 -0
  32. service_forge/sft/config/injector_default_files.py +131 -0
  33. service_forge/sft/config/sf_metadata.py +30 -0
  34. service_forge/sft/config/sft_config.py +200 -0
  35. service_forge/sft/file/__init__.py +0 -0
  36. service_forge/sft/file/ignore_pattern.py +80 -0
  37. service_forge/sft/file/sft_file_manager.py +107 -0
  38. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  39. service_forge/sft/util/assert_util.py +25 -0
  40. service_forge/sft/util/logger.py +16 -0
  41. service_forge/sft/util/name_util.py +8 -0
  42. service_forge/sft/util/yaml_utils.py +57 -0
  43. service_forge/storage/__init__.py +5 -0
  44. service_forge/storage/feedback_storage.py +245 -0
  45. service_forge/utils/__init__.py +0 -0
  46. service_forge/utils/default_type_converter.py +12 -0
  47. service_forge/utils/register.py +39 -0
  48. service_forge/utils/type_converter.py +99 -0
  49. service_forge/utils/workflow_clone.py +124 -0
  50. service_forge/workflow/__init__.py +1 -0
  51. service_forge/workflow/context.py +14 -0
  52. service_forge/workflow/edge.py +24 -0
  53. service_forge/workflow/node.py +184 -0
  54. service_forge/workflow/nodes/__init__.py +8 -0
  55. service_forge/workflow/nodes/control/if_node.py +29 -0
  56. service_forge/workflow/nodes/control/switch_node.py +28 -0
  57. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  58. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  59. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  60. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  61. service_forge/workflow/nodes/output/print_node.py +29 -0
  62. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  63. service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
  64. service_forge/workflow/port.py +89 -0
  65. service_forge/workflow/trigger.py +28 -0
  66. service_forge/workflow/triggers/__init__.py +6 -0
  67. service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
  68. service_forge/workflow/triggers/fast_api_trigger.py +201 -0
  69. service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
  70. service_forge/workflow/triggers/once_trigger.py +23 -0
  71. service_forge/workflow/triggers/period_trigger.py +29 -0
  72. service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
  73. service_forge/workflow/workflow.py +227 -0
  74. service_forge/workflow/workflow_callback.py +141 -0
  75. service_forge/workflow/workflow_config.py +66 -0
  76. service_forge/workflow/workflow_event.py +15 -0
  77. service_forge/workflow/workflow_factory.py +246 -0
  78. service_forge/workflow/workflow_group.py +51 -0
  79. service_forge/workflow/workflow_type.py +52 -0
  80. service_forge-0.1.18.dist-info/METADATA +98 -0
  81. service_forge-0.1.18.dist-info/RECORD +83 -0
  82. service_forge-0.1.18.dist-info/WHEEL +4 -0
  83. service_forge-0.1.18.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,148 @@
1
+ import os
2
+ import httpx
3
+ from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
4
+ from loguru import logger
5
+ from typing import Optional
6
+
7
+ from service_forge.model.feedback import FeedbackCreate, FeedbackResponse, FeedbackListResponse
8
+ from service_forge.storage.feedback_storage import feedback_storage
9
+ from service_forge.db.database import create_database_manager
10
+ from service_forge.current_service import get_service
11
+
12
+ router = APIRouter(prefix="/sdk/feedback", tags=["feedback"])
13
+
14
+ def get_forward_api_url():
15
+ if not get_service():
16
+ return None
17
+ if not get_service().config.feedback:
18
+ return None
19
+ return get_service().config.feedback.api_url
20
+
21
+ def get_forward_api_timeout():
22
+ if not get_service():
23
+ return None
24
+ if not get_service().config.feedback:
25
+ return None
26
+ return get_service().config.feedback.api_timeout
27
+
28
+ async def forward_feedback_to_api(feedback_data: dict):
29
+ """
30
+ 将反馈数据转发到外部 API
31
+
32
+ Args:
33
+ feedback_data: 反馈数据字典
34
+ """
35
+ forward_api_url = get_forward_api_url()
36
+ forward_api_timeout = get_forward_api_timeout()
37
+ print(forward_api_url)
38
+ print(forward_api_timeout)
39
+ if not forward_api_url:
40
+ logger.debug("未配置转发 API URL,跳过转发")
41
+ return
42
+ try:
43
+ # 处理 datetime 对象,转换为 ISO 格式字符串
44
+ serializable_data = feedback_data.copy()
45
+ if 'created_at' in serializable_data and serializable_data['created_at']:
46
+ serializable_data['created_at'] = serializable_data['created_at'].isoformat()
47
+
48
+ async with httpx.AsyncClient(timeout=forward_api_timeout) as client:
49
+ response = await client.post(
50
+ forward_api_url,
51
+ json=serializable_data,
52
+ headers={"Content-Type": "application/json"}
53
+ )
54
+ response.raise_for_status()
55
+ logger.info(f"反馈转发成功: feedback_id={feedback_data.get('feedback_id')}, status={response.status_code}")
56
+ except httpx.TimeoutException:
57
+ logger.warning(f"反馈转发超时: {forward_api_url}")
58
+ except httpx.ConnectError as e:
59
+ logger.warning(f"反馈转发连接失败: {forward_api_url} - {e}")
60
+ except httpx.HTTPStatusError as e:
61
+ logger.error(f"反馈转发失败: status={e.response.status_code}, detail={e.response.text}")
62
+ except Exception as e:
63
+ logger.error(f"反馈转发异常: {type(e).__name__}: {e}")
64
+
65
+
66
+
67
+ @router.post("/", response_model=FeedbackResponse, summary="创建工作流反馈")
68
+ async def create_feedback(feedback: FeedbackCreate, background_tasks: BackgroundTasks):
69
+ """
70
+ 创建工作流执行完成后的用户反馈
71
+ - **task_id**: 工作流任务ID - workflow的id
72
+ - **workflow_name**: 工作流名称 - workflow的名称
73
+ - **rating**: 可选的评分 (1-5) - 反馈中的一种,评分,可以为空
74
+ - **comment**: 可选的用户评论 - 反馈中的一种,可以为空
75
+ - **metadata**: 可选的额外元数据
76
+ 还少什么? - trace_id?
77
+ """
78
+ try:
79
+ # 保存到数据库
80
+ feedback_data = await feedback_storage.create_feedback(
81
+ task_id=feedback.task_id,
82
+ workflow_name=feedback.workflow_name,
83
+ rating=feedback.rating,
84
+ comment=feedback.comment,
85
+ metadata=feedback.metadata,
86
+ )
87
+
88
+ # 后台任务转发到外部 API (不阻塞响应)
89
+ background_tasks.add_task(forward_feedback_to_api, feedback_data)
90
+
91
+ return FeedbackResponse(**feedback_data)
92
+ except Exception as e:
93
+ logger.error(f"创建反馈失败: {e}")
94
+ raise HTTPException(status_code=500, detail=f"创建反馈失败: {str(e)}")
95
+
96
+
97
+ @router.get("/{feedback_id}", response_model=FeedbackResponse, summary="获取单个反馈")
98
+ async def get_feedback(feedback_id: str):
99
+ """
100
+ 根据反馈ID获取反馈详情
101
+
102
+ - **feedback_id**: 反馈ID
103
+ """
104
+ feedback = await feedback_storage.get_feedback(feedback_id)
105
+ if not feedback:
106
+ raise HTTPException(status_code=404, detail="反馈不存在")
107
+ return FeedbackResponse(**feedback)
108
+
109
+
110
+ @router.get("/", response_model=FeedbackListResponse, summary="获取反馈列表")
111
+ async def list_feedbacks(
112
+ task_id: Optional[str] = Query(None, description="按任务ID筛选"),
113
+ workflow_name: Optional[str] = Query(None, description="按工作流名称筛选"),
114
+ ):
115
+ """
116
+ 获取反馈列表,支持按任务ID或工作流名称筛选
117
+
118
+ - **task_id**: 可选,按任务ID筛选
119
+ - **workflow_name**: 可选,按工作流名称筛选
120
+ """
121
+ try:
122
+ if task_id:
123
+ feedbacks = await feedback_storage.get_feedbacks_by_task(task_id)
124
+ elif workflow_name:
125
+ feedbacks = await feedback_storage.get_feedbacks_by_workflow(workflow_name)
126
+ else:
127
+ feedbacks = await feedback_storage.get_all_feedbacks()
128
+
129
+ return FeedbackListResponse(
130
+ total=len(feedbacks),
131
+ feedbacks=[FeedbackResponse(**f) for f in feedbacks]
132
+ )
133
+ except Exception as e:
134
+ logger.error(f"获取反馈列表失败: {e}")
135
+ raise HTTPException(status_code=500, detail=f"获取反馈列表失败: {str(e)}")
136
+
137
+
138
+ @router.delete("/{feedback_id}", summary="删除反馈")
139
+ async def delete_feedback(feedback_id: str):
140
+ """
141
+ 删除指定的反馈
142
+
143
+ - **feedback_id**: 反馈ID
144
+ """
145
+ success = await feedback_storage.delete_feedback(feedback_id)
146
+ if not success:
147
+ raise HTTPException(status_code=404, detail="反馈不存在")
148
+ return {"message": "反馈删除成功", "feedback_id": feedback_id}
@@ -0,0 +1,127 @@
1
+ import os
2
+ import uuid
3
+ import tempfile
4
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
5
+ from fastapi.responses import JSONResponse
6
+ from loguru import logger
7
+ from typing import Optional, TYPE_CHECKING
8
+ from pydantic import BaseModel
9
+ from omegaconf import OmegaConf
10
+ from service_forge.current_service import get_service
11
+
12
+ service_router = APIRouter(prefix="/sdk/service", tags=["service"])
13
+
14
+ class WorkflowStatusResponse(BaseModel):
15
+ name: str
16
+ version: str
17
+ description: str
18
+ workflows: list[dict]
19
+
20
+ class WorkflowActionResponse(BaseModel):
21
+ workflow_id: str
22
+ success: bool
23
+ message: str
24
+
25
+ @service_router.get("/status", response_model=WorkflowStatusResponse)
26
+ async def get_service_status():
27
+ service = get_service()
28
+ if service is None:
29
+ raise HTTPException(status_code=503, detail="Service not initialized")
30
+
31
+ try:
32
+ status = service.get_service_status()
33
+ return status
34
+ except Exception as e:
35
+ logger.error(f"Error getting service status: {e}")
36
+ raise HTTPException(status_code=500, detail=str(e))
37
+
38
+ @service_router.post("/workflow/{workflow_id}/start", response_model=WorkflowActionResponse)
39
+ async def start_workflow(workflow_id: str):
40
+ service = get_service()
41
+ if service is None:
42
+ raise HTTPException(status_code=503, detail="Service not initialized")
43
+
44
+ try:
45
+ success = service.start_workflow_by_id(uuid.UUID(workflow_id))
46
+ if success:
47
+ return WorkflowActionResponse(success=True, message=f"Workflow {workflow_id} started successfully")
48
+ else:
49
+ return WorkflowActionResponse(success=False, message=f"Failed to start workflow {workflow_id}")
50
+ except Exception as e:
51
+ logger.error(f"Error starting workflow {workflow_id}: {e}")
52
+ raise HTTPException(status_code=500, detail=str(e))
53
+
54
+ @service_router.post("/workflow/{workflow_id}/stop", response_model=WorkflowActionResponse)
55
+ async def stop_workflow(workflow_id: str):
56
+ service = get_service()
57
+ if service is None:
58
+ raise HTTPException(status_code=503, detail="Service not initialized")
59
+
60
+ try:
61
+ success = await service.stop_workflow_by_id(uuid.UUID(workflow_id))
62
+ if success:
63
+ return WorkflowActionResponse(success=True, message=f"Workflow {workflow_id} stopped successfully")
64
+ else:
65
+ return WorkflowActionResponse(success=False, message=f"Failed to stop workflow {workflow_id}")
66
+ except Exception as e:
67
+ logger.error(f"Error stopping workflow {workflow_id}: {e}")
68
+ raise HTTPException(status_code=500, detail=str(e))
69
+
70
+ @service_router.post("/workflow/upload", response_model=WorkflowActionResponse)
71
+ async def upload_workflow_config(
72
+ file: Optional[UploadFile] = File(None),
73
+ config_content: Optional[str] = Form(None),
74
+ ):
75
+ service = get_service()
76
+ if service is None:
77
+ raise HTTPException(status_code=503, detail="Service not initialized")
78
+
79
+ if file is None and config_content is None:
80
+ raise HTTPException(status_code=400, detail="Either file or config_content must be provided")
81
+
82
+ if file is not None and config_content is not None:
83
+ raise HTTPException(status_code=400, detail="Cannot provide both file and config_content")
84
+
85
+ temp_file_path = None
86
+ try:
87
+ if file is not None:
88
+ if not file.filename or not file.filename.endswith(('.yaml', '.yml')):
89
+ raise HTTPException(status_code=400, detail="Only YAML files are supported")
90
+
91
+ suffix = '.yaml' if file.filename.endswith('.yaml') else '.yml'
92
+ with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix=suffix) as temp_file:
93
+ content = await file.read()
94
+ temp_file.write(content)
95
+ temp_file_path = temp_file.name
96
+
97
+ workflow_id = await service.load_workflow_from_config(config_path=temp_file_path)
98
+ else:
99
+ try:
100
+ config = OmegaConf.to_object(OmegaConf.create(config_content))
101
+ except Exception as e:
102
+ raise HTTPException(status_code=400, detail=f"Invalid YAML format: {str(e)}")
103
+
104
+ workflow_id = await service.load_workflow_from_config(config=config)
105
+
106
+ if workflow_id:
107
+ return WorkflowActionResponse(
108
+ workflow_id=str(workflow_id),
109
+ success=True,
110
+ message=f"Workflow configuration uploaded and loaded successfully"
111
+ )
112
+ else:
113
+ raise HTTPException(status_code=500, detail="Failed to load workflow configuration")
114
+
115
+ except HTTPException:
116
+ raise
117
+ except Exception as e:
118
+ logger.error(f"Error uploading workflow config: {e}")
119
+ raise HTTPException(status_code=500, detail=str(e))
120
+
121
+ finally:
122
+ if temp_file_path and os.path.exists(temp_file_path):
123
+ try:
124
+ os.unlink(temp_file_path)
125
+ except Exception as e:
126
+ logger.warning(f"Failed to delete temp file {temp_file_path}: {e}")
127
+
@@ -0,0 +1,83 @@
1
+ from fastapi import WebSocket
2
+ from typing import Dict, Set
3
+ import uuid
4
+ import json
5
+ import asyncio
6
+ from loguru import logger
7
+
8
+ class WebSocketManager:
9
+ def __init__(self):
10
+ self.task_connections: Dict[uuid.UUID, Set[WebSocket]] = {}
11
+ self.websocket_tasks: Dict[WebSocket, Set[uuid.UUID]] = {}
12
+ self.all_task_subscribers: Set[WebSocket] = set()
13
+ self.websocket_subscribes_all: Dict[WebSocket, bool] = {}
14
+ self._lock = asyncio.Lock()
15
+
16
+ async def subscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
17
+ async with self._lock:
18
+ if task_id is None:
19
+ self.all_task_subscribers.add(websocket)
20
+ self.websocket_subscribes_all[websocket] = True
21
+ else:
22
+ if task_id not in self.task_connections:
23
+ self.task_connections[task_id] = set()
24
+ self.task_connections[task_id].add(websocket)
25
+
26
+ if websocket not in self.websocket_tasks:
27
+ self.websocket_tasks[websocket] = set()
28
+ self.websocket_tasks[websocket].add(task_id)
29
+ return True
30
+
31
+ async def unsubscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
32
+ async with self._lock:
33
+ if task_id is None:
34
+ self.all_task_subscribers.discard(websocket)
35
+ self.websocket_subscribes_all.pop(websocket, None)
36
+ else:
37
+ if task_id in self.task_connections:
38
+ self.task_connections[task_id].discard(websocket)
39
+ if not self.task_connections[task_id]:
40
+ del self.task_connections[task_id]
41
+
42
+ if websocket in self.websocket_tasks:
43
+ self.websocket_tasks[websocket].discard(task_id)
44
+ if not self.websocket_tasks[websocket]:
45
+ del self.websocket_tasks[websocket]
46
+ return True
47
+
48
+ async def disconnect(self, websocket: WebSocket):
49
+ async with self._lock:
50
+ if websocket in self.websocket_tasks:
51
+ task_ids = self.websocket_tasks[websocket].copy()
52
+ for task_id in task_ids:
53
+ if task_id in self.task_connections:
54
+ self.task_connections[task_id].discard(websocket)
55
+ if not self.task_connections[task_id]:
56
+ del self.task_connections[task_id]
57
+ del self.websocket_tasks[websocket]
58
+
59
+ self.all_task_subscribers.discard(websocket)
60
+ self.websocket_subscribes_all.pop(websocket, None)
61
+
62
+ async def send_to_task(self, task_id: uuid.UUID, message: dict):
63
+ async with self._lock:
64
+ specific_subscribers = list(self.task_connections.get(task_id, set()))
65
+ all_subscribers = list(self.all_task_subscribers)
66
+ websockets = list(set(specific_subscribers + all_subscribers))
67
+
68
+ disconnected = set()
69
+ message_str = json.dumps(message)
70
+
71
+ # logger.debug(f"向 {len(websockets)} 个 websocket 连接发送消息 (task_id: {task_id})")
72
+
73
+ for websocket in websockets:
74
+ try:
75
+ await websocket.send_text(message_str)
76
+ except Exception as e:
77
+ logger.error(f"向 task_id {task_id} 的 websocket 发送消息失败: {e}")
78
+ disconnected.add(websocket)
79
+
80
+ for ws in disconnected:
81
+ await self.disconnect(ws)
82
+
83
+ websocket_manager = WebSocketManager()
@@ -0,0 +1,78 @@
1
+ from fastapi import WebSocket, WebSocketDisconnect
2
+ from fastapi.routing import APIRouter
3
+ from loguru import logger
4
+ import json
5
+ import uuid
6
+ from .websocket_manager import websocket_manager
7
+
8
+ websocket_router = APIRouter()
9
+
10
+ @websocket_router.websocket("/sdk/ws")
11
+ async def sdk_websocket_endpoint(websocket: WebSocket):
12
+ await websocket.accept()
13
+ try:
14
+ while True:
15
+ data = await websocket.receive_text()
16
+ try:
17
+ message = json.loads(data)
18
+ message_type = message.get("type")
19
+
20
+ if message_type == "subscribe":
21
+ task_id_str = message.get("task_id")
22
+ if not task_id_str:
23
+ await websocket.send_text(
24
+ json.dumps({"error": "Missing task_id in subscribe message"})
25
+ )
26
+ continue
27
+
28
+ if task_id_str.lower() == "all":
29
+ success = await websocket_manager.subscribe(websocket, None)
30
+ response = {"success": success, "type": "subscribe_response", "task_id": "all"}
31
+ await websocket.send_text(json.dumps(response))
32
+ else:
33
+ try:
34
+ task_id = uuid.UUID(task_id_str)
35
+ success = await websocket_manager.subscribe(websocket, task_id)
36
+ response = {"success": success, "type": "subscribe_response", "task_id": task_id_str}
37
+ await websocket.send_text(json.dumps(response))
38
+ except ValueError:
39
+ await websocket.send_text(
40
+ json.dumps({"error": "Invalid task_id format"})
41
+ )
42
+ elif message_type == "unsubscribe":
43
+ task_id_str = message.get("task_id")
44
+ if not task_id_str:
45
+ await websocket.send_text(
46
+ json.dumps({"error": "Missing task_id in unsubscribe message"})
47
+ )
48
+ continue
49
+
50
+ if task_id_str.lower() == "all":
51
+ success = await websocket_manager.unsubscribe(websocket, None)
52
+ response = {"success": success, "type": "unsubscribe_response", "task_id": "all"}
53
+ await websocket.send_text(json.dumps(response))
54
+ else:
55
+ try:
56
+ task_id = uuid.UUID(task_id_str)
57
+ success = await websocket_manager.unsubscribe(websocket, task_id)
58
+ response = {"success": success, "type": "unsubscribe_response", "task_id": task_id_str}
59
+ await websocket.send_text(json.dumps(response))
60
+ except ValueError:
61
+ await websocket.send_text(
62
+ json.dumps({"error": "Invalid task_id format"})
63
+ )
64
+ else:
65
+ await websocket.send_text(
66
+ json.dumps({"error": f"Unknown message type: {message_type}"})
67
+ )
68
+ except json.JSONDecodeError:
69
+ logger.error(f"收到无效JSON消息: {data}")
70
+ await websocket.send_text(
71
+ json.dumps({"error": "Invalid JSON format"})
72
+ )
73
+ except WebSocketDisconnect:
74
+ await websocket_manager.disconnect(websocket)
75
+ except Exception as e:
76
+ logger.error(f"SDK WebSocket连接处理异常: {e}")
77
+ await websocket_manager.disconnect(websocket)
78
+
@@ -0,0 +1,141 @@
1
+
2
+ from __future__ import annotations
3
+ import asyncio
4
+ import uuid
5
+ import datetime
6
+ from typing import Dict, List, Set, Any, Optional
7
+
8
+ class TaskManager:
9
+ """任务管理器,用于跟踪任务状态和队列信息"""
10
+
11
+ def __init__(self):
12
+ # 存储所有任务信息: {task_id: task_info}
13
+ self.tasks: Dict[uuid.UUID, Dict[str, Any]] = {}
14
+ # 任务队列,按添加顺序排列
15
+ self.task_queue: List[uuid.UUID] = []
16
+ # 正在执行的任务ID集合
17
+ self.running_tasks: Set[uuid.UUID] = set()
18
+ # 已完成的任务ID集合
19
+ self.completed_tasks: Set[uuid.UUID] = set()
20
+ # 客户端与任务的映射: {client_id: set(task_id)}
21
+ self.client_tasks: Dict[str, Set[uuid.UUID]] = {}
22
+
23
+ def add_task(self, task_id: uuid.UUID, client_id: str, workflow_name: str, steps: int) -> Dict[str, Any]:
24
+ """添加新任务到队列"""
25
+ current_time = asyncio.get_event_loop().time()
26
+ task_info = {
27
+ "task_id": task_id,
28
+ "client_id": client_id,
29
+ "workflow_name": workflow_name,
30
+ "steps": steps,
31
+ "current_step": 0, # 当前步骤,从0开始
32
+ "status": "pending", # pending, running, completed, failed
33
+ "created_at": datetime.datetime.fromtimestamp(current_time).isoformat(),
34
+ "queue_position": len(self.task_queue) + 1
35
+ }
36
+ self.tasks[task_id] = task_info
37
+ self.task_queue.append(task_id)
38
+
39
+ # 更新客户端任务映射
40
+ if client_id not in self.client_tasks:
41
+ self.client_tasks[client_id] = set()
42
+ self.client_tasks[client_id].add(task_id)
43
+
44
+ return task_info
45
+
46
+ def start_task(self, task_id: uuid.UUID) -> bool:
47
+ """标记任务开始执行"""
48
+ if task_id not in self.tasks:
49
+ return False
50
+
51
+ current_time = asyncio.get_event_loop().time()
52
+ self.tasks[task_id]["status"] = "running"
53
+ self.tasks[task_id]["started_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
54
+ self.tasks[task_id]["current_step"] = 1 # 开始执行第一步
55
+ self.running_tasks.add(task_id)
56
+
57
+ # 从队列中移除
58
+ if task_id in self.task_queue:
59
+ self.task_queue.remove(task_id)
60
+
61
+ # 更新队列中所有任务的位置
62
+ for i, q_task_id in enumerate(self.task_queue):
63
+ self.tasks[q_task_id]["queue_position"] = i + 1
64
+
65
+ return True
66
+
67
+ def complete_task(self, task_id: uuid.UUID) -> bool:
68
+ """标记任务完成"""
69
+ if task_id not in self.tasks:
70
+ return False
71
+
72
+ current_time = asyncio.get_event_loop().time()
73
+ self.tasks[task_id]["status"] = "completed"
74
+ self.tasks[task_id]["completed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
75
+ self.tasks[task_id]["current_step"] = self.tasks[task_id]["steps"] # 完成所有步骤
76
+ self.running_tasks.discard(task_id)
77
+ self.completed_tasks.add(task_id)
78
+
79
+ return True
80
+
81
+ def fail_task(self, task_id: uuid.UUID, error: str = None) -> bool:
82
+ """标记任务失败"""
83
+ if task_id not in self.tasks:
84
+ return False
85
+
86
+ current_time = asyncio.get_event_loop().time()
87
+ self.tasks[task_id]["status"] = "failed"
88
+ self.tasks[task_id]["failed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
89
+ if error:
90
+ self.tasks[task_id]["error"] = error
91
+ self.running_tasks.discard(task_id)
92
+
93
+ return True
94
+
95
+ def get_client_tasks(self, client_id: str) -> List[Dict[str, Any]]:
96
+ """获取客户端的所有任务"""
97
+ if client_id not in self.client_tasks:
98
+ return []
99
+
100
+ return [
101
+ self.tasks[task_id]
102
+ for task_id in self.client_tasks[client_id]
103
+ if task_id in self.tasks
104
+ ]
105
+
106
+ def get_queue_position(self, task_id: uuid.UUID) -> int:
107
+ """获取任务在队列中的位置,从1开始计数,如果不在队列中返回-1"""
108
+ if task_id not in self.tasks:
109
+ return -1
110
+
111
+ return self.tasks[task_id].get("queue_position", -1)
112
+
113
+ def get_global_queue_info(self) -> Dict[str, int]:
114
+ """获取全局队列信息"""
115
+ return {
116
+ "total": len(self.running_tasks) + len(self.task_queue),
117
+ "waiting": len(self.task_queue),
118
+ "running": len(self.running_tasks),
119
+ }
120
+
121
+ def get_task_info(self, task_id: uuid.UUID) -> Optional[Dict[str, Any]]:
122
+ """获取特定任务的详细信息"""
123
+ task_info = self.tasks.get(task_id)
124
+ if task_info:
125
+ # 创建任务信息的副本,并将UUID转换为字符串
126
+ task_copy = task_info.copy()
127
+ task_copy["task_id"] = str(task_id)
128
+ return task_copy
129
+ return None
130
+
131
+ def update_current_step(self, task_id: uuid.UUID, step: int) -> bool:
132
+ """更新当前任务的步骤"""
133
+ if task_id not in self.tasks:
134
+ return False
135
+
136
+ # 确保步骤在有效范围内
137
+ if step < 0 or step > self.tasks[task_id]["steps"]:
138
+ return False
139
+
140
+ self.tasks[task_id]["current_step"] = step
141
+ return True
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING
3
+
4
+ if TYPE_CHECKING:
5
+ from service_forge.service import Service
6
+
7
+ _current_service: Service | None = None
8
+
9
+ def set_service(service: Service) -> None:
10
+ global _current_service
11
+ _current_service = service
12
+
13
+ def get_service() -> Service | None:
14
+ return _current_service
@@ -0,0 +1 @@
1
+ from .database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase