service-forge 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (75) hide show
  1. service_forge/api/deprecated_websocket_api.py +86 -0
  2. service_forge/api/deprecated_websocket_manager.py +425 -0
  3. service_forge/api/http_api.py +148 -0
  4. service_forge/api/http_api_doc.py +455 -0
  5. service_forge/api/kafka_api.py +126 -0
  6. service_forge/api/routers/service/__init__.py +4 -0
  7. service_forge/api/routers/service/service_router.py +137 -0
  8. service_forge/api/routers/websocket/websocket_manager.py +83 -0
  9. service_forge/api/routers/websocket/websocket_router.py +78 -0
  10. service_forge/api/task_manager.py +141 -0
  11. service_forge/db/__init__.py +1 -0
  12. service_forge/db/database.py +240 -0
  13. service_forge/llm/__init__.py +62 -0
  14. service_forge/llm/llm.py +56 -0
  15. service_forge/model/__init__.py +0 -0
  16. service_forge/model/websocket.py +13 -0
  17. service_forge/proto/foo_input.py +5 -0
  18. service_forge/service.py +288 -0
  19. service_forge/service_config.py +158 -0
  20. service_forge/sft/cli.py +91 -0
  21. service_forge/sft/cmd/config_command.py +67 -0
  22. service_forge/sft/cmd/deploy_service.py +123 -0
  23. service_forge/sft/cmd/list_tars.py +41 -0
  24. service_forge/sft/cmd/service_command.py +149 -0
  25. service_forge/sft/cmd/upload_service.py +36 -0
  26. service_forge/sft/config/injector.py +119 -0
  27. service_forge/sft/config/injector_default_files.py +131 -0
  28. service_forge/sft/config/sf_metadata.py +30 -0
  29. service_forge/sft/config/sft_config.py +153 -0
  30. service_forge/sft/file/__init__.py +0 -0
  31. service_forge/sft/file/ignore_pattern.py +80 -0
  32. service_forge/sft/file/sft_file_manager.py +107 -0
  33. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  34. service_forge/sft/util/assert_util.py +25 -0
  35. service_forge/sft/util/logger.py +16 -0
  36. service_forge/sft/util/name_util.py +8 -0
  37. service_forge/sft/util/yaml_utils.py +57 -0
  38. service_forge/utils/__init__.py +0 -0
  39. service_forge/utils/default_type_converter.py +12 -0
  40. service_forge/utils/register.py +39 -0
  41. service_forge/utils/type_converter.py +99 -0
  42. service_forge/utils/workflow_clone.py +124 -0
  43. service_forge/workflow/__init__.py +1 -0
  44. service_forge/workflow/context.py +14 -0
  45. service_forge/workflow/edge.py +24 -0
  46. service_forge/workflow/node.py +184 -0
  47. service_forge/workflow/nodes/__init__.py +8 -0
  48. service_forge/workflow/nodes/control/if_node.py +29 -0
  49. service_forge/workflow/nodes/control/switch_node.py +28 -0
  50. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  51. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  52. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  53. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  54. service_forge/workflow/nodes/output/print_node.py +29 -0
  55. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  56. service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
  57. service_forge/workflow/port.py +89 -0
  58. service_forge/workflow/trigger.py +24 -0
  59. service_forge/workflow/triggers/__init__.py +6 -0
  60. service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
  61. service_forge/workflow/triggers/fast_api_trigger.py +169 -0
  62. service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
  63. service_forge/workflow/triggers/once_trigger.py +20 -0
  64. service_forge/workflow/triggers/period_trigger.py +26 -0
  65. service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
  66. service_forge/workflow/workflow.py +210 -0
  67. service_forge/workflow/workflow_callback.py +141 -0
  68. service_forge/workflow/workflow_event.py +15 -0
  69. service_forge/workflow/workflow_factory.py +246 -0
  70. service_forge/workflow/workflow_group.py +27 -0
  71. service_forge/workflow/workflow_type.py +52 -0
  72. service_forge-0.1.11.dist-info/METADATA +98 -0
  73. service_forge-0.1.11.dist-info/RECORD +75 -0
  74. service_forge-0.1.11.dist-info/WHEEL +4 -0
  75. service_forge-0.1.11.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,137 @@
1
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
2
+ from fastapi.responses import JSONResponse
3
+ from loguru import logger
4
+ from typing import Optional, TYPE_CHECKING
5
+ import tempfile
6
+ import os
7
+ from pydantic import BaseModel
8
+ from omegaconf import OmegaConf
9
+
10
+ if TYPE_CHECKING:
11
+ from service_forge.service import Service
12
+
13
+ # TODO: refactor this, do not use global variable
14
+ _current_service: Optional['Service'] = None
15
+
16
+ def set_service(service: 'Service') -> None:
17
+ global _current_service
18
+ _current_service = service
19
+
20
+ def get_service() -> Optional['Service']:
21
+ return _current_service
22
+
23
+ service_router = APIRouter(prefix="/sdk/service", tags=["service"])
24
+
25
+ class WorkflowStatusResponse(BaseModel):
26
+ name: str
27
+ version: str
28
+ description: str
29
+ workflows: list[dict]
30
+
31
+ class WorkflowActionResponse(BaseModel):
32
+ success: bool
33
+ message: str
34
+
35
+ @service_router.get("/status", response_model=WorkflowStatusResponse)
36
+ async def get_service_status():
37
+ service = get_service()
38
+ if service is None:
39
+ raise HTTPException(status_code=503, detail="Service not initialized")
40
+
41
+ try:
42
+ status = service.get_service_status()
43
+ return status
44
+ except Exception as e:
45
+ logger.error(f"Error getting service status: {e}")
46
+ raise HTTPException(status_code=500, detail=str(e))
47
+
48
+ @service_router.post("/workflow/{workflow_name}/start", response_model=WorkflowActionResponse)
49
+ async def start_workflow(workflow_name: str):
50
+ service = get_service()
51
+ if service is None:
52
+ raise HTTPException(status_code=503, detail="Service not initialized")
53
+
54
+ try:
55
+ success = await service.start_workflow(workflow_name)
56
+ if success:
57
+ return WorkflowActionResponse(success=True, message=f"Workflow {workflow_name} started successfully")
58
+ else:
59
+ return WorkflowActionResponse(success=False, message=f"Failed to start workflow {workflow_name}")
60
+ except Exception as e:
61
+ logger.error(f"Error starting workflow {workflow_name}: {e}")
62
+ raise HTTPException(status_code=500, detail=str(e))
63
+
64
+ @service_router.post("/workflow/{workflow_name}/stop", response_model=WorkflowActionResponse)
65
+ async def stop_workflow(workflow_name: str):
66
+ service = get_service()
67
+ if service is None:
68
+ raise HTTPException(status_code=503, detail="Service not initialized")
69
+
70
+ try:
71
+ success = await service.stop_workflow(workflow_name)
72
+ if success:
73
+ return WorkflowActionResponse(success=True, message=f"Workflow {workflow_name} stopped successfully")
74
+ else:
75
+ return WorkflowActionResponse(success=False, message=f"Failed to stop workflow {workflow_name}")
76
+ except Exception as e:
77
+ logger.error(f"Error stopping workflow {workflow_name}: {e}")
78
+ raise HTTPException(status_code=500, detail=str(e))
79
+
80
+ @service_router.post("/workflow/upload", response_model=WorkflowActionResponse)
81
+ async def upload_workflow_config(
82
+ file: Optional[UploadFile] = File(None),
83
+ config_content: Optional[str] = Form(None),
84
+ workflow_name: Optional[str] = Form(None)
85
+ ):
86
+ service = get_service()
87
+ if service is None:
88
+ raise HTTPException(status_code=503, detail="Service not initialized")
89
+
90
+ if file is None and config_content is None:
91
+ raise HTTPException(status_code=400, detail="Either file or config_content must be provided")
92
+
93
+ if file is not None and config_content is not None:
94
+ raise HTTPException(status_code=400, detail="Cannot provide both file and config_content")
95
+
96
+ temp_file_path = None
97
+ try:
98
+ if file is not None:
99
+ if not file.filename or not file.filename.endswith(('.yaml', '.yml')):
100
+ raise HTTPException(status_code=400, detail="Only YAML files are supported")
101
+
102
+ suffix = '.yaml' if file.filename.endswith('.yaml') else '.yml'
103
+ with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix=suffix) as temp_file:
104
+ content = await file.read()
105
+ temp_file.write(content)
106
+ temp_file_path = temp_file.name
107
+
108
+ success = await service.load_workflow_from_config(config_path=temp_file_path, workflow_name=workflow_name)
109
+ else:
110
+ try:
111
+ config = OmegaConf.to_object(OmegaConf.create(config_content))
112
+ except Exception as e:
113
+ raise HTTPException(status_code=400, detail=f"Invalid YAML format: {str(e)}")
114
+
115
+ success = await service.load_workflow_from_config(config=config, workflow_name=workflow_name)
116
+
117
+ if success:
118
+ return WorkflowActionResponse(
119
+ success=True,
120
+ message=f"Workflow configuration uploaded and loaded successfully"
121
+ )
122
+ else:
123
+ raise HTTPException(status_code=500, detail="Failed to load workflow configuration")
124
+
125
+ except HTTPException:
126
+ raise
127
+ except Exception as e:
128
+ logger.error(f"Error uploading workflow config: {e}")
129
+ raise HTTPException(status_code=500, detail=str(e))
130
+
131
+ finally:
132
+ if temp_file_path and os.path.exists(temp_file_path):
133
+ try:
134
+ os.unlink(temp_file_path)
135
+ except Exception as e:
136
+ logger.warning(f"Failed to delete temp file {temp_file_path}: {e}")
137
+
@@ -0,0 +1,83 @@
1
+ from fastapi import WebSocket
2
+ from typing import Dict, Set
3
+ import uuid
4
+ import json
5
+ import asyncio
6
+ from loguru import logger
7
+
8
+ class WebSocketManager:
9
+ def __init__(self):
10
+ self.task_connections: Dict[uuid.UUID, Set[WebSocket]] = {}
11
+ self.websocket_tasks: Dict[WebSocket, Set[uuid.UUID]] = {}
12
+ self.all_task_subscribers: Set[WebSocket] = set()
13
+ self.websocket_subscribes_all: Dict[WebSocket, bool] = {}
14
+ self._lock = asyncio.Lock()
15
+
16
+ async def subscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
17
+ async with self._lock:
18
+ if task_id is None:
19
+ self.all_task_subscribers.add(websocket)
20
+ self.websocket_subscribes_all[websocket] = True
21
+ else:
22
+ if task_id not in self.task_connections:
23
+ self.task_connections[task_id] = set()
24
+ self.task_connections[task_id].add(websocket)
25
+
26
+ if websocket not in self.websocket_tasks:
27
+ self.websocket_tasks[websocket] = set()
28
+ self.websocket_tasks[websocket].add(task_id)
29
+ return True
30
+
31
+ async def unsubscribe(self, websocket: WebSocket, task_id: uuid.UUID | None) -> bool:
32
+ async with self._lock:
33
+ if task_id is None:
34
+ self.all_task_subscribers.discard(websocket)
35
+ self.websocket_subscribes_all.pop(websocket, None)
36
+ else:
37
+ if task_id in self.task_connections:
38
+ self.task_connections[task_id].discard(websocket)
39
+ if not self.task_connections[task_id]:
40
+ del self.task_connections[task_id]
41
+
42
+ if websocket in self.websocket_tasks:
43
+ self.websocket_tasks[websocket].discard(task_id)
44
+ if not self.websocket_tasks[websocket]:
45
+ del self.websocket_tasks[websocket]
46
+ return True
47
+
48
+ async def disconnect(self, websocket: WebSocket):
49
+ async with self._lock:
50
+ if websocket in self.websocket_tasks:
51
+ task_ids = self.websocket_tasks[websocket].copy()
52
+ for task_id in task_ids:
53
+ if task_id in self.task_connections:
54
+ self.task_connections[task_id].discard(websocket)
55
+ if not self.task_connections[task_id]:
56
+ del self.task_connections[task_id]
57
+ del self.websocket_tasks[websocket]
58
+
59
+ self.all_task_subscribers.discard(websocket)
60
+ self.websocket_subscribes_all.pop(websocket, None)
61
+
62
+ async def send_to_task(self, task_id: uuid.UUID, message: dict):
63
+ async with self._lock:
64
+ specific_subscribers = list(self.task_connections.get(task_id, set()))
65
+ all_subscribers = list(self.all_task_subscribers)
66
+ websockets = list(set(specific_subscribers + all_subscribers))
67
+
68
+ disconnected = set()
69
+ message_str = json.dumps(message)
70
+
71
+ # logger.debug(f"向 {len(websockets)} 个 websocket 连接发送消息 (task_id: {task_id})")
72
+
73
+ for websocket in websockets:
74
+ try:
75
+ await websocket.send_text(message_str)
76
+ except Exception as e:
77
+ logger.error(f"向 task_id {task_id} 的 websocket 发送消息失败: {e}")
78
+ disconnected.add(websocket)
79
+
80
+ for ws in disconnected:
81
+ await self.disconnect(ws)
82
+
83
+ websocket_manager = WebSocketManager()
@@ -0,0 +1,78 @@
1
+ from fastapi import WebSocket, WebSocketDisconnect
2
+ from fastapi.routing import APIRouter
3
+ from loguru import logger
4
+ import json
5
+ import uuid
6
+ from .websocket_manager import websocket_manager
7
+
8
+ websocket_router = APIRouter()
9
+
10
+ @websocket_router.websocket("/sdk/ws")
11
+ async def sdk_websocket_endpoint(websocket: WebSocket):
12
+ await websocket.accept()
13
+ try:
14
+ while True:
15
+ data = await websocket.receive_text()
16
+ try:
17
+ message = json.loads(data)
18
+ message_type = message.get("type")
19
+
20
+ if message_type == "subscribe":
21
+ task_id_str = message.get("task_id")
22
+ if not task_id_str:
23
+ await websocket.send_text(
24
+ json.dumps({"error": "Missing task_id in subscribe message"})
25
+ )
26
+ continue
27
+
28
+ if task_id_str.lower() == "all":
29
+ success = await websocket_manager.subscribe(websocket, None)
30
+ response = {"success": success, "type": "subscribe_response", "task_id": "all"}
31
+ await websocket.send_text(json.dumps(response))
32
+ else:
33
+ try:
34
+ task_id = uuid.UUID(task_id_str)
35
+ success = await websocket_manager.subscribe(websocket, task_id)
36
+ response = {"success": success, "type": "subscribe_response", "task_id": task_id_str}
37
+ await websocket.send_text(json.dumps(response))
38
+ except ValueError:
39
+ await websocket.send_text(
40
+ json.dumps({"error": "Invalid task_id format"})
41
+ )
42
+ elif message_type == "unsubscribe":
43
+ task_id_str = message.get("task_id")
44
+ if not task_id_str:
45
+ await websocket.send_text(
46
+ json.dumps({"error": "Missing task_id in unsubscribe message"})
47
+ )
48
+ continue
49
+
50
+ if task_id_str.lower() == "all":
51
+ success = await websocket_manager.unsubscribe(websocket, None)
52
+ response = {"success": success, "type": "unsubscribe_response", "task_id": "all"}
53
+ await websocket.send_text(json.dumps(response))
54
+ else:
55
+ try:
56
+ task_id = uuid.UUID(task_id_str)
57
+ success = await websocket_manager.unsubscribe(websocket, task_id)
58
+ response = {"success": success, "type": "unsubscribe_response", "task_id": task_id_str}
59
+ await websocket.send_text(json.dumps(response))
60
+ except ValueError:
61
+ await websocket.send_text(
62
+ json.dumps({"error": "Invalid task_id format"})
63
+ )
64
+ else:
65
+ await websocket.send_text(
66
+ json.dumps({"error": f"Unknown message type: {message_type}"})
67
+ )
68
+ except json.JSONDecodeError:
69
+ logger.error(f"收到无效JSON消息: {data}")
70
+ await websocket.send_text(
71
+ json.dumps({"error": "Invalid JSON format"})
72
+ )
73
+ except WebSocketDisconnect:
74
+ await websocket_manager.disconnect(websocket)
75
+ except Exception as e:
76
+ logger.error(f"SDK WebSocket连接处理异常: {e}")
77
+ await websocket_manager.disconnect(websocket)
78
+
@@ -0,0 +1,141 @@
1
+
2
+ from __future__ import annotations
3
+ import asyncio
4
+ import uuid
5
+ import datetime
6
+ from typing import Dict, List, Set, Any, Optional
7
+
8
+ class TaskManager:
9
+ """任务管理器,用于跟踪任务状态和队列信息"""
10
+
11
+ def __init__(self):
12
+ # 存储所有任务信息: {task_id: task_info}
13
+ self.tasks: Dict[uuid.UUID, Dict[str, Any]] = {}
14
+ # 任务队列,按添加顺序排列
15
+ self.task_queue: List[uuid.UUID] = []
16
+ # 正在执行的任务ID集合
17
+ self.running_tasks: Set[uuid.UUID] = set()
18
+ # 已完成的任务ID集合
19
+ self.completed_tasks: Set[uuid.UUID] = set()
20
+ # 客户端与任务的映射: {client_id: set(task_id)}
21
+ self.client_tasks: Dict[str, Set[uuid.UUID]] = {}
22
+
23
+ def add_task(self, task_id: uuid.UUID, client_id: str, workflow_name: str, steps: int) -> Dict[str, Any]:
24
+ """添加新任务到队列"""
25
+ current_time = asyncio.get_event_loop().time()
26
+ task_info = {
27
+ "task_id": task_id,
28
+ "client_id": client_id,
29
+ "workflow_name": workflow_name,
30
+ "steps": steps,
31
+ "current_step": 0, # 当前步骤,从0开始
32
+ "status": "pending", # pending, running, completed, failed
33
+ "created_at": datetime.datetime.fromtimestamp(current_time).isoformat(),
34
+ "queue_position": len(self.task_queue) + 1
35
+ }
36
+ self.tasks[task_id] = task_info
37
+ self.task_queue.append(task_id)
38
+
39
+ # 更新客户端任务映射
40
+ if client_id not in self.client_tasks:
41
+ self.client_tasks[client_id] = set()
42
+ self.client_tasks[client_id].add(task_id)
43
+
44
+ return task_info
45
+
46
+ def start_task(self, task_id: uuid.UUID) -> bool:
47
+ """标记任务开始执行"""
48
+ if task_id not in self.tasks:
49
+ return False
50
+
51
+ current_time = asyncio.get_event_loop().time()
52
+ self.tasks[task_id]["status"] = "running"
53
+ self.tasks[task_id]["started_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
54
+ self.tasks[task_id]["current_step"] = 1 # 开始执行第一步
55
+ self.running_tasks.add(task_id)
56
+
57
+ # 从队列中移除
58
+ if task_id in self.task_queue:
59
+ self.task_queue.remove(task_id)
60
+
61
+ # 更新队列中所有任务的位置
62
+ for i, q_task_id in enumerate(self.task_queue):
63
+ self.tasks[q_task_id]["queue_position"] = i + 1
64
+
65
+ return True
66
+
67
+ def complete_task(self, task_id: uuid.UUID) -> bool:
68
+ """标记任务完成"""
69
+ if task_id not in self.tasks:
70
+ return False
71
+
72
+ current_time = asyncio.get_event_loop().time()
73
+ self.tasks[task_id]["status"] = "completed"
74
+ self.tasks[task_id]["completed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
75
+ self.tasks[task_id]["current_step"] = self.tasks[task_id]["steps"] # 完成所有步骤
76
+ self.running_tasks.discard(task_id)
77
+ self.completed_tasks.add(task_id)
78
+
79
+ return True
80
+
81
+ def fail_task(self, task_id: uuid.UUID, error: str = None) -> bool:
82
+ """标记任务失败"""
83
+ if task_id not in self.tasks:
84
+ return False
85
+
86
+ current_time = asyncio.get_event_loop().time()
87
+ self.tasks[task_id]["status"] = "failed"
88
+ self.tasks[task_id]["failed_at"] = datetime.datetime.fromtimestamp(current_time).isoformat()
89
+ if error:
90
+ self.tasks[task_id]["error"] = error
91
+ self.running_tasks.discard(task_id)
92
+
93
+ return True
94
+
95
+ def get_client_tasks(self, client_id: str) -> List[Dict[str, Any]]:
96
+ """获取客户端的所有任务"""
97
+ if client_id not in self.client_tasks:
98
+ return []
99
+
100
+ return [
101
+ self.tasks[task_id]
102
+ for task_id in self.client_tasks[client_id]
103
+ if task_id in self.tasks
104
+ ]
105
+
106
+ def get_queue_position(self, task_id: uuid.UUID) -> int:
107
+ """获取任务在队列中的位置,从1开始计数,如果不在队列中返回-1"""
108
+ if task_id not in self.tasks:
109
+ return -1
110
+
111
+ return self.tasks[task_id].get("queue_position", -1)
112
+
113
+ def get_global_queue_info(self) -> Dict[str, int]:
114
+ """获取全局队列信息"""
115
+ return {
116
+ "total": len(self.running_tasks) + len(self.task_queue),
117
+ "waiting": len(self.task_queue),
118
+ "running": len(self.running_tasks),
119
+ }
120
+
121
+ def get_task_info(self, task_id: uuid.UUID) -> Optional[Dict[str, Any]]:
122
+ """获取特定任务的详细信息"""
123
+ task_info = self.tasks.get(task_id)
124
+ if task_info:
125
+ # 创建任务信息的副本,并将UUID转换为字符串
126
+ task_copy = task_info.copy()
127
+ task_copy["task_id"] = str(task_id)
128
+ return task_copy
129
+ return None
130
+
131
+ def update_current_step(self, task_id: uuid.UUID, step: int) -> bool:
132
+ """更新当前任务的步骤"""
133
+ if task_id not in self.tasks:
134
+ return False
135
+
136
+ # 确保步骤在有效范围内
137
+ if step < 0 or step > self.tasks[task_id]["steps"]:
138
+ return False
139
+
140
+ self.tasks[task_id]["current_step"] = step
141
+ return True
@@ -0,0 +1 @@
1
+ from .database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
@@ -0,0 +1,240 @@
1
+ from __future__ import annotations
2
+
3
+ import redis
4
+ import pymongo
5
+ import psycopg2
6
+ from typing import AsyncGenerator
7
+ from loguru import logger
8
+ from omegaconf import OmegaConf
9
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
10
+ from sqlalchemy import text
11
+
12
+ class PostgresDatabase:
13
+ def __init__(
14
+ self,
15
+ name: str,
16
+ postgres_user: str,
17
+ postgres_password: str,
18
+ postgres_host: str,
19
+ postgres_port: int,
20
+ postgres_db: str,
21
+ ) -> None:
22
+ self.name = name
23
+ self.postgres_user = postgres_user
24
+ self.postgres_password = postgres_password
25
+ self.postgres_host = postgres_host
26
+ self.postgres_port = postgres_port
27
+ self.postgres_db = postgres_db
28
+ self.engine = None
29
+ self.session_factory = None
30
+ self.test_connection()
31
+
32
+ @property
33
+ def database_url(self) -> str:
34
+ return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
35
+
36
+ @property
37
+ def database_base_url(self) -> str:
38
+ return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/postgres"
39
+
40
+ async def init(self) -> None:
41
+ if self.engine is None:
42
+ self.engine = await self.create_engine()
43
+ self.session_factory = async_sessionmaker(bind=self.engine, class_=AsyncSession, expire_on_commit=False)
44
+
45
+ async def close(self) -> None:
46
+ if self.engine:
47
+ await self.engine.dispose()
48
+ self.engine = None
49
+ self.session_factory = None
50
+ logger.info("Database connection closed")
51
+
52
+ async def create_engine(self) -> AsyncEngine:
53
+ if not all([self.postgres_user, self.postgres_host, self.postgres_port, self.postgres_db]):
54
+ raise ValueError("Missing required database configuration. Please check your .env file or configuration.")
55
+ logger.info(f"Creating database engine: {self.database_url}")
56
+ return create_async_engine(self.database_url)
57
+
58
+ async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
59
+ if self.session_factory is None:
60
+ await self.init()
61
+
62
+ if self.session_factory is None:
63
+ raise RuntimeError("Session factory is not initialized")
64
+
65
+ async with self.session_factory() as session:
66
+ try:
67
+ yield session
68
+ except Exception:
69
+ await session.rollback()
70
+ raise
71
+ finally:
72
+ await session.close()
73
+ yield session
74
+
75
+ async def get_session_factory(self) -> async_sessionmaker[AsyncSession]:
76
+ if self.engine is None:
77
+ await self.init()
78
+
79
+ if self.session_factory is None:
80
+ raise RuntimeError("Session factory is not initialized")
81
+
82
+ return self.session_factory
83
+
84
+ def test_connection(self) -> bool:
85
+ try:
86
+ conn = psycopg2.connect(
87
+ host=self.postgres_host,
88
+ port=self.postgres_port,
89
+ user=self.postgres_user,
90
+ password=self.postgres_password,
91
+ database=self.postgres_db,
92
+ connect_timeout=5
93
+ )
94
+ conn.close()
95
+ logger.info(f"PostgreSQL connection test successful for database '{self.name}'")
96
+ return True
97
+ except Exception as e:
98
+ logger.warning(f"PostgreSQL connection test failed for database '{self.name}': {e}")
99
+ return False
100
+
101
+ class MongoDatabase:
102
+ def __init__(
103
+ self,
104
+ name: str,
105
+ mongo_host: str,
106
+ mongo_port: int,
107
+ mongo_user: str,
108
+ mongo_password: str,
109
+ mongo_db: str,
110
+ ) -> None:
111
+ self.name = name
112
+ self.mongo_host = mongo_host
113
+ self.mongo_port = mongo_port
114
+ self.mongo_user = mongo_user
115
+ self.mongo_password = mongo_password
116
+ self.mongo_db = mongo_db or ""
117
+ self.client = pymongo.MongoClient(self.database_url)
118
+ self.test_connection()
119
+
120
+ @property
121
+ def database_url(self) -> str:
122
+ return f"mongodb://{self.mongo_user}:{self.mongo_password}@{self.mongo_host}:{self.mongo_port}/{self.mongo_db}"
123
+
124
+ def test_connection(self) -> bool:
125
+ try:
126
+ self.client.admin.command('ping')
127
+ logger.info(f"MongoDB connection test successful for database '{self.name}'")
128
+ return True
129
+ except Exception as e:
130
+ logger.error(f"MongoDB connection test failed for database '{self.name}': {e}")
131
+ return False
132
+
133
+ class RedisDatabase:
134
+ def __init__(
135
+ self,
136
+ name: str,
137
+ redis_host: str,
138
+ redis_port: int,
139
+ redis_password: str,
140
+ ) -> None:
141
+ self.name = name
142
+ self.redis_host = redis_host
143
+ self.redis_port = redis_port
144
+ self.redis_password = redis_password
145
+ self.client = redis.Redis(host=redis_host, port=redis_port, password=redis_password)
146
+ self.test_connection()
147
+
148
+ def test_connection(self) -> bool:
149
+ try:
150
+ self.client.ping()
151
+ logger.info(f"Redis connection test successful for database '{self.name}'")
152
+ return True
153
+ except Exception as e:
154
+ logger.error(f"Redis connection test failed for database '{self.name}': {e}")
155
+ return False
156
+
157
+
158
+ class DatabaseManager:
159
+ def __init__(
160
+ self,
161
+ postgres_databases: list[PostgresDatabase],
162
+ mongo_databases: list[MongoDatabase],
163
+ redis_databases: list[RedisDatabase],
164
+ ) -> None:
165
+ self.postgres_databases = postgres_databases
166
+ self.mongo_databases = mongo_databases
167
+ self.redis_databases = redis_databases
168
+
169
+ def get_database(self, name: str) -> PostgresDatabase | MongoDatabase | RedisDatabase | None:
170
+ for database in self.postgres_databases:
171
+ if database.name == name:
172
+ return database
173
+ return None
174
+
175
+ def get_default_postgres_database(self) -> PostgresDatabase | None:
176
+ if len(self.postgres_databases) > 0:
177
+ return self.postgres_databases[0]
178
+ return None
179
+
180
+ def get_default_mongo_database(self) -> MongoDatabase | None:
181
+ if len(self.mongo_databases) > 0:
182
+ return self.mongo_databases[0]
183
+ return None
184
+
185
+ def get_default_redis_database(self) -> RedisDatabase | None:
186
+ if len(self.redis_databases) > 0:
187
+ return self.redis_databases[0]
188
+ return None
189
+
190
+ @staticmethod
191
+ def from_config(config_path: str = None, config = None) -> DatabaseManager:
192
+ if config is None:
193
+ config = OmegaConf.to_object(OmegaConf.load(config_path))
194
+
195
+ databases_config = config.get('databases', None)
196
+ postgres_databases = []
197
+ mongo_databases = []
198
+ redis_databases = []
199
+ if databases_config is not None:
200
+ for database_config in databases_config:
201
+ if ('postgres_host' in database_config and database_config['postgres_host'] is not None) + \
202
+ ('mongo_host' in database_config and database_config['mongo_host'] is not None) + \
203
+ ('redis_host' in database_config and database_config['redis_host'] is not None) == 0:
204
+ raise ValueError(f"Database '{database_config['name']}' is missing required configuration. Please check your service.yaml file.")
205
+
206
+ if ('postgres_host' in database_config and database_config['postgres_host'] is not None) + \
207
+ ('mongo_host' in database_config and database_config['mongo_host'] is not None) + \
208
+ ('redis_host' in database_config and database_config['redis_host'] is not None) > 1:
209
+ raise ValueError(f"Database '{database_config['name']}' has multiple host configurations. Please check your service.yaml file.")
210
+
211
+ if 'postgres_host' in database_config and database_config['postgres_host'] is not None:
212
+ postgres_databases.append(PostgresDatabase(
213
+ name=database_config['name'],
214
+ postgres_user=database_config['postgres_user'],
215
+ postgres_password=database_config['postgres_password'],
216
+ postgres_host=database_config['postgres_host'],
217
+ postgres_port=database_config['postgres_port'],
218
+ postgres_db=database_config['postgres_db'],
219
+ ))
220
+ if 'mongo_host' in database_config and database_config['mongo_host'] is not None:
221
+ mongo_databases.append(MongoDatabase(
222
+ name=database_config['name'],
223
+ mongo_host=database_config['mongo_host'],
224
+ mongo_port=database_config['mongo_port'],
225
+ mongo_user=database_config['mongo_user'],
226
+ mongo_password=database_config['mongo_password'],
227
+ mongo_db=database_config['mongo_db'],
228
+ ))
229
+ if 'redis_host' in database_config and database_config['redis_host'] is not None:
230
+ redis_databases.append(RedisDatabase(
231
+ name=database_config['name'],
232
+ redis_host=database_config['redis_host'],
233
+ redis_port=database_config['redis_port'],
234
+ redis_password=database_config['redis_password'],
235
+ ))
236
+
237
+ return DatabaseManager(postgres_databases=postgres_databases, mongo_databases=mongo_databases, redis_databases=redis_databases)
238
+
239
+ def create_database_manager(config_path: str = None, config = None) -> DatabaseManager:
240
+ return DatabaseManager.from_config(config_path, config)