service-forge 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +148 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/service/__init__.py +4 -0
- service_forge/api/routers/service/service_router.py +137 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +240 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +288 -0
- service_forge/service_config.py +158 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +119 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +153 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +24 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
- service_forge/workflow/triggers/fast_api_trigger.py +169 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
- service_forge/workflow/workflow.py +210 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +27 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.11.dist-info/METADATA +98 -0
- service_forge-0.1.11.dist-info/RECORD +75 -0
- service_forge-0.1.11.dist-info/WHEEL +4 -0
- service_forge-0.1.11.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
|
2
|
+
from fastapi.responses import JSONResponse
|
|
3
|
+
from loguru import logger
|
|
4
|
+
from typing import Optional, TYPE_CHECKING
|
|
5
|
+
import tempfile
|
|
6
|
+
import os
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from omegaconf import OmegaConf
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from service_forge.service import Service
|
|
12
|
+
|
|
13
|
+
# TODO: refactor this, do not use global variable
|
|
14
|
+
_current_service: Optional['Service'] = None
|
|
15
|
+
|
|
16
|
+
def set_service(service: 'Service') -> None:
|
|
17
|
+
global _current_service
|
|
18
|
+
_current_service = service
|
|
19
|
+
|
|
20
|
+
def get_service() -> Optional['Service']:
|
|
21
|
+
return _current_service
|
|
22
|
+
|
|
23
|
+
service_router = APIRouter(prefix="/sdk/service", tags=["service"])
|
|
24
|
+
|
|
25
|
+
class WorkflowStatusResponse(BaseModel):
|
|
26
|
+
name: str
|
|
27
|
+
version: str
|
|
28
|
+
description: str
|
|
29
|
+
workflows: list[dict]
|
|
30
|
+
|
|
31
|
+
class WorkflowActionResponse(BaseModel):
|
|
32
|
+
success: bool
|
|
33
|
+
message: str
|
|
34
|
+
|
|
35
|
+
@service_router.get("/status", response_model=WorkflowStatusResponse)
|
|
36
|
+
async def get_service_status():
|
|
37
|
+
service = get_service()
|
|
38
|
+
if service is None:
|
|
39
|
+
raise HTTPException(status_code=503, detail="Service not initialized")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
status = service.get_service_status()
|
|
43
|
+
return status
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.error(f"Error getting service status: {e}")
|
|
46
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
47
|
+
|
|
48
|
+
@service_router.post("/workflow/{workflow_name}/start", response_model=WorkflowActionResponse)
|
|
49
|
+
async def start_workflow(workflow_name: str):
|
|
50
|
+
service = get_service()
|
|
51
|
+
if service is None:
|
|
52
|
+
raise HTTPException(status_code=503, detail="Service not initialized")
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
success = await service.start_workflow(workflow_name)
|
|
56
|
+
if success:
|
|
57
|
+
return WorkflowActionResponse(success=True, message=f"Workflow {workflow_name} started successfully")
|
|
58
|
+
else:
|
|
59
|
+
return WorkflowActionResponse(success=False, message=f"Failed to start workflow {workflow_name}")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"Error starting workflow {workflow_name}: {e}")
|
|
62
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
63
|
+
|
|
64
|
+
@service_router.post("/workflow/{workflow_name}/stop", response_model=WorkflowActionResponse)
|
|
65
|
+
async def stop_workflow(workflow_name: str):
|
|
66
|
+
service = get_service()
|
|
67
|
+
if service is None:
|
|
68
|
+
raise HTTPException(status_code=503, detail="Service not initialized")
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
success = await service.stop_workflow(workflow_name)
|
|
72
|
+
if success:
|
|
73
|
+
return WorkflowActionResponse(success=True, message=f"Workflow {workflow_name} stopped successfully")
|
|
74
|
+
else:
|
|
75
|
+
return WorkflowActionResponse(success=False, message=f"Failed to stop workflow {workflow_name}")
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.error(f"Error stopping workflow {workflow_name}: {e}")
|
|
78
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
79
|
+
|
|
80
|
+
@service_router.post("/workflow/upload", response_model=WorkflowActionResponse)
|
|
81
|
+
async def upload_workflow_config(
|
|
82
|
+
file: Optional[UploadFile] = File(None),
|
|
83
|
+
config_content: Optional[str] = Form(None),
|
|
84
|
+
workflow_name: Optional[str] = Form(None)
|
|
85
|
+
):
|
|
86
|
+
service = get_service()
|
|
87
|
+
if service is None:
|
|
88
|
+
raise HTTPException(status_code=503, detail="Service not initialized")
|
|
89
|
+
|
|
90
|
+
if file is None and config_content is None:
|
|
91
|
+
raise HTTPException(status_code=400, detail="Either file or config_content must be provided")
|
|
92
|
+
|
|
93
|
+
if file is not None and config_content is not None:
|
|
94
|
+
raise HTTPException(status_code=400, detail="Cannot provide both file and config_content")
|
|
95
|
+
|
|
96
|
+
temp_file_path = None
|
|
97
|
+
try:
|
|
98
|
+
if file is not None:
|
|
99
|
+
if not file.filename or not file.filename.endswith(('.yaml', '.yml')):
|
|
100
|
+
raise HTTPException(status_code=400, detail="Only YAML files are supported")
|
|
101
|
+
|
|
102
|
+
suffix = '.yaml' if file.filename.endswith('.yaml') else '.yml'
|
|
103
|
+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix=suffix) as temp_file:
|
|
104
|
+
content = await file.read()
|
|
105
|
+
temp_file.write(content)
|
|
106
|
+
temp_file_path = temp_file.name
|
|
107
|
+
|
|
108
|
+
success = await service.load_workflow_from_config(config_path=temp_file_path, workflow_name=workflow_name)
|
|
109
|
+
else:
|
|
110
|
+
try:
|
|
111
|
+
config = OmegaConf.to_object(OmegaConf.create(config_content))
|
|
112
|
+
except Exception as e:
|
|
113
|
+
raise HTTPException(status_code=400, detail=f"Invalid YAML format: {str(e)}")
|
|
114
|
+
|
|
115
|
+
success = await service.load_workflow_from_config(config=config, workflow_name=workflow_name)
|
|
116
|
+
|
|
117
|
+
if success:
|
|
118
|
+
return WorkflowActionResponse(
|
|
119
|
+
success=True,
|
|
120
|
+
message=f"Workflow configuration uploaded and loaded successfully"
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
raise HTTPException(status_code=500, detail="Failed to load workflow configuration")
|
|
124
|
+
|
|
125
|
+
except HTTPException:
|
|
126
|
+
raise
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error(f"Error uploading workflow config: {e}")
|
|
129
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
130
|
+
|
|
131
|
+
finally:
|
|
132
|
+
if temp_file_path and os.path.exists(temp_file_path):
|
|
133
|
+
try:
|
|
134
|
+
os.unlink(temp_file_path)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.warning(f"Failed to delete temp file {temp_file_path}: {e}")
|
|
137
|
+
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from fastapi import WebSocket
|
|
2
|
+
from typing import Dict, Set
|
|
3
|
+
import uuid
|
|
4
|
+
import json
|
|
5
|
+
import asyncio
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
class WebSocketManager:
|
|
9
|
+
def __init__(self):
|
|
10
|
+
self.task_connections: Dict[uuid.UUID, Set[WebSocket]] = {}
|
|
11
|
+
self.websocket_tasks: Dict[WebSocket, Set[uuid.UUID]] = {}
|
|
12
|
+
self.all_task_subscribers: Set[WebSocket] = set()
|
|
13
|
+
self.websocket_subscribes_all: Dict[WebSocket, bool] = {}
|
|
14
|
+
self._lock = asyncio.Lock()
|
|
15
|
+
|
|
16
|
+
async def subscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
|
|
17
|
+
async with self._lock:
|
|
18
|
+
if task_id is None:
|
|
19
|
+
self.all_task_subscribers.add(websocket)
|
|
20
|
+
self.websocket_subscribes_all[websocket] = True
|
|
21
|
+
else:
|
|
22
|
+
if task_id not in self.task_connections:
|
|
23
|
+
self.task_connections[task_id] = set()
|
|
24
|
+
self.task_connections[task_id].add(websocket)
|
|
25
|
+
|
|
26
|
+
if websocket not in self.websocket_tasks:
|
|
27
|
+
self.websocket_tasks[websocket] = set()
|
|
28
|
+
self.websocket_tasks[websocket].add(task_id)
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
async def unsubscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
|
|
32
|
+
async with self._lock:
|
|
33
|
+
if task_id is None:
|
|
34
|
+
self.all_task_subscribers.discard(websocket)
|
|
35
|
+
self.websocket_subscribes_all.pop(websocket, None)
|
|
36
|
+
else:
|
|
37
|
+
if task_id in self.task_connections:
|
|
38
|
+
self.task_connections[task_id].discard(websocket)
|
|
39
|
+
if not self.task_connections[task_id]:
|
|
40
|
+
del self.task_connections[task_id]
|
|
41
|
+
|
|
42
|
+
if websocket in self.websocket_tasks:
|
|
43
|
+
self.websocket_tasks[websocket].discard(task_id)
|
|
44
|
+
if not self.websocket_tasks[websocket]:
|
|
45
|
+
del self.websocket_tasks[websocket]
|
|
46
|
+
return True
|
|
47
|
+
|
|
48
|
+
async def disconnect(self, websocket: WebSocket):
|
|
49
|
+
async with self._lock:
|
|
50
|
+
if websocket in self.websocket_tasks:
|
|
51
|
+
task_ids = self.websocket_tasks[websocket].copy()
|
|
52
|
+
for task_id in task_ids:
|
|
53
|
+
if task_id in self.task_connections:
|
|
54
|
+
self.task_connections[task_id].discard(websocket)
|
|
55
|
+
if not self.task_connections[task_id]:
|
|
56
|
+
del self.task_connections[task_id]
|
|
57
|
+
del self.websocket_tasks[websocket]
|
|
58
|
+
|
|
59
|
+
self.all_task_subscribers.discard(websocket)
|
|
60
|
+
self.websocket_subscribes_all.pop(websocket, None)
|
|
61
|
+
|
|
62
|
+
async def send_to_task(self, task_id: uuid.UUID, message: dict):
|
|
63
|
+
async with self._lock:
|
|
64
|
+
specific_subscribers = list(self.task_connections.get(task_id, set()))
|
|
65
|
+
all_subscribers = list(self.all_task_subscribers)
|
|
66
|
+
websockets = list(set(specific_subscribers + all_subscribers))
|
|
67
|
+
|
|
68
|
+
disconnected = set()
|
|
69
|
+
message_str = json.dumps(message)
|
|
70
|
+
|
|
71
|
+
# logger.debug(f"向 {len(websockets)} 个 websocket 连接发送消息 (task_id: {task_id})")
|
|
72
|
+
|
|
73
|
+
for websocket in websockets:
|
|
74
|
+
try:
|
|
75
|
+
await websocket.send_text(message_str)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
logger.error(f"向 task_id {task_id} 的 websocket 发送消息失败: {e}")
|
|
78
|
+
disconnected.add(websocket)
|
|
79
|
+
|
|
80
|
+
for ws in disconnected:
|
|
81
|
+
await self.disconnect(ws)
|
|
82
|
+
|
|
83
|
+
websocket_manager = WebSocketManager()
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from fastapi import WebSocket, WebSocketDisconnect
|
|
2
|
+
from fastapi.routing import APIRouter
|
|
3
|
+
from loguru import logger
|
|
4
|
+
import json
|
|
5
|
+
import uuid
|
|
6
|
+
from .websocket_manager import websocket_manager
|
|
7
|
+
|
|
8
|
+
websocket_router = APIRouter()
|
|
9
|
+
|
|
10
|
+
@websocket_router.websocket("/sdk/ws")
|
|
11
|
+
async def sdk_websocket_endpoint(websocket: WebSocket):
|
|
12
|
+
await websocket.accept()
|
|
13
|
+
try:
|
|
14
|
+
while True:
|
|
15
|
+
data = await websocket.receive_text()
|
|
16
|
+
try:
|
|
17
|
+
message = json.loads(data)
|
|
18
|
+
message_type = message.get("type")
|
|
19
|
+
|
|
20
|
+
if message_type == "subscribe":
|
|
21
|
+
task_id_str = message.get("task_id")
|
|
22
|
+
if not task_id_str:
|
|
23
|
+
await websocket.send_text(
|
|
24
|
+
json.dumps({"error": "Missing task_id in subscribe message"})
|
|
25
|
+
)
|
|
26
|
+
continue
|
|
27
|
+
|
|
28
|
+
if task_id_str.lower() == "all":
|
|
29
|
+
success = await websocket_manager.subscribe(websocket, None)
|
|
30
|
+
response = {"success": success, "type": "subscribe_response", "task_id": "all"}
|
|
31
|
+
await websocket.send_text(json.dumps(response))
|
|
32
|
+
else:
|
|
33
|
+
try:
|
|
34
|
+
task_id = uuid.UUID(task_id_str)
|
|
35
|
+
success = await websocket_manager.subscribe(websocket, task_id)
|
|
36
|
+
response = {"success": success, "type": "subscribe_response", "task_id": task_id_str}
|
|
37
|
+
await websocket.send_text(json.dumps(response))
|
|
38
|
+
except ValueError:
|
|
39
|
+
await websocket.send_text(
|
|
40
|
+
json.dumps({"error": "Invalid task_id format"})
|
|
41
|
+
)
|
|
42
|
+
elif message_type == "unsubscribe":
|
|
43
|
+
task_id_str = message.get("task_id")
|
|
44
|
+
if not task_id_str:
|
|
45
|
+
await websocket.send_text(
|
|
46
|
+
json.dumps({"error": "Missing task_id in unsubscribe message"})
|
|
47
|
+
)
|
|
48
|
+
continue
|
|
49
|
+
|
|
50
|
+
if task_id_str.lower() == "all":
|
|
51
|
+
success = await websocket_manager.unsubscribe(websocket, None)
|
|
52
|
+
response = {"success": success, "type": "unsubscribe_response", "task_id": "all"}
|
|
53
|
+
await websocket.send_text(json.dumps(response))
|
|
54
|
+
else:
|
|
55
|
+
try:
|
|
56
|
+
task_id = uuid.UUID(task_id_str)
|
|
57
|
+
success = await websocket_manager.unsubscribe(websocket, task_id)
|
|
58
|
+
response = {"success": success, "type": "unsubscribe_response", "task_id": task_id_str}
|
|
59
|
+
await websocket.send_text(json.dumps(response))
|
|
60
|
+
except ValueError:
|
|
61
|
+
await websocket.send_text(
|
|
62
|
+
json.dumps({"error": "Invalid task_id format"})
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
await websocket.send_text(
|
|
66
|
+
json.dumps({"error": f"Unknown message type: {message_type}"})
|
|
67
|
+
)
|
|
68
|
+
except json.JSONDecodeError:
|
|
69
|
+
logger.error(f"收到无效JSON消息: {data}")
|
|
70
|
+
await websocket.send_text(
|
|
71
|
+
json.dumps({"error": "Invalid JSON format"})
|
|
72
|
+
)
|
|
73
|
+
except WebSocketDisconnect:
|
|
74
|
+
await websocket_manager.disconnect(websocket)
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error(f"SDK WebSocket连接处理异常: {e}")
|
|
77
|
+
await websocket_manager.disconnect(websocket)
|
|
78
|
+
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
import datetime
|
|
6
|
+
from typing import Dict, List, Set, Any, Optional
|
|
7
|
+
|
|
8
|
+
class TaskManager:
|
|
9
|
+
"""任务管理器,用于跟踪任务状态和队列信息"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# 存储所有任务信息: {task_id: task_info}
|
|
13
|
+
self.tasks: Dict[uuid.UUID, Dict[str, Any]] = {}
|
|
14
|
+
# 任务队列,按添加顺序排列
|
|
15
|
+
self.task_queue: List[uuid.UUID] = []
|
|
16
|
+
# 正在执行的任务ID集合
|
|
17
|
+
self.running_tasks: Set[uuid.UUID] = set()
|
|
18
|
+
# 已完成的任务ID集合
|
|
19
|
+
self.completed_tasks: Set[uuid.UUID] = set()
|
|
20
|
+
# 客户端与任务的映射: {client_id: set(task_id)}
|
|
21
|
+
self.client_tasks: Dict[str, Set[uuid.UUID]] = {}
|
|
22
|
+
|
|
23
|
+
def add_task(self, task_id: uuid.UUID, client_id: str, workflow_name: str, steps: int) -> Dict[str, Any]:
|
|
24
|
+
"""添加新任务到队列"""
|
|
25
|
+
current_time = asyncio.get_event_loop().time()
|
|
26
|
+
task_info = {
|
|
27
|
+
"task_id": task_id,
|
|
28
|
+
"client_id": client_id,
|
|
29
|
+
"workflow_name": workflow_name,
|
|
30
|
+
"steps": steps,
|
|
31
|
+
"current_step": 0, # 当前步骤,从0开始
|
|
32
|
+
"status": "pending", # pending, running, completed, failed
|
|
33
|
+
"created_at": datetime.datetime.fromtimestamp(current_time).isoformat(),
|
|
34
|
+
"queue_position": len(self.task_queue) + 1
|
|
35
|
+
}
|
|
36
|
+
self.tasks[task_id] = task_info
|
|
37
|
+
self.task_queue.append(task_id)
|
|
38
|
+
|
|
39
|
+
# 更新客户端任务映射
|
|
40
|
+
if client_id not in self.client_tasks:
|
|
41
|
+
self.client_tasks[client_id] = set()
|
|
42
|
+
self.client_tasks[client_id].add(task_id)
|
|
43
|
+
|
|
44
|
+
return task_info
|
|
45
|
+
|
|
46
|
+
def start_task(self, task_id: uuid.UUID) -> bool:
|
|
47
|
+
"""标记任务开始执行"""
|
|
48
|
+
if task_id not in self.tasks:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
current_time = asyncio.get_event_loop().time()
|
|
52
|
+
self.tasks[task_id]["status"] = "running"
|
|
53
|
+
self.tasks[task_id]["started_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
54
|
+
self.tasks[task_id]["current_step"] = 1 # 开始执行第一步
|
|
55
|
+
self.running_tasks.add(task_id)
|
|
56
|
+
|
|
57
|
+
# 从队列中移除
|
|
58
|
+
if task_id in self.task_queue:
|
|
59
|
+
self.task_queue.remove(task_id)
|
|
60
|
+
|
|
61
|
+
# 更新队列中所有任务的位置
|
|
62
|
+
for i, q_task_id in enumerate(self.task_queue):
|
|
63
|
+
self.tasks[q_task_id]["queue_position"] = i + 1
|
|
64
|
+
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
def complete_task(self, task_id: uuid.UUID) -> bool:
|
|
68
|
+
"""标记任务完成"""
|
|
69
|
+
if task_id not in self.tasks:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
current_time = asyncio.get_event_loop().time()
|
|
73
|
+
self.tasks[task_id]["status"] = "completed"
|
|
74
|
+
self.tasks[task_id]["completed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
75
|
+
self.tasks[task_id]["current_step"] = self.tasks[task_id]["steps"] # 完成所有步骤
|
|
76
|
+
self.running_tasks.discard(task_id)
|
|
77
|
+
self.completed_tasks.add(task_id)
|
|
78
|
+
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
def fail_task(self, task_id: uuid.UUID, error: str = None) -> bool:
|
|
82
|
+
"""标记任务失败"""
|
|
83
|
+
if task_id not in self.tasks:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
current_time = asyncio.get_event_loop().time()
|
|
87
|
+
self.tasks[task_id]["status"] = "failed"
|
|
88
|
+
self.tasks[task_id]["failed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
89
|
+
if error:
|
|
90
|
+
self.tasks[task_id]["error"] = error
|
|
91
|
+
self.running_tasks.discard(task_id)
|
|
92
|
+
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
def get_client_tasks(self, client_id: str) -> List[Dict[str, Any]]:
|
|
96
|
+
"""获取客户端的所有任务"""
|
|
97
|
+
if client_id not in self.client_tasks:
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
return [
|
|
101
|
+
self.tasks[task_id]
|
|
102
|
+
for task_id in self.client_tasks[client_id]
|
|
103
|
+
if task_id in self.tasks
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
def get_queue_position(self, task_id: uuid.UUID) -> int:
|
|
107
|
+
"""获取任务在队列中的位置,从1开始计数,如果不在队列中返回-1"""
|
|
108
|
+
if task_id not in self.tasks:
|
|
109
|
+
return -1
|
|
110
|
+
|
|
111
|
+
return self.tasks[task_id].get("queue_position", -1)
|
|
112
|
+
|
|
113
|
+
def get_global_queue_info(self) -> Dict[str, int]:
|
|
114
|
+
"""获取全局队列信息"""
|
|
115
|
+
return {
|
|
116
|
+
"total": len(self.running_tasks) + len(self.task_queue),
|
|
117
|
+
"waiting": len(self.task_queue),
|
|
118
|
+
"running": len(self.running_tasks),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
def get_task_info(self, task_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
|
122
|
+
"""获取特定任务的详细信息"""
|
|
123
|
+
task_info = self.tasks.get(task_id)
|
|
124
|
+
if task_info:
|
|
125
|
+
# 创建任务信息的副本,并将UUID转换为字符串
|
|
126
|
+
task_copy = task_info.copy()
|
|
127
|
+
task_copy["task_id"] = str(task_id)
|
|
128
|
+
return task_copy
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
def update_current_step(self, task_id: uuid.UUID, step: int) -> bool:
|
|
132
|
+
"""更新当前任务的步骤"""
|
|
133
|
+
if task_id not in self.tasks:
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# 确保步骤在有效范围内
|
|
137
|
+
if step < 0 or step > self.tasks[task_id]["steps"]:
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
self.tasks[task_id]["current_step"] = step
|
|
141
|
+
return True
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import redis
|
|
4
|
+
import pymongo
|
|
5
|
+
import psycopg2
|
|
6
|
+
from typing import AsyncGenerator
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from omegaconf import OmegaConf
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
10
|
+
from sqlalchemy import text
|
|
11
|
+
|
|
12
|
+
class PostgresDatabase:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
name: str,
|
|
16
|
+
postgres_user: str,
|
|
17
|
+
postgres_password: str,
|
|
18
|
+
postgres_host: str,
|
|
19
|
+
postgres_port: int,
|
|
20
|
+
postgres_db: str,
|
|
21
|
+
) -> None:
|
|
22
|
+
self.name = name
|
|
23
|
+
self.postgres_user = postgres_user
|
|
24
|
+
self.postgres_password = postgres_password
|
|
25
|
+
self.postgres_host = postgres_host
|
|
26
|
+
self.postgres_port = postgres_port
|
|
27
|
+
self.postgres_db = postgres_db
|
|
28
|
+
self.engine = None
|
|
29
|
+
self.session_factory = None
|
|
30
|
+
self.test_connection()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def database_url(self) -> str:
|
|
34
|
+
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def database_base_url(self) -> str:
|
|
38
|
+
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/postgres"
|
|
39
|
+
|
|
40
|
+
async def init(self) -> None:
|
|
41
|
+
if self.engine is None:
|
|
42
|
+
self.engine = await self.create_engine()
|
|
43
|
+
self.session_factory = async_sessionmaker(bind=self.engine, class_=AsyncSession, expire_on_commit=False)
|
|
44
|
+
|
|
45
|
+
async def close(self) -> None:
|
|
46
|
+
if self.engine:
|
|
47
|
+
await self.engine.dispose()
|
|
48
|
+
self.engine = None
|
|
49
|
+
self.session_factory = None
|
|
50
|
+
logger.info("Database connection closed")
|
|
51
|
+
|
|
52
|
+
async def create_engine(self) -> AsyncEngine:
|
|
53
|
+
if not all([self.postgres_user, self.postgres_host, self.postgres_port, self.postgres_db]):
|
|
54
|
+
raise ValueError("Missing required database configuration. Please check your .env file or configuration.")
|
|
55
|
+
logger.info(f"Creating database engine: {self.database_url}")
|
|
56
|
+
return create_async_engine(self.database_url)
|
|
57
|
+
|
|
58
|
+
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
59
|
+
if self.session_factory is None:
|
|
60
|
+
await self.init()
|
|
61
|
+
|
|
62
|
+
if self.session_factory is None:
|
|
63
|
+
raise RuntimeError("Session factory is not initialized")
|
|
64
|
+
|
|
65
|
+
async with self.session_factory() as session:
|
|
66
|
+
try:
|
|
67
|
+
yield session
|
|
68
|
+
except Exception:
|
|
69
|
+
await session.rollback()
|
|
70
|
+
raise
|
|
71
|
+
finally:
|
|
72
|
+
await session.close()
|
|
73
|
+
yield session
|
|
74
|
+
|
|
75
|
+
async def get_session_factory(self) -> async_sessionmaker[AsyncSession]:
|
|
76
|
+
if self.engine is None:
|
|
77
|
+
await self.init()
|
|
78
|
+
|
|
79
|
+
if self.session_factory is None:
|
|
80
|
+
raise RuntimeError("Session factory is not initialized")
|
|
81
|
+
|
|
82
|
+
return self.session_factory
|
|
83
|
+
|
|
84
|
+
def test_connection(self) -> bool:
|
|
85
|
+
try:
|
|
86
|
+
conn = psycopg2.connect(
|
|
87
|
+
host=self.postgres_host,
|
|
88
|
+
port=self.postgres_port,
|
|
89
|
+
user=self.postgres_user,
|
|
90
|
+
password=self.postgres_password,
|
|
91
|
+
database=self.postgres_db,
|
|
92
|
+
connect_timeout=5
|
|
93
|
+
)
|
|
94
|
+
conn.close()
|
|
95
|
+
logger.info(f"PostgreSQL connection test successful for database '{self.name}'")
|
|
96
|
+
return True
|
|
97
|
+
except Exception as e:
|
|
98
|
+
logger.warning(f"PostgreSQL connection test failed for database '{self.name}': {e}")
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
class MongoDatabase:
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
name: str,
|
|
105
|
+
mongo_host: str,
|
|
106
|
+
mongo_port: int,
|
|
107
|
+
mongo_user: str,
|
|
108
|
+
mongo_password: str,
|
|
109
|
+
mongo_db: str,
|
|
110
|
+
) -> None:
|
|
111
|
+
self.name = name
|
|
112
|
+
self.mongo_host = mongo_host
|
|
113
|
+
self.mongo_port = mongo_port
|
|
114
|
+
self.mongo_user = mongo_user
|
|
115
|
+
self.mongo_password = mongo_password
|
|
116
|
+
self.mongo_db = mongo_db or ""
|
|
117
|
+
self.client = pymongo.MongoClient(self.database_url)
|
|
118
|
+
self.test_connection()
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def database_url(self) -> str:
|
|
122
|
+
return f"mongodb://{self.mongo_user}:{self.mongo_password}@{self.mongo_host}:{self.mongo_port}/{self.mongo_db}"
|
|
123
|
+
|
|
124
|
+
def test_connection(self) -> bool:
|
|
125
|
+
try:
|
|
126
|
+
self.client.admin.command('ping')
|
|
127
|
+
logger.info(f"MongoDB connection test successful for database '{self.name}'")
|
|
128
|
+
return True
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"MongoDB connection test failed for database '{self.name}': {e}")
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
class RedisDatabase:
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
name: str,
|
|
137
|
+
redis_host: str,
|
|
138
|
+
redis_port: int,
|
|
139
|
+
redis_password: str,
|
|
140
|
+
) -> None:
|
|
141
|
+
self.name = name
|
|
142
|
+
self.redis_host = redis_host
|
|
143
|
+
self.redis_port = redis_port
|
|
144
|
+
self.redis_password = redis_password
|
|
145
|
+
self.client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
|
|
146
|
+
self.test_connection()
|
|
147
|
+
|
|
148
|
+
def test_connection(self) -> bool:
|
|
149
|
+
try:
|
|
150
|
+
self.client.ping()
|
|
151
|
+
logger.info(f"Redis connection test successful for database '{self.name}'")
|
|
152
|
+
return True
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.error(f"Redis connection test failed for database '{self.name}': {e}")
|
|
155
|
+
return False
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class DatabaseManager:
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
postgres_databases: list[PostgresDatabase],
|
|
162
|
+
mongo_databases: list[MongoDatabase],
|
|
163
|
+
redis_databases: list[RedisDatabase],
|
|
164
|
+
) -> None:
|
|
165
|
+
self.postgres_databases = postgres_databases
|
|
166
|
+
self.mongo_databases = mongo_databases
|
|
167
|
+
self.redis_databases = redis_databases
|
|
168
|
+
|
|
169
|
+
def get_database(self, name: str) -> PostgresDatabase | MongoDatabase | RedisDatabase | None:
|
|
170
|
+
for database in self.postgres_databases:
|
|
171
|
+
if database.name == name:
|
|
172
|
+
return database
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
def get_default_postgres_database(self) -> PostgresDatabase | None:
|
|
176
|
+
if len(self.postgres_databases) > 0:
|
|
177
|
+
return self.postgres_databases[0]
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
def get_default_mongo_database(self) -> MongoDatabase | None:
|
|
181
|
+
if len(self.mongo_databases) > 0:
|
|
182
|
+
return self.mongo_databases[0]
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
def get_default_redis_database(self) -> RedisDatabase | None:
|
|
186
|
+
if len(self.redis_databases) > 0:
|
|
187
|
+
return self.redis_databases[0]
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
def from_config(config_path: str = None, config = None) -> DatabaseManager:
|
|
192
|
+
if config is None:
|
|
193
|
+
config = OmegaConf.to_object(OmegaConf.load(config_path))
|
|
194
|
+
|
|
195
|
+
databases_config = config.get('databases', None)
|
|
196
|
+
postgres_databases = []
|
|
197
|
+
mongo_databases = []
|
|
198
|
+
redis_databases = []
|
|
199
|
+
if databases_config is not None:
|
|
200
|
+
for database_config in databases_config:
|
|
201
|
+
if ('postgres_host' in database_config and database_config['postgres_host'] is not None) + \
|
|
202
|
+
('mongo_host' in database_config and database_config['mongo_host'] is not None) + \
|
|
203
|
+
('redis_host' in database_config and database_config['redis_host'] is not None) == 0:
|
|
204
|
+
raise ValueError(f"Database '{database_config['name']}' is missing required configuration. Please check your service.yaml file.")
|
|
205
|
+
|
|
206
|
+
if ('postgres_host' in database_config and database_config['postgres_host'] is not None) + \
|
|
207
|
+
('mongo_host' in database_config and database_config['mongo_host'] is not None) + \
|
|
208
|
+
('redis_host' in database_config and database_config['redis_host'] is not None) > 1:
|
|
209
|
+
raise ValueError(f"Database '{database_config['name']}' has multiple host configurations. Please check your service.yaml file.")
|
|
210
|
+
|
|
211
|
+
if 'postgres_host' in database_config and database_config['postgres_host'] is not None:
|
|
212
|
+
postgres_databases.append(PostgresDatabase(
|
|
213
|
+
name=database_config['name'],
|
|
214
|
+
postgres_user=database_config['postgres_user'],
|
|
215
|
+
postgres_password=database_config['postgres_password'],
|
|
216
|
+
postgres_host=database_config['postgres_host'],
|
|
217
|
+
postgres_port=database_config['postgres_port'],
|
|
218
|
+
postgres_db=database_config['postgres_db'],
|
|
219
|
+
))
|
|
220
|
+
if 'mongo_host' in database_config and database_config['mongo_host'] is not None:
|
|
221
|
+
mongo_databases.append(MongoDatabase(
|
|
222
|
+
name=database_config['name'],
|
|
223
|
+
mongo_host=database_config['mongo_host'],
|
|
224
|
+
mongo_port=database_config['mongo_port'],
|
|
225
|
+
mongo_user=database_config['mongo_user'],
|
|
226
|
+
mongo_password=database_config['mongo_password'],
|
|
227
|
+
mongo_db=database_config['mongo_db'],
|
|
228
|
+
))
|
|
229
|
+
if 'redis_host' in database_config and database_config['redis_host'] is not None:
|
|
230
|
+
redis_databases.append(RedisDatabase(
|
|
231
|
+
name=database_config['name'],
|
|
232
|
+
redis_host=database_config['redis_host'],
|
|
233
|
+
redis_port=database_config['redis_port'],
|
|
234
|
+
redis_password=database_config['redis_password'],
|
|
235
|
+
))
|
|
236
|
+
|
|
237
|
+
return DatabaseManager(postgres_databases=postgres_databases, mongo_databases=mongo_databases, redis_databases=redis_databases)
|
|
238
|
+
|
|
239
|
+
def create_database_manager(config_path: str = None, config = None) -> DatabaseManager:
|
|
240
|
+
return DatabaseManager.from_config(config_path, config)
|