service-forge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- service_forge/api/http_api.py +138 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/api/websocket_api.py +86 -0
- service_forge/api/websocket_manager.py +425 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +119 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/main.py +121 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +111 -0
- service_forge/service_config.py +115 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +124 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +87 -0
- service_forge/sft/config/injector_default_files.py +97 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +125 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +2 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +74 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +13 -0
- service_forge/workflow/edge.py +31 -0
- service_forge/workflow/node.py +179 -0
- service_forge/workflow/nodes/__init__.py +7 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +61 -0
- service_forge/workflow/port.py +86 -0
- service_forge/workflow/trigger.py +20 -0
- service_forge/workflow/triggers/__init__.py +4 -0
- service_forge/workflow/triggers/fast_api_trigger.py +125 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/workflow.py +251 -0
- service_forge/workflow/workflow_factory.py +227 -0
- service_forge/workflow/workflow_group.py +23 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.0.dist-info/METADATA +93 -0
- service_forge-0.1.0.dist-info/RECORD +64 -0
- service_forge-0.1.0.dist-info/WHEEL +4 -0
- service_forge-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from fastapi import FastAPI
|
|
2
|
+
import uvicorn
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from fastapi import APIRouter
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
from fastapi import HTTPException, Request
|
|
8
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
9
|
+
from service_forge.api.websocket_api import router as websocket_router
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def is_trusted_origin(origin_host: str, host: str, trusted_root: str = "ring.shiweinan.com") -> bool:
|
|
13
|
+
"""
|
|
14
|
+
Check if the origin host is trusted based on domain matching.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
origin_host: The hostname from the origin header
|
|
18
|
+
host: The hostname from the host header
|
|
19
|
+
trusted_root: The trusted root domain (can be customized)
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
bool: True if the origin is trusted, False otherwise
|
|
23
|
+
"""
|
|
24
|
+
# Convert to lowercase to avoid case sensitivity issues
|
|
25
|
+
origin_host = origin_host.lower()
|
|
26
|
+
host = host.lower()
|
|
27
|
+
|
|
28
|
+
# Allow same domain, or subdomains under the same trusted root
|
|
29
|
+
return (
|
|
30
|
+
origin_host == host or
|
|
31
|
+
origin_host.endswith("." + trusted_root) or
|
|
32
|
+
host.endswith("." + trusted_root)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def create_app(
|
|
37
|
+
app: Optional[FastAPI] = None,
|
|
38
|
+
routers: Optional[list[APIRouter]] = None,
|
|
39
|
+
cors_origins: Optional[list[str]] = None,
|
|
40
|
+
enable_auth_middleware: bool = True,
|
|
41
|
+
trusted_domain: str = "ring.shiweinan.com"
|
|
42
|
+
) -> FastAPI:
|
|
43
|
+
"""
|
|
44
|
+
Create or configure a FastAPI app with common middleware and configuration.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
app: Optional existing FastAPI instance. If None, creates a new one.
|
|
48
|
+
routers: List of APIRouter instances to include
|
|
49
|
+
cors_origins: List of allowed CORS origins. Defaults to ["*"]
|
|
50
|
+
enable_auth_middleware: Whether to enable authentication middleware
|
|
51
|
+
trusted_domain: Trusted domain for origin validation
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
FastAPI: Configured FastAPI application instance
|
|
55
|
+
"""
|
|
56
|
+
if app is None:
|
|
57
|
+
app = FastAPI()
|
|
58
|
+
|
|
59
|
+
# Configure CORS middleware
|
|
60
|
+
if cors_origins is None:
|
|
61
|
+
cors_origins = ["*"]
|
|
62
|
+
|
|
63
|
+
app.add_middleware(
|
|
64
|
+
CORSMiddleware,
|
|
65
|
+
allow_origins=cors_origins,
|
|
66
|
+
allow_credentials=True,
|
|
67
|
+
allow_methods=["*"],
|
|
68
|
+
allow_headers=["*"],
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Include routers if provided
|
|
72
|
+
if routers:
|
|
73
|
+
for router in routers:
|
|
74
|
+
app.include_router(router)
|
|
75
|
+
|
|
76
|
+
# Always include WebSocket router
|
|
77
|
+
app.include_router(websocket_router)
|
|
78
|
+
|
|
79
|
+
# Add authentication middleware if enabled
|
|
80
|
+
if enable_auth_middleware:
|
|
81
|
+
@app.middleware("http")
|
|
82
|
+
async def auth_middleware(request: Request, call_next):
|
|
83
|
+
"""
|
|
84
|
+
Authentication middleware for API routes.
|
|
85
|
+
|
|
86
|
+
Validates user authentication for /api routes with origin-based
|
|
87
|
+
trust verification and X-User-ID header validation.
|
|
88
|
+
"""
|
|
89
|
+
if request.url.path.startswith("/api"):
|
|
90
|
+
origin = request.headers.get("origin") or request.headers.get("referer")
|
|
91
|
+
scheme = request.url.scheme
|
|
92
|
+
host = request.headers.get("host", "")
|
|
93
|
+
is_same_origin = False
|
|
94
|
+
|
|
95
|
+
logger.debug(f"origin {origin}, host:{host}")
|
|
96
|
+
|
|
97
|
+
if origin and host:
|
|
98
|
+
try:
|
|
99
|
+
parsed_origin = urlparse(origin)
|
|
100
|
+
parsed_host = urlparse(f"{scheme}://{host}")
|
|
101
|
+
is_same_origin = (
|
|
102
|
+
parsed_origin.hostname == parsed_host.hostname
|
|
103
|
+
and parsed_origin.port == parsed_host.port
|
|
104
|
+
and is_trusted_origin(parsed_origin.hostname, parsed_host.hostname, trusted_domain)
|
|
105
|
+
)
|
|
106
|
+
except Exception:
|
|
107
|
+
pass # If parsing fails, continue with default behavior
|
|
108
|
+
if not is_same_origin:
|
|
109
|
+
headers = request.headers
|
|
110
|
+
user_id = headers.get("X-User-ID")
|
|
111
|
+
if not user_id:
|
|
112
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
113
|
+
|
|
114
|
+
request.state.user_id = user_id
|
|
115
|
+
else:
|
|
116
|
+
# Same-origin requests can skip auth, but still set default user_id
|
|
117
|
+
request.state.user_id = "0" # Can be None or default value as needed
|
|
118
|
+
|
|
119
|
+
return await call_next(request)
|
|
120
|
+
|
|
121
|
+
return app
|
|
122
|
+
|
|
123
|
+
async def start_fastapi_server(host: str, port: int):
|
|
124
|
+
try:
|
|
125
|
+
config = uvicorn.Config(
|
|
126
|
+
fastapi_app,
|
|
127
|
+
host=host,
|
|
128
|
+
port=int(port),
|
|
129
|
+
log_level="info",
|
|
130
|
+
access_log=True
|
|
131
|
+
)
|
|
132
|
+
server = uvicorn.Server(config)
|
|
133
|
+
await server.serve()
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.error(f"Server error: {e}")
|
|
136
|
+
raise
|
|
137
|
+
|
|
138
|
+
fastapi_app = create_app(enable_auth_middleware=False)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable, Any
|
|
3
|
+
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer, ConsumerRecord
|
|
4
|
+
import asyncio
|
|
5
|
+
import json
|
|
6
|
+
import inspect
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
class KafkaApp:
|
|
11
|
+
def __init__(self, bootstrap_servers: str = None):
|
|
12
|
+
self.bootstrap_servers = bootstrap_servers
|
|
13
|
+
self._handlers: dict[str, Callable] = {}
|
|
14
|
+
self._consumer_tasks: dict[str, asyncio.Task] = {}
|
|
15
|
+
self._producer: AIOKafkaProducer = None
|
|
16
|
+
self._lock = asyncio.Lock()
|
|
17
|
+
self._running = False
|
|
18
|
+
|
|
19
|
+
def kafka_input(self, topic: str, data_type: type, group_id: str):
|
|
20
|
+
def decorator(func: Callable):
|
|
21
|
+
self._handlers[topic] = (func, data_type, group_id)
|
|
22
|
+
logger.info(f"Registered Kafka input handler for topic '{topic}', data_type: {data_type}")
|
|
23
|
+
|
|
24
|
+
if self._running:
|
|
25
|
+
asyncio.create_task(self._start_consumer(topic, func, data_type, group_id))
|
|
26
|
+
return func
|
|
27
|
+
return decorator
|
|
28
|
+
|
|
29
|
+
def set_bootstrap_servers(self, bootstrap_servers: str) -> None:
|
|
30
|
+
self.bootstrap_servers = bootstrap_servers
|
|
31
|
+
|
|
32
|
+
async def start(self):
|
|
33
|
+
if not self.bootstrap_servers:
|
|
34
|
+
raise ValueError("bootstrap_servers 未设置")
|
|
35
|
+
|
|
36
|
+
logger.info(f"🚀 KafkaApp started with servers: {self.bootstrap_servers}")
|
|
37
|
+
|
|
38
|
+
await self._start_producer()
|
|
39
|
+
|
|
40
|
+
async with self._lock:
|
|
41
|
+
for topic, (handler, data_type, group_id) in self._handlers.items():
|
|
42
|
+
if topic not in self._consumer_tasks:
|
|
43
|
+
self._consumer_tasks[topic] = asyncio.create_task(self._start_consumer(topic, handler, data_type, group_id))
|
|
44
|
+
|
|
45
|
+
self._running = True
|
|
46
|
+
while self._running:
|
|
47
|
+
await asyncio.sleep(1)
|
|
48
|
+
|
|
49
|
+
async def _start_consumer(self, topic: str, handler: Callable, data_type: type, group_id: str):
|
|
50
|
+
consumer = AIOKafkaConsumer(
|
|
51
|
+
topic,
|
|
52
|
+
bootstrap_servers=self.bootstrap_servers,
|
|
53
|
+
group_id=group_id,
|
|
54
|
+
enable_auto_commit=True,
|
|
55
|
+
auto_offset_reset="latest",
|
|
56
|
+
)
|
|
57
|
+
await consumer.start()
|
|
58
|
+
logger.info(f"✅ Started consumer for topic: {topic}")
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
async for msg in consumer:
|
|
62
|
+
await self._dispatch_message(handler, msg, data_type)
|
|
63
|
+
except asyncio.CancelledError:
|
|
64
|
+
logger.warning(f"🛑 Consumer for {topic} cancelled")
|
|
65
|
+
finally:
|
|
66
|
+
await consumer.stop()
|
|
67
|
+
|
|
68
|
+
async def _dispatch_message(self, handler: Callable, msg: ConsumerRecord, data_type: type):
|
|
69
|
+
try:
|
|
70
|
+
data = data_type()
|
|
71
|
+
data.ParseFromString(msg.value)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
print("Error:", e)
|
|
74
|
+
data = data_type(**json.loads(msg.value.decode("utf-8")))
|
|
75
|
+
result = handler(data)
|
|
76
|
+
if inspect.iscoroutine(result):
|
|
77
|
+
await result
|
|
78
|
+
|
|
79
|
+
async def _start_producer(self):
|
|
80
|
+
if self._producer is None:
|
|
81
|
+
self._producer = AIOKafkaProducer(
|
|
82
|
+
bootstrap_servers=self.bootstrap_servers,
|
|
83
|
+
value_serializer=lambda v: v.SerializeToString(),
|
|
84
|
+
)
|
|
85
|
+
await self._producer.start()
|
|
86
|
+
logger.info("✅ Kafka producer started")
|
|
87
|
+
|
|
88
|
+
async def _stop_producer(self):
|
|
89
|
+
if self._producer is not None:
|
|
90
|
+
await self._producer.stop()
|
|
91
|
+
self._producer = None
|
|
92
|
+
logger.info("✅ Kafka producer stopped")
|
|
93
|
+
|
|
94
|
+
async def send_message(self, topic: str, data_type: type, data: Any) -> None:
|
|
95
|
+
if not self._running:
|
|
96
|
+
raise RuntimeError("KafkaApp is not running. Call start() first.")
|
|
97
|
+
|
|
98
|
+
if self._producer is None:
|
|
99
|
+
raise RuntimeError("Kafka producer is not initialized.")
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
await self._producer.send_and_wait(topic, data)
|
|
103
|
+
logger.info(f"✅ 已发送消息到 topic '{topic}', type: {data_type}")
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
logger.error(f"❌ 发送消息到 topic '{topic}' 失败: {e}")
|
|
107
|
+
raise
|
|
108
|
+
|
|
109
|
+
async def stop(self):
|
|
110
|
+
logger.info("Stopping KafkaApp ...")
|
|
111
|
+
self._running = False
|
|
112
|
+
|
|
113
|
+
for t in list(self._consumer_tasks.values()):
|
|
114
|
+
t.cancel()
|
|
115
|
+
await asyncio.sleep(0.1)
|
|
116
|
+
self._consumer_tasks.clear()
|
|
117
|
+
|
|
118
|
+
await self._stop_producer()
|
|
119
|
+
|
|
120
|
+
logger.info("✅ KafkaApp stopped")
|
|
121
|
+
|
|
122
|
+
kafka_app = KafkaApp()
|
|
123
|
+
|
|
124
|
+
async def start_kafka_server(bootstrap_servers: str):
|
|
125
|
+
kafka_app.set_bootstrap_servers(bootstrap_servers)
|
|
126
|
+
await kafka_app.start()
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
import datetime
|
|
6
|
+
from typing import Dict, List, Set, Any, Optional
|
|
7
|
+
|
|
8
|
+
class TaskManager:
|
|
9
|
+
"""任务管理器,用于跟踪任务状态和队列信息"""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# 存储所有任务信息: {task_id: task_info}
|
|
13
|
+
self.tasks: Dict[uuid.UUID, Dict[str, Any]] = {}
|
|
14
|
+
# 任务队列,按添加顺序排列
|
|
15
|
+
self.task_queue: List[uuid.UUID] = []
|
|
16
|
+
# 正在执行的任务ID集合
|
|
17
|
+
self.running_tasks: Set[uuid.UUID] = set()
|
|
18
|
+
# 已完成的任务ID集合
|
|
19
|
+
self.completed_tasks: Set[uuid.UUID] = set()
|
|
20
|
+
# 客户端与任务的映射: {client_id: set(task_id)}
|
|
21
|
+
self.client_tasks: Dict[str, Set[uuid.UUID]] = {}
|
|
22
|
+
|
|
23
|
+
def add_task(self, task_id: uuid.UUID, client_id: str, workflow_name: str, steps: int) -> Dict[str, Any]:
|
|
24
|
+
"""添加新任务到队列"""
|
|
25
|
+
current_time = asyncio.get_event_loop().time()
|
|
26
|
+
task_info = {
|
|
27
|
+
"task_id": task_id,
|
|
28
|
+
"client_id": client_id,
|
|
29
|
+
"workflow_name": workflow_name,
|
|
30
|
+
"steps": steps,
|
|
31
|
+
"current_step": 0, # 当前步骤,从0开始
|
|
32
|
+
"status": "pending", # pending, running, completed, failed
|
|
33
|
+
"created_at": datetime.datetime.fromtimestamp(current_time).isoformat(),
|
|
34
|
+
"queue_position": len(self.task_queue) + 1
|
|
35
|
+
}
|
|
36
|
+
self.tasks[task_id] = task_info
|
|
37
|
+
self.task_queue.append(task_id)
|
|
38
|
+
|
|
39
|
+
# 更新客户端任务映射
|
|
40
|
+
if client_id not in self.client_tasks:
|
|
41
|
+
self.client_tasks[client_id] = set()
|
|
42
|
+
self.client_tasks[client_id].add(task_id)
|
|
43
|
+
|
|
44
|
+
return task_info
|
|
45
|
+
|
|
46
|
+
def start_task(self, task_id: uuid.UUID) -> bool:
|
|
47
|
+
"""标记任务开始执行"""
|
|
48
|
+
if task_id not in self.tasks:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
current_time = asyncio.get_event_loop().time()
|
|
52
|
+
self.tasks[task_id]["status"] = "running"
|
|
53
|
+
self.tasks[task_id]["started_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
54
|
+
self.tasks[task_id]["current_step"] = 1 # 开始执行第一步
|
|
55
|
+
self.running_tasks.add(task_id)
|
|
56
|
+
|
|
57
|
+
# 从队列中移除
|
|
58
|
+
if task_id in self.task_queue:
|
|
59
|
+
self.task_queue.remove(task_id)
|
|
60
|
+
|
|
61
|
+
# 更新队列中所有任务的位置
|
|
62
|
+
for i, q_task_id in enumerate(self.task_queue):
|
|
63
|
+
self.tasks[q_task_id]["queue_position"] = i + 1
|
|
64
|
+
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
def complete_task(self, task_id: uuid.UUID) -> bool:
|
|
68
|
+
"""标记任务完成"""
|
|
69
|
+
if task_id not in self.tasks:
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
current_time = asyncio.get_event_loop().time()
|
|
73
|
+
self.tasks[task_id]["status"] = "completed"
|
|
74
|
+
self.tasks[task_id]["completed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
75
|
+
self.tasks[task_id]["current_step"] = self.tasks[task_id]["steps"] # 完成所有步骤
|
|
76
|
+
self.running_tasks.discard(task_id)
|
|
77
|
+
self.completed_tasks.add(task_id)
|
|
78
|
+
|
|
79
|
+
return True
|
|
80
|
+
|
|
81
|
+
def fail_task(self, task_id: uuid.UUID, error: str = None) -> bool:
|
|
82
|
+
"""标记任务失败"""
|
|
83
|
+
if task_id not in self.tasks:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
current_time = asyncio.get_event_loop().time()
|
|
87
|
+
self.tasks[task_id]["status"] = "failed"
|
|
88
|
+
self.tasks[task_id]["failed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
|
|
89
|
+
if error:
|
|
90
|
+
self.tasks[task_id]["error"] = error
|
|
91
|
+
self.running_tasks.discard(task_id)
|
|
92
|
+
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
def get_client_tasks(self, client_id: str) -> List[Dict[str, Any]]:
|
|
96
|
+
"""获取客户端的所有任务"""
|
|
97
|
+
if client_id not in self.client_tasks:
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
return [
|
|
101
|
+
self.tasks[task_id]
|
|
102
|
+
for task_id in self.client_tasks[client_id]
|
|
103
|
+
if task_id in self.tasks
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
def get_queue_position(self, task_id: uuid.UUID) -> int:
|
|
107
|
+
"""获取任务在队列中的位置,从1开始计数,如果不在队列中返回-1"""
|
|
108
|
+
if task_id not in self.tasks:
|
|
109
|
+
return -1
|
|
110
|
+
|
|
111
|
+
return self.tasks[task_id].get("queue_position", -1)
|
|
112
|
+
|
|
113
|
+
def get_global_queue_info(self) -> Dict[str, int]:
|
|
114
|
+
"""获取全局队列信息"""
|
|
115
|
+
return {
|
|
116
|
+
"total": len(self.running_tasks) + len(self.task_queue),
|
|
117
|
+
"waiting": len(self.task_queue),
|
|
118
|
+
"running": len(self.running_tasks),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
def get_task_info(self, task_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
|
122
|
+
"""获取特定任务的详细信息"""
|
|
123
|
+
task_info = self.tasks.get(task_id)
|
|
124
|
+
if task_info:
|
|
125
|
+
# 创建任务信息的副本,并将UUID转换为字符串
|
|
126
|
+
task_copy = task_info.copy()
|
|
127
|
+
task_copy["task_id"] = str(task_id)
|
|
128
|
+
return task_copy
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
def update_current_step(self, task_id: uuid.UUID, step: int) -> bool:
|
|
132
|
+
"""更新当前任务的步骤"""
|
|
133
|
+
if task_id not in self.tasks:
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
# 确保步骤在有效范围内
|
|
137
|
+
if step < 0 or step > self.tasks[task_id]["steps"]:
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
self.tasks[task_id]["current_step"] = step
|
|
141
|
+
return True
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, Query
|
|
6
|
+
from fastapi.responses import JSONResponse
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from typing import Dict, Any, Optional
|
|
9
|
+
from .websocket_manager import websocket_manager
|
|
10
|
+
|
|
11
|
+
router = APIRouter(prefix="/ws", tags=["websocket"])
|
|
12
|
+
|
|
13
|
+
@router.websocket("/connect")
|
|
14
|
+
async def websocket_endpoint(websocket: WebSocket, client_id: Optional[str] = Query(None)):
|
|
15
|
+
"""WebSocket连接端点,支持指定客户端ID"""
|
|
16
|
+
client_id = await websocket_manager.connect(websocket, client_id)
|
|
17
|
+
try:
|
|
18
|
+
while True:
|
|
19
|
+
# 接收客户端消息
|
|
20
|
+
data = await websocket.receive_text()
|
|
21
|
+
try:
|
|
22
|
+
message = json.loads(data)
|
|
23
|
+
await handle_client_message(client_id, message)
|
|
24
|
+
except json.JSONDecodeError:
|
|
25
|
+
logger.error(f"从客户端 {client_id} 收到无效JSON消息: {data}")
|
|
26
|
+
await websocket_manager.send_personal_message(
|
|
27
|
+
json.dumps({"error": "Invalid JSON format"}),
|
|
28
|
+
client_id
|
|
29
|
+
)
|
|
30
|
+
except WebSocketDisconnect:
|
|
31
|
+
websocket_manager.disconnect(client_id)
|
|
32
|
+
except Exception as e:
|
|
33
|
+
logger.error(f"WebSocket连接处理异常: {e}")
|
|
34
|
+
websocket_manager.disconnect(client_id)
|
|
35
|
+
|
|
36
|
+
async def handle_client_message(client_id: str, message: Dict[str, Any]):
|
|
37
|
+
"""处理来自客户端的消息"""
|
|
38
|
+
message_type = message.get("type")
|
|
39
|
+
|
|
40
|
+
if message_type == "subscribe":
|
|
41
|
+
# 客户端订阅任务
|
|
42
|
+
task_id_str = message.get("task_id")
|
|
43
|
+
if not task_id_str:
|
|
44
|
+
await websocket_manager.send_personal_message(
|
|
45
|
+
json.dumps({"error": "Missing task_id in subscribe message"}),
|
|
46
|
+
client_id
|
|
47
|
+
)
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
task_id = uuid.UUID(task_id_str)
|
|
52
|
+
success = await websocket_manager.subscribe_to_task(client_id, task_id)
|
|
53
|
+
response = {"success": success}
|
|
54
|
+
await websocket_manager.send_personal_message(json.dumps(response), client_id)
|
|
55
|
+
except ValueError:
|
|
56
|
+
await websocket_manager.send_personal_message(
|
|
57
|
+
json.dumps({"error": "Invalid task_id format"}),
|
|
58
|
+
client_id
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif message_type == "unsubscribe":
|
|
62
|
+
# 客户端取消订阅任务
|
|
63
|
+
task_id_str = message.get("task_id")
|
|
64
|
+
if not task_id_str:
|
|
65
|
+
await websocket_manager.send_personal_message(
|
|
66
|
+
json.dumps({"error": "Missing task_id in unsubscribe message"}),
|
|
67
|
+
client_id
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
task_id = uuid.UUID(task_id_str)
|
|
73
|
+
success = await websocket_manager.unsubscribe_from_task(client_id, task_id)
|
|
74
|
+
response = {"success": success}
|
|
75
|
+
await websocket_manager.send_personal_message(json.dumps(response), client_id)
|
|
76
|
+
except ValueError:
|
|
77
|
+
await websocket_manager.send_personal_message(
|
|
78
|
+
json.dumps({"error": "Invalid task_id format"}),
|
|
79
|
+
client_id
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
await websocket_manager.send_personal_message(
|
|
84
|
+
json.dumps({"error": f"Unknown message type: {message_type}"}),
|
|
85
|
+
client_id
|
|
86
|
+
)
|