aiecs 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiecs/__init__.py +72 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +469 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +363 -0
- aiecs/application/knowledge_graph/__init__.py +7 -0
- aiecs/application/knowledge_graph/builder/__init__.py +37 -0
- aiecs/application/knowledge_graph/builder/document_builder.py +375 -0
- aiecs/application/knowledge_graph/builder/graph_builder.py +356 -0
- aiecs/application/knowledge_graph/builder/schema_mapping.py +531 -0
- aiecs/application/knowledge_graph/builder/structured_pipeline.py +443 -0
- aiecs/application/knowledge_graph/builder/text_chunker.py +319 -0
- aiecs/application/knowledge_graph/extractors/__init__.py +27 -0
- aiecs/application/knowledge_graph/extractors/base.py +100 -0
- aiecs/application/knowledge_graph/extractors/llm_entity_extractor.py +327 -0
- aiecs/application/knowledge_graph/extractors/llm_relation_extractor.py +349 -0
- aiecs/application/knowledge_graph/extractors/ner_entity_extractor.py +244 -0
- aiecs/application/knowledge_graph/fusion/__init__.py +23 -0
- aiecs/application/knowledge_graph/fusion/entity_deduplicator.py +387 -0
- aiecs/application/knowledge_graph/fusion/entity_linker.py +343 -0
- aiecs/application/knowledge_graph/fusion/knowledge_fusion.py +580 -0
- aiecs/application/knowledge_graph/fusion/relation_deduplicator.py +189 -0
- aiecs/application/knowledge_graph/pattern_matching/__init__.py +21 -0
- aiecs/application/knowledge_graph/pattern_matching/pattern_matcher.py +344 -0
- aiecs/application/knowledge_graph/pattern_matching/query_executor.py +378 -0
- aiecs/application/knowledge_graph/profiling/__init__.py +12 -0
- aiecs/application/knowledge_graph/profiling/query_plan_visualizer.py +199 -0
- aiecs/application/knowledge_graph/profiling/query_profiler.py +223 -0
- aiecs/application/knowledge_graph/reasoning/__init__.py +27 -0
- aiecs/application/knowledge_graph/reasoning/evidence_synthesis.py +347 -0
- aiecs/application/knowledge_graph/reasoning/inference_engine.py +504 -0
- aiecs/application/knowledge_graph/reasoning/logic_form_parser.py +167 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/__init__.py +79 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_builder.py +513 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_nodes.py +630 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/ast_validator.py +654 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/error_handler.py +477 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/parser.py +390 -0
- aiecs/application/knowledge_graph/reasoning/logic_parser/query_context.py +217 -0
- aiecs/application/knowledge_graph/reasoning/logic_query_integration.py +169 -0
- aiecs/application/knowledge_graph/reasoning/query_planner.py +872 -0
- aiecs/application/knowledge_graph/reasoning/reasoning_engine.py +554 -0
- aiecs/application/knowledge_graph/retrieval/__init__.py +19 -0
- aiecs/application/knowledge_graph/retrieval/retrieval_strategies.py +596 -0
- aiecs/application/knowledge_graph/search/__init__.py +59 -0
- aiecs/application/knowledge_graph/search/hybrid_search.py +423 -0
- aiecs/application/knowledge_graph/search/reranker.py +295 -0
- aiecs/application/knowledge_graph/search/reranker_strategies.py +553 -0
- aiecs/application/knowledge_graph/search/text_similarity.py +398 -0
- aiecs/application/knowledge_graph/traversal/__init__.py +15 -0
- aiecs/application/knowledge_graph/traversal/enhanced_traversal.py +329 -0
- aiecs/application/knowledge_graph/traversal/path_scorer.py +269 -0
- aiecs/application/knowledge_graph/validators/__init__.py +13 -0
- aiecs/application/knowledge_graph/validators/relation_validator.py +189 -0
- aiecs/application/knowledge_graph/visualization/__init__.py +11 -0
- aiecs/application/knowledge_graph/visualization/graph_visualizer.py +321 -0
- aiecs/common/__init__.py +9 -0
- aiecs/common/knowledge_graph/__init__.py +17 -0
- aiecs/common/knowledge_graph/runnable.py +484 -0
- aiecs/config/__init__.py +16 -0
- aiecs/config/config.py +498 -0
- aiecs/config/graph_config.py +137 -0
- aiecs/config/registry.py +23 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +152 -0
- aiecs/core/interface/storage_interface.py +171 -0
- aiecs/domain/__init__.py +289 -0
- aiecs/domain/agent/__init__.py +189 -0
- aiecs/domain/agent/base_agent.py +697 -0
- aiecs/domain/agent/exceptions.py +103 -0
- aiecs/domain/agent/graph_aware_mixin.py +559 -0
- aiecs/domain/agent/hybrid_agent.py +490 -0
- aiecs/domain/agent/integration/__init__.py +26 -0
- aiecs/domain/agent/integration/context_compressor.py +222 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +252 -0
- aiecs/domain/agent/integration/retry_policy.py +219 -0
- aiecs/domain/agent/integration/role_config.py +213 -0
- aiecs/domain/agent/knowledge_aware_agent.py +646 -0
- aiecs/domain/agent/lifecycle.py +296 -0
- aiecs/domain/agent/llm_agent.py +300 -0
- aiecs/domain/agent/memory/__init__.py +12 -0
- aiecs/domain/agent/memory/conversation.py +197 -0
- aiecs/domain/agent/migration/__init__.py +14 -0
- aiecs/domain/agent/migration/conversion.py +160 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +90 -0
- aiecs/domain/agent/models.py +317 -0
- aiecs/domain/agent/observability.py +407 -0
- aiecs/domain/agent/persistence.py +289 -0
- aiecs/domain/agent/prompts/__init__.py +29 -0
- aiecs/domain/agent/prompts/builder.py +161 -0
- aiecs/domain/agent/prompts/formatters.py +189 -0
- aiecs/domain/agent/prompts/template.py +255 -0
- aiecs/domain/agent/registry.py +260 -0
- aiecs/domain/agent/tool_agent.py +257 -0
- aiecs/domain/agent/tools/__init__.py +12 -0
- aiecs/domain/agent/tools/schema_generator.py +221 -0
- aiecs/domain/community/__init__.py +155 -0
- aiecs/domain/community/agent_adapter.py +477 -0
- aiecs/domain/community/analytics.py +481 -0
- aiecs/domain/community/collaborative_workflow.py +642 -0
- aiecs/domain/community/communication_hub.py +645 -0
- aiecs/domain/community/community_builder.py +320 -0
- aiecs/domain/community/community_integration.py +800 -0
- aiecs/domain/community/community_manager.py +813 -0
- aiecs/domain/community/decision_engine.py +879 -0
- aiecs/domain/community/exceptions.py +225 -0
- aiecs/domain/community/models/__init__.py +33 -0
- aiecs/domain/community/models/community_models.py +268 -0
- aiecs/domain/community/resource_manager.py +457 -0
- aiecs/domain/community/shared_context_manager.py +603 -0
- aiecs/domain/context/__init__.py +58 -0
- aiecs/domain/context/context_engine.py +989 -0
- aiecs/domain/context/conversation_models.py +354 -0
- aiecs/domain/context/graph_memory.py +467 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +57 -0
- aiecs/domain/knowledge_graph/__init__.py +19 -0
- aiecs/domain/knowledge_graph/models/__init__.py +52 -0
- aiecs/domain/knowledge_graph/models/entity.py +130 -0
- aiecs/domain/knowledge_graph/models/evidence.py +194 -0
- aiecs/domain/knowledge_graph/models/inference_rule.py +186 -0
- aiecs/domain/knowledge_graph/models/path.py +179 -0
- aiecs/domain/knowledge_graph/models/path_pattern.py +173 -0
- aiecs/domain/knowledge_graph/models/query.py +272 -0
- aiecs/domain/knowledge_graph/models/query_plan.py +187 -0
- aiecs/domain/knowledge_graph/models/relation.py +136 -0
- aiecs/domain/knowledge_graph/schema/__init__.py +23 -0
- aiecs/domain/knowledge_graph/schema/entity_type.py +135 -0
- aiecs/domain/knowledge_graph/schema/graph_schema.py +271 -0
- aiecs/domain/knowledge_graph/schema/property_schema.py +155 -0
- aiecs/domain/knowledge_graph/schema/relation_type.py +171 -0
- aiecs/domain/knowledge_graph/schema/schema_manager.py +496 -0
- aiecs/domain/knowledge_graph/schema/type_enums.py +205 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +613 -0
- aiecs/domain/task/model.py +62 -0
- aiecs/domain/task/task_context.py +268 -0
- aiecs/infrastructure/__init__.py +24 -0
- aiecs/infrastructure/graph_storage/__init__.py +11 -0
- aiecs/infrastructure/graph_storage/base.py +601 -0
- aiecs/infrastructure/graph_storage/batch_operations.py +449 -0
- aiecs/infrastructure/graph_storage/cache.py +429 -0
- aiecs/infrastructure/graph_storage/distributed.py +226 -0
- aiecs/infrastructure/graph_storage/error_handling.py +390 -0
- aiecs/infrastructure/graph_storage/graceful_degradation.py +306 -0
- aiecs/infrastructure/graph_storage/health_checks.py +378 -0
- aiecs/infrastructure/graph_storage/in_memory.py +514 -0
- aiecs/infrastructure/graph_storage/index_optimization.py +483 -0
- aiecs/infrastructure/graph_storage/lazy_loading.py +410 -0
- aiecs/infrastructure/graph_storage/metrics.py +357 -0
- aiecs/infrastructure/graph_storage/migration.py +413 -0
- aiecs/infrastructure/graph_storage/pagination.py +471 -0
- aiecs/infrastructure/graph_storage/performance_monitoring.py +466 -0
- aiecs/infrastructure/graph_storage/postgres.py +871 -0
- aiecs/infrastructure/graph_storage/query_optimizer.py +635 -0
- aiecs/infrastructure/graph_storage/schema_cache.py +290 -0
- aiecs/infrastructure/graph_storage/sqlite.py +623 -0
- aiecs/infrastructure/graph_storage/streaming.py +495 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +383 -0
- aiecs/infrastructure/messaging/websocket_manager.py +298 -0
- aiecs/infrastructure/monitoring/__init__.py +34 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +174 -0
- aiecs/infrastructure/monitoring/global_metrics_manager.py +213 -0
- aiecs/infrastructure/monitoring/structured_logger.py +48 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +410 -0
- aiecs/infrastructure/persistence/__init__.py +24 -0
- aiecs/infrastructure/persistence/context_engine_client.py +187 -0
- aiecs/infrastructure/persistence/database_manager.py +333 -0
- aiecs/infrastructure/persistence/file_storage.py +754 -0
- aiecs/infrastructure/persistence/redis_client.py +220 -0
- aiecs/llm/__init__.py +86 -0
- aiecs/llm/callbacks/__init__.py +11 -0
- aiecs/llm/callbacks/custom_callbacks.py +264 -0
- aiecs/llm/client_factory.py +420 -0
- aiecs/llm/clients/__init__.py +33 -0
- aiecs/llm/clients/base_client.py +193 -0
- aiecs/llm/clients/googleai_client.py +181 -0
- aiecs/llm/clients/openai_client.py +131 -0
- aiecs/llm/clients/vertex_client.py +437 -0
- aiecs/llm/clients/xai_client.py +184 -0
- aiecs/llm/config/__init__.py +51 -0
- aiecs/llm/config/config_loader.py +275 -0
- aiecs/llm/config/config_validator.py +236 -0
- aiecs/llm/config/model_config.py +151 -0
- aiecs/llm/utils/__init__.py +10 -0
- aiecs/llm/utils/validate_config.py +91 -0
- aiecs/main.py +363 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/aid/VERSION_MANAGEMENT.md +97 -0
- aiecs/scripts/aid/__init__.py +19 -0
- aiecs/scripts/aid/version_manager.py +215 -0
- aiecs/scripts/dependance_check/DEPENDENCY_SYSTEM_SUMMARY.md +242 -0
- aiecs/scripts/dependance_check/README_DEPENDENCY_CHECKER.md +310 -0
- aiecs/scripts/dependance_check/__init__.py +17 -0
- aiecs/scripts/dependance_check/dependency_checker.py +938 -0
- aiecs/scripts/dependance_check/dependency_fixer.py +391 -0
- aiecs/scripts/dependance_check/download_nlp_data.py +396 -0
- aiecs/scripts/dependance_check/quick_dependency_check.py +270 -0
- aiecs/scripts/dependance_check/setup_nlp_data.sh +217 -0
- aiecs/scripts/dependance_patch/__init__.py +7 -0
- aiecs/scripts/dependance_patch/fix_weasel/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/dependance_patch/fix_weasel/__init__.py +11 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.py +128 -0
- aiecs/scripts/dependance_patch/fix_weasel/fix_weasel_validator.sh +82 -0
- aiecs/scripts/dependance_patch/fix_weasel/patch_weasel_library.sh +188 -0
- aiecs/scripts/dependance_patch/fix_weasel/run_weasel_patch.sh +41 -0
- aiecs/scripts/tools_develop/README.md +449 -0
- aiecs/scripts/tools_develop/TOOL_AUTO_DISCOVERY.md +234 -0
- aiecs/scripts/tools_develop/__init__.py +21 -0
- aiecs/scripts/tools_develop/check_type_annotations.py +259 -0
- aiecs/scripts/tools_develop/validate_tool_schemas.py +422 -0
- aiecs/scripts/tools_develop/verify_tools.py +356 -0
- aiecs/tasks/__init__.py +1 -0
- aiecs/tasks/worker.py +172 -0
- aiecs/tools/__init__.py +299 -0
- aiecs/tools/apisource/__init__.py +99 -0
- aiecs/tools/apisource/intelligence/__init__.py +19 -0
- aiecs/tools/apisource/intelligence/data_fusion.py +381 -0
- aiecs/tools/apisource/intelligence/query_analyzer.py +413 -0
- aiecs/tools/apisource/intelligence/search_enhancer.py +388 -0
- aiecs/tools/apisource/monitoring/__init__.py +9 -0
- aiecs/tools/apisource/monitoring/metrics.py +303 -0
- aiecs/tools/apisource/providers/__init__.py +115 -0
- aiecs/tools/apisource/providers/base.py +664 -0
- aiecs/tools/apisource/providers/census.py +401 -0
- aiecs/tools/apisource/providers/fred.py +564 -0
- aiecs/tools/apisource/providers/newsapi.py +412 -0
- aiecs/tools/apisource/providers/worldbank.py +357 -0
- aiecs/tools/apisource/reliability/__init__.py +12 -0
- aiecs/tools/apisource/reliability/error_handler.py +375 -0
- aiecs/tools/apisource/reliability/fallback_strategy.py +391 -0
- aiecs/tools/apisource/tool.py +850 -0
- aiecs/tools/apisource/utils/__init__.py +9 -0
- aiecs/tools/apisource/utils/validators.py +338 -0
- aiecs/tools/base_tool.py +201 -0
- aiecs/tools/docs/__init__.py +121 -0
- aiecs/tools/docs/ai_document_orchestrator.py +599 -0
- aiecs/tools/docs/ai_document_writer_orchestrator.py +2403 -0
- aiecs/tools/docs/content_insertion_tool.py +1333 -0
- aiecs/tools/docs/document_creator_tool.py +1317 -0
- aiecs/tools/docs/document_layout_tool.py +1166 -0
- aiecs/tools/docs/document_parser_tool.py +994 -0
- aiecs/tools/docs/document_writer_tool.py +1818 -0
- aiecs/tools/knowledge_graph/__init__.py +17 -0
- aiecs/tools/knowledge_graph/graph_reasoning_tool.py +734 -0
- aiecs/tools/knowledge_graph/graph_search_tool.py +923 -0
- aiecs/tools/knowledge_graph/kg_builder_tool.py +476 -0
- aiecs/tools/langchain_adapter.py +542 -0
- aiecs/tools/schema_generator.py +275 -0
- aiecs/tools/search_tool/__init__.py +100 -0
- aiecs/tools/search_tool/analyzers.py +589 -0
- aiecs/tools/search_tool/cache.py +260 -0
- aiecs/tools/search_tool/constants.py +128 -0
- aiecs/tools/search_tool/context.py +216 -0
- aiecs/tools/search_tool/core.py +749 -0
- aiecs/tools/search_tool/deduplicator.py +123 -0
- aiecs/tools/search_tool/error_handler.py +271 -0
- aiecs/tools/search_tool/metrics.py +371 -0
- aiecs/tools/search_tool/rate_limiter.py +178 -0
- aiecs/tools/search_tool/schemas.py +277 -0
- aiecs/tools/statistics/__init__.py +80 -0
- aiecs/tools/statistics/ai_data_analysis_orchestrator.py +643 -0
- aiecs/tools/statistics/ai_insight_generator_tool.py +505 -0
- aiecs/tools/statistics/ai_report_orchestrator_tool.py +694 -0
- aiecs/tools/statistics/data_loader_tool.py +564 -0
- aiecs/tools/statistics/data_profiler_tool.py +658 -0
- aiecs/tools/statistics/data_transformer_tool.py +573 -0
- aiecs/tools/statistics/data_visualizer_tool.py +495 -0
- aiecs/tools/statistics/model_trainer_tool.py +487 -0
- aiecs/tools/statistics/statistical_analyzer_tool.py +459 -0
- aiecs/tools/task_tools/__init__.py +86 -0
- aiecs/tools/task_tools/chart_tool.py +732 -0
- aiecs/tools/task_tools/classfire_tool.py +922 -0
- aiecs/tools/task_tools/image_tool.py +447 -0
- aiecs/tools/task_tools/office_tool.py +684 -0
- aiecs/tools/task_tools/pandas_tool.py +635 -0
- aiecs/tools/task_tools/report_tool.py +635 -0
- aiecs/tools/task_tools/research_tool.py +392 -0
- aiecs/tools/task_tools/scraper_tool.py +715 -0
- aiecs/tools/task_tools/stats_tool.py +688 -0
- aiecs/tools/temp_file_manager.py +130 -0
- aiecs/tools/tool_executor/__init__.py +37 -0
- aiecs/tools/tool_executor/tool_executor.py +881 -0
- aiecs/utils/LLM_output_structor.py +445 -0
- aiecs/utils/__init__.py +34 -0
- aiecs/utils/base_callback.py +47 -0
- aiecs/utils/cache_provider.py +695 -0
- aiecs/utils/execution_utils.py +184 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +14 -0
- aiecs/utils/token_usage_repository.py +323 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +52 -0
- aiecs-1.5.1.dist-info/METADATA +608 -0
- aiecs-1.5.1.dist-info/RECORD +302 -0
- aiecs-1.5.1.dist-info/WHEEL +5 -0
- aiecs-1.5.1.dist-info/entry_points.txt +10 -0
- aiecs-1.5.1.dist-info/licenses/LICENSE +225 -0
- aiecs-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Dict, List, Any
|
|
6
|
+
from celery import Celery
|
|
7
|
+
from celery.exceptions import TimeoutError as CeleryTimeoutError
|
|
8
|
+
from asyncio import TimeoutError as AsyncioTimeoutError
|
|
9
|
+
|
|
10
|
+
# Removed direct import to avoid circular dependency
|
|
11
|
+
# Tasks are referenced by string names instead
|
|
12
|
+
from aiecs.domain.execution.model import TaskStatus, ErrorCode
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CeleryTaskManager:
|
|
18
|
+
"""
|
|
19
|
+
Specialized handler for Celery distributed task scheduling and execution
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, config: Dict[str, Any]):
|
|
23
|
+
self.config = config
|
|
24
|
+
self.celery_app = None
|
|
25
|
+
self._init_celery()
|
|
26
|
+
|
|
27
|
+
def _init_celery(self):
|
|
28
|
+
"""Initialize Celery application"""
|
|
29
|
+
try:
|
|
30
|
+
self.celery_app = Celery(
|
|
31
|
+
"service_executor",
|
|
32
|
+
broker=self.config.get("broker_url", "redis://redis:6379/0"),
|
|
33
|
+
backend=self.config.get("backend_url", "redis://redis:6379/0"),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Configure Celery
|
|
37
|
+
self.celery_app.conf.update(
|
|
38
|
+
task_serializer=self.config.get("task_serializer", "json"),
|
|
39
|
+
accept_content=self.config.get("accept_content", ["json"]),
|
|
40
|
+
result_serializer=self.config.get("result_serializer", "json"),
|
|
41
|
+
timezone=self.config.get("timezone", "UTC"),
|
|
42
|
+
enable_utc=self.config.get("enable_utc", True),
|
|
43
|
+
task_queues=self.config.get(
|
|
44
|
+
"task_queues",
|
|
45
|
+
{
|
|
46
|
+
"fast_tasks": {
|
|
47
|
+
"exchange": "fast_tasks",
|
|
48
|
+
"routing_key": "fast_tasks",
|
|
49
|
+
},
|
|
50
|
+
"heavy_tasks": {
|
|
51
|
+
"exchange": "heavy_tasks",
|
|
52
|
+
"routing_key": "heavy_tasks",
|
|
53
|
+
},
|
|
54
|
+
},
|
|
55
|
+
),
|
|
56
|
+
worker_concurrency=self.config.get(
|
|
57
|
+
"worker_concurrency",
|
|
58
|
+
{"fast_worker": 10, "heavy_worker": 2},
|
|
59
|
+
),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
logger.info("Celery application initialized successfully")
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.error(f"Failed to initialize Celery: {e}")
|
|
65
|
+
raise
|
|
66
|
+
|
|
67
|
+
def execute_celery_task(
|
|
68
|
+
self,
|
|
69
|
+
task_name: str,
|
|
70
|
+
queue: str,
|
|
71
|
+
user_id: str,
|
|
72
|
+
task_id: str,
|
|
73
|
+
step: int,
|
|
74
|
+
mode: str,
|
|
75
|
+
service: str,
|
|
76
|
+
input_data: Dict[str, Any],
|
|
77
|
+
context: Dict[str, Any],
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Execute Celery task
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
task_name: Task name
|
|
84
|
+
queue: Queue name ('fast_tasks' or 'heavy_tasks')
|
|
85
|
+
user_id: User ID
|
|
86
|
+
task_id: Task ID
|
|
87
|
+
step: Step number
|
|
88
|
+
mode: Service mode
|
|
89
|
+
service: Service name
|
|
90
|
+
input_data: Input data
|
|
91
|
+
context: Context information
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Celery AsyncResult object
|
|
95
|
+
"""
|
|
96
|
+
logger.info(
|
|
97
|
+
f"Queueing task {task_name} to {queue} for user {user_id}, task {task_id}, step {step}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Determine Celery task to use based on queue
|
|
101
|
+
celery_task_name = "aiecs.tasks.worker.execute_task"
|
|
102
|
+
if queue == "heavy_tasks":
|
|
103
|
+
celery_task_name = "aiecs.tasks.worker.execute_heavy_task"
|
|
104
|
+
|
|
105
|
+
# Send task to Celery
|
|
106
|
+
return self.celery_app.send_task(
|
|
107
|
+
celery_task_name,
|
|
108
|
+
kwargs={
|
|
109
|
+
"task_name": task_name,
|
|
110
|
+
"user_id": user_id,
|
|
111
|
+
"task_id": task_id,
|
|
112
|
+
"step": step,
|
|
113
|
+
"mode": mode,
|
|
114
|
+
"service": service,
|
|
115
|
+
"input_data": input_data,
|
|
116
|
+
"context": context,
|
|
117
|
+
},
|
|
118
|
+
queue=queue,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
async def execute_task(
|
|
122
|
+
self,
|
|
123
|
+
task_name: str,
|
|
124
|
+
input_data: Dict[str, Any],
|
|
125
|
+
context: Dict[str, Any],
|
|
126
|
+
) -> Any:
|
|
127
|
+
"""
|
|
128
|
+
Execute a single task using Celery for asynchronous processing
|
|
129
|
+
"""
|
|
130
|
+
user_id = context.get("user_id", "anonymous")
|
|
131
|
+
task_id = input_data.get("task_id", str(uuid.uuid4()))
|
|
132
|
+
step = input_data.get("step", 0)
|
|
133
|
+
mode = input_data.get("mode", "default")
|
|
134
|
+
service = input_data.get("service", "default")
|
|
135
|
+
queue = input_data.get("queue", "fast_tasks")
|
|
136
|
+
timeout = self.config.get("task_timeout_seconds", 300)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
# Use string-based task names to avoid circular imports
|
|
140
|
+
celery_task_name = "aiecs.tasks.worker.execute_task"
|
|
141
|
+
if queue == "heavy_tasks":
|
|
142
|
+
celery_task_name = "aiecs.tasks.worker.execute_heavy_task"
|
|
143
|
+
|
|
144
|
+
result = self.celery_app.send_task(
|
|
145
|
+
celery_task_name,
|
|
146
|
+
kwargs={
|
|
147
|
+
"task_name": task_name,
|
|
148
|
+
"user_id": user_id,
|
|
149
|
+
"task_id": task_id,
|
|
150
|
+
"step": step,
|
|
151
|
+
"mode": mode,
|
|
152
|
+
"service": service,
|
|
153
|
+
"input_data": input_data,
|
|
154
|
+
"context": context,
|
|
155
|
+
},
|
|
156
|
+
queue=queue,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return result.get(timeout=timeout)
|
|
160
|
+
|
|
161
|
+
except CeleryTimeoutError as e:
|
|
162
|
+
logger.error(f"Timeout executing Celery task {task_name}: {e}")
|
|
163
|
+
return {
|
|
164
|
+
"status": TaskStatus.TIMED_OUT,
|
|
165
|
+
"error_code": ErrorCode.TIMEOUT_ERROR,
|
|
166
|
+
"error_message": str(e),
|
|
167
|
+
}
|
|
168
|
+
except Exception as e:
|
|
169
|
+
logger.error(f"Error executing Celery task {task_name}: {e}", exc_info=True)
|
|
170
|
+
return {
|
|
171
|
+
"status": TaskStatus.FAILED,
|
|
172
|
+
"error_code": ErrorCode.EXECUTION_ERROR,
|
|
173
|
+
"error_message": str(e),
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
async def execute_heavy_task(self, task_name: str, input_data: Dict, context: Dict) -> Any:
|
|
177
|
+
"""
|
|
178
|
+
Execute heavy task
|
|
179
|
+
"""
|
|
180
|
+
input_data["queue"] = "heavy_tasks"
|
|
181
|
+
return await self.execute_task(task_name, input_data, context)
|
|
182
|
+
|
|
183
|
+
async def execute_dsl_task_step(
|
|
184
|
+
self, step: Dict, input_data: Dict, context: Dict
|
|
185
|
+
) -> Dict[str, Any]:
|
|
186
|
+
"""
|
|
187
|
+
Execute DSL task step
|
|
188
|
+
"""
|
|
189
|
+
task_name = step.get("task")
|
|
190
|
+
category = "process"
|
|
191
|
+
|
|
192
|
+
if not task_name:
|
|
193
|
+
return {
|
|
194
|
+
"step": "unknown",
|
|
195
|
+
"result": None,
|
|
196
|
+
"completed": False,
|
|
197
|
+
"message": "Invalid DSL step: missing task name",
|
|
198
|
+
"status": TaskStatus.FAILED,
|
|
199
|
+
"error_code": ErrorCode.VALIDATION_ERROR,
|
|
200
|
+
"error_message": "Task name is required",
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# Determine task type
|
|
204
|
+
task_type = "fast"
|
|
205
|
+
try:
|
|
206
|
+
task_type_result = await self.execute_task(task_name, {"get_task_type": True}, context)
|
|
207
|
+
if isinstance(task_type_result, dict) and "task_type" in task_type_result:
|
|
208
|
+
task_type = task_type_result["task_type"]
|
|
209
|
+
except Exception:
|
|
210
|
+
logger.warning(f"Could not determine task type for {task_name}, defaulting to 'fast'")
|
|
211
|
+
|
|
212
|
+
queue = "heavy_tasks" if task_type == "heavy" else "fast_tasks"
|
|
213
|
+
celery_task_name = (
|
|
214
|
+
"aiecs.tasks.worker.execute_heavy_task"
|
|
215
|
+
if task_type == "heavy"
|
|
216
|
+
else "aiecs.tasks.worker.execute_task"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
user_id = context.get("user_id", str(uuid.uuid4()))
|
|
220
|
+
task_id = context.get("task_id", str(uuid.uuid4()))
|
|
221
|
+
step_num = context.get("step", 0)
|
|
222
|
+
|
|
223
|
+
# Send task to Celery
|
|
224
|
+
celery_task = self.celery_app.send_task(
|
|
225
|
+
celery_task_name,
|
|
226
|
+
kwargs={
|
|
227
|
+
"task_name": task_name,
|
|
228
|
+
"user_id": user_id,
|
|
229
|
+
"task_id": task_id,
|
|
230
|
+
"step": step_num,
|
|
231
|
+
"mode": context.get("mode", "multi_task"),
|
|
232
|
+
"service": context.get("service", "summarizer"),
|
|
233
|
+
"input_data": input_data,
|
|
234
|
+
"context": context,
|
|
235
|
+
},
|
|
236
|
+
queue=queue,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
try:
|
|
240
|
+
timeout_seconds = self.config.get("task_timeout_seconds", 300)
|
|
241
|
+
start_time = time.time()
|
|
242
|
+
|
|
243
|
+
# Wait for task completion
|
|
244
|
+
while not celery_task.ready():
|
|
245
|
+
if time.time() - start_time > timeout_seconds:
|
|
246
|
+
raise AsyncioTimeoutError(
|
|
247
|
+
f"Task {task_name} timed out after {timeout_seconds} seconds"
|
|
248
|
+
)
|
|
249
|
+
await asyncio.sleep(0.5)
|
|
250
|
+
|
|
251
|
+
if celery_task.successful():
|
|
252
|
+
result = celery_task.get()
|
|
253
|
+
if isinstance(result, dict) and "step" in result:
|
|
254
|
+
return result
|
|
255
|
+
else:
|
|
256
|
+
return {
|
|
257
|
+
"step": f"{category}/{task_name}",
|
|
258
|
+
"result": result,
|
|
259
|
+
"completed": True,
|
|
260
|
+
"message": f"Completed task {task_name}",
|
|
261
|
+
"status": TaskStatus.COMPLETED,
|
|
262
|
+
}
|
|
263
|
+
else:
|
|
264
|
+
error = celery_task.get(propagate=False)
|
|
265
|
+
status = (
|
|
266
|
+
TaskStatus.TIMED_OUT
|
|
267
|
+
if isinstance(error, CeleryTimeoutError)
|
|
268
|
+
else TaskStatus.FAILED
|
|
269
|
+
)
|
|
270
|
+
error_code = (
|
|
271
|
+
ErrorCode.TIMEOUT_ERROR
|
|
272
|
+
if isinstance(error, CeleryTimeoutError)
|
|
273
|
+
else ErrorCode.EXECUTION_ERROR
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return {
|
|
277
|
+
"step": f"{category}/{task_name}",
|
|
278
|
+
"result": None,
|
|
279
|
+
"completed": False,
|
|
280
|
+
"message": f"Failed to execute task: {error}",
|
|
281
|
+
"status": status,
|
|
282
|
+
"error_code": error_code,
|
|
283
|
+
"error_message": str(error),
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
except AsyncioTimeoutError as e:
|
|
287
|
+
return {
|
|
288
|
+
"step": f"{category}/{task_name}",
|
|
289
|
+
"result": None,
|
|
290
|
+
"completed": False,
|
|
291
|
+
"message": "Task execution timed out",
|
|
292
|
+
"status": TaskStatus.TIMED_OUT,
|
|
293
|
+
"error_code": ErrorCode.TIMEOUT_ERROR,
|
|
294
|
+
"error_message": str(e),
|
|
295
|
+
}
|
|
296
|
+
except Exception as e:
|
|
297
|
+
return {
|
|
298
|
+
"step": f"{category}/{task_name}",
|
|
299
|
+
"result": None,
|
|
300
|
+
"completed": False,
|
|
301
|
+
"message": f"Failed to execute {category}/{task_name}",
|
|
302
|
+
"status": TaskStatus.FAILED,
|
|
303
|
+
"error_code": ErrorCode.EXECUTION_ERROR,
|
|
304
|
+
"error_message": str(e),
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
def get_task_result(self, task_id: str):
|
|
308
|
+
"""Get task result"""
|
|
309
|
+
try:
|
|
310
|
+
result = self.celery_app.AsyncResult(task_id)
|
|
311
|
+
return {
|
|
312
|
+
"task_id": task_id,
|
|
313
|
+
"status": result.status,
|
|
314
|
+
"result": result.result if result.ready() else None,
|
|
315
|
+
"successful": result.successful() if result.ready() else None,
|
|
316
|
+
"failed": result.failed() if result.ready() else None,
|
|
317
|
+
}
|
|
318
|
+
except Exception as e:
|
|
319
|
+
logger.error(f"Error getting task result for {task_id}: {e}")
|
|
320
|
+
return {"task_id": task_id, "status": "ERROR", "error": str(e)}
|
|
321
|
+
|
|
322
|
+
def cancel_task(self, task_id: str):
|
|
323
|
+
"""Cancel task"""
|
|
324
|
+
try:
|
|
325
|
+
self.celery_app.control.revoke(task_id, terminate=True)
|
|
326
|
+
logger.info(f"Task {task_id} cancelled")
|
|
327
|
+
return True
|
|
328
|
+
except Exception as e:
|
|
329
|
+
logger.error(f"Error cancelling task {task_id}: {e}")
|
|
330
|
+
return False
|
|
331
|
+
|
|
332
|
+
async def batch_execute_tasks(self, tasks: List[Dict[str, Any]]) -> List[Any]:
|
|
333
|
+
"""
|
|
334
|
+
Batch execute tasks
|
|
335
|
+
"""
|
|
336
|
+
results = []
|
|
337
|
+
batch_size = self.config.get("batch_size", 10)
|
|
338
|
+
rate_limit = self.config.get("rate_limit_requests_per_second", 5)
|
|
339
|
+
|
|
340
|
+
for i in range(0, len(tasks), batch_size):
|
|
341
|
+
batch = tasks[i : i + batch_size]
|
|
342
|
+
batch_results = await asyncio.gather(
|
|
343
|
+
*[
|
|
344
|
+
self.execute_task(
|
|
345
|
+
task["task_name"],
|
|
346
|
+
task.get("input_data", {}),
|
|
347
|
+
task.get("context", {}),
|
|
348
|
+
)
|
|
349
|
+
for task in batch
|
|
350
|
+
],
|
|
351
|
+
return_exceptions=True,
|
|
352
|
+
)
|
|
353
|
+
results.extend(batch_results)
|
|
354
|
+
await asyncio.sleep(1.0 / rate_limit)
|
|
355
|
+
|
|
356
|
+
return results
|
|
357
|
+
|
|
358
|
+
def get_queue_info(self) -> Dict[str, Any]:
|
|
359
|
+
"""Get queue information"""
|
|
360
|
+
try:
|
|
361
|
+
inspect = self.celery_app.control.inspect()
|
|
362
|
+
active_tasks = inspect.active()
|
|
363
|
+
scheduled_tasks = inspect.scheduled()
|
|
364
|
+
reserved_tasks = inspect.reserved()
|
|
365
|
+
|
|
366
|
+
return {
|
|
367
|
+
"active_tasks": active_tasks,
|
|
368
|
+
"scheduled_tasks": scheduled_tasks,
|
|
369
|
+
"reserved_tasks": reserved_tasks,
|
|
370
|
+
}
|
|
371
|
+
except Exception as e:
|
|
372
|
+
logger.error(f"Error getting queue info: {e}")
|
|
373
|
+
return {"error": str(e)}
|
|
374
|
+
|
|
375
|
+
def get_worker_stats(self) -> Dict[str, Any]:
|
|
376
|
+
"""Get worker statistics"""
|
|
377
|
+
try:
|
|
378
|
+
inspect = self.celery_app.control.inspect()
|
|
379
|
+
stats = inspect.stats()
|
|
380
|
+
return stats or {}
|
|
381
|
+
except Exception as e:
|
|
382
|
+
logger.error(f"Error getting worker stats: {e}")
|
|
383
|
+
return {"error": str(e)}
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import uuid
|
|
5
|
+
import websockets
|
|
6
|
+
from typing import Dict, Any, Set, Optional, Callable
|
|
7
|
+
from websockets import serve, ServerConnection
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class UserConfirmation(BaseModel):
|
|
14
|
+
proceed: bool
|
|
15
|
+
feedback: Optional[str] = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TaskStepResult(BaseModel):
|
|
19
|
+
step: str
|
|
20
|
+
result: Any = None
|
|
21
|
+
completed: bool = False
|
|
22
|
+
message: str
|
|
23
|
+
status: str
|
|
24
|
+
error_code: Optional[str] = None
|
|
25
|
+
error_message: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class WebSocketManager:
|
|
29
|
+
"""
|
|
30
|
+
Specialized handler for WebSocket server and client communication
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, host: str = "python-middleware-api", port: int = 8765):
|
|
34
|
+
self.host = host
|
|
35
|
+
self.port = port
|
|
36
|
+
self.server = None
|
|
37
|
+
self.callback_registry: Dict[str, Callable] = {}
|
|
38
|
+
self.active_connections: Set[ServerConnection] = set()
|
|
39
|
+
self._running = False
|
|
40
|
+
|
|
41
|
+
async def start_server(self):
|
|
42
|
+
"""Start WebSocket server"""
|
|
43
|
+
if self.server:
|
|
44
|
+
logger.warning("WebSocket server is already running")
|
|
45
|
+
return self.server
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
self.server = await serve(self._handle_client_connection, self.host, self.port)
|
|
49
|
+
self._running = True
|
|
50
|
+
logger.info(f"WebSocket server started on {self.host}:{self.port}")
|
|
51
|
+
return self.server
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error(f"Failed to start WebSocket server: {e}")
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
async def stop_server(self):
|
|
57
|
+
"""Stop WebSocket server"""
|
|
58
|
+
if self.server:
|
|
59
|
+
self.server.close()
|
|
60
|
+
await self.server.wait_closed()
|
|
61
|
+
self._running = False
|
|
62
|
+
logger.info("WebSocket server stopped")
|
|
63
|
+
|
|
64
|
+
# Close all active connections
|
|
65
|
+
if self.active_connections:
|
|
66
|
+
await asyncio.gather(
|
|
67
|
+
*[conn.close() for conn in self.active_connections],
|
|
68
|
+
return_exceptions=True,
|
|
69
|
+
)
|
|
70
|
+
self.active_connections.clear()
|
|
71
|
+
|
|
72
|
+
async def _handle_client_connection(self, websocket: ServerConnection, path: str):
|
|
73
|
+
"""Handle client connection"""
|
|
74
|
+
self.active_connections.add(websocket)
|
|
75
|
+
client_addr = websocket.remote_address
|
|
76
|
+
logger.info(f"New WebSocket connection from {client_addr}")
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
async for message in websocket:
|
|
80
|
+
await self._handle_client_message(websocket, message)
|
|
81
|
+
except websockets.exceptions.ConnectionClosed:
|
|
82
|
+
logger.info(f"WebSocket connection closed: {client_addr}")
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.error(f"WebSocket error for {client_addr}: {e}")
|
|
85
|
+
finally:
|
|
86
|
+
self.active_connections.discard(websocket)
|
|
87
|
+
if not websocket.closed:
|
|
88
|
+
await websocket.close()
|
|
89
|
+
|
|
90
|
+
async def _handle_client_message(self, websocket: ServerConnection, message: str):
|
|
91
|
+
"""Handle client message"""
|
|
92
|
+
try:
|
|
93
|
+
data = json.loads(message)
|
|
94
|
+
action = data.get("action")
|
|
95
|
+
|
|
96
|
+
if action == "confirm":
|
|
97
|
+
await self._handle_confirmation(data)
|
|
98
|
+
elif action == "cancel":
|
|
99
|
+
await self._handle_cancellation(data)
|
|
100
|
+
elif action == "ping":
|
|
101
|
+
await self._handle_ping(websocket, data)
|
|
102
|
+
elif action == "subscribe":
|
|
103
|
+
await self._handle_subscription(websocket, data)
|
|
104
|
+
else:
|
|
105
|
+
logger.warning(f"Unknown action received: {action}")
|
|
106
|
+
await self._send_error(websocket, f"Unknown action: {action}")
|
|
107
|
+
|
|
108
|
+
except json.JSONDecodeError as e:
|
|
109
|
+
logger.error(f"Invalid JSON received: {e}")
|
|
110
|
+
await self._send_error(websocket, "Invalid JSON format")
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f"Error handling client message: {e}")
|
|
113
|
+
await self._send_error(websocket, f"Internal error: {str(e)}")
|
|
114
|
+
|
|
115
|
+
async def _handle_confirmation(self, data: Dict[str, Any]):
|
|
116
|
+
"""Handle user confirmation"""
|
|
117
|
+
callback_id = data.get("callback_id")
|
|
118
|
+
if callback_id and callback_id in self.callback_registry:
|
|
119
|
+
callback = self.callback_registry[callback_id]
|
|
120
|
+
confirmation = UserConfirmation(
|
|
121
|
+
proceed=data.get("proceed", False),
|
|
122
|
+
feedback=data.get("feedback"),
|
|
123
|
+
)
|
|
124
|
+
try:
|
|
125
|
+
callback(confirmation)
|
|
126
|
+
del self.callback_registry[callback_id]
|
|
127
|
+
logger.debug(f"Processed confirmation for callback {callback_id}")
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"Error processing confirmation callback: {e}")
|
|
130
|
+
else:
|
|
131
|
+
logger.warning(f"No callback found for confirmation ID: {callback_id}")
|
|
132
|
+
|
|
133
|
+
async def _handle_cancellation(self, data: Dict[str, Any]):
|
|
134
|
+
"""Handle task cancellation"""
|
|
135
|
+
user_id = data.get("user_id")
|
|
136
|
+
task_id = data.get("task_id")
|
|
137
|
+
|
|
138
|
+
if user_id and task_id:
|
|
139
|
+
# Task cancellation logic can be added here
|
|
140
|
+
# Since database manager access is needed, this functionality may
|
|
141
|
+
# need to be implemented through callbacks
|
|
142
|
+
logger.info(f"Task cancellation requested: user={user_id}, task={task_id}")
|
|
143
|
+
await self.broadcast_message(
|
|
144
|
+
{
|
|
145
|
+
"type": "task_cancelled",
|
|
146
|
+
"user_id": user_id,
|
|
147
|
+
"task_id": task_id,
|
|
148
|
+
"timestamp": asyncio.get_event_loop().time(),
|
|
149
|
+
}
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
logger.warning("Invalid cancellation request: missing user_id or task_id")
|
|
153
|
+
|
|
154
|
+
async def _handle_ping(self, websocket: ServerConnection, data: Dict[str, Any]):
|
|
155
|
+
"""Handle heartbeat detection"""
|
|
156
|
+
pong_data = {
|
|
157
|
+
"type": "pong",
|
|
158
|
+
"timestamp": asyncio.get_event_loop().time(),
|
|
159
|
+
"original_data": data,
|
|
160
|
+
}
|
|
161
|
+
await self._send_to_client(websocket, pong_data)
|
|
162
|
+
|
|
163
|
+
async def _handle_subscription(self, websocket: ServerConnection, data: Dict[str, Any]):
|
|
164
|
+
"""Handle subscription request"""
|
|
165
|
+
user_id = data.get("user_id")
|
|
166
|
+
if user_id:
|
|
167
|
+
# User-specific subscription logic can be implemented here
|
|
168
|
+
logger.info(f"User {user_id} subscribed to updates")
|
|
169
|
+
await self._send_to_client(
|
|
170
|
+
websocket,
|
|
171
|
+
{"type": "subscription_confirmed", "user_id": user_id},
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async def _send_error(self, websocket: ServerConnection, error_message: str):
|
|
175
|
+
"""Send error message to client"""
|
|
176
|
+
error_data = {
|
|
177
|
+
"type": "error",
|
|
178
|
+
"message": error_message,
|
|
179
|
+
"timestamp": asyncio.get_event_loop().time(),
|
|
180
|
+
}
|
|
181
|
+
await self._send_to_client(websocket, error_data)
|
|
182
|
+
|
|
183
|
+
async def _send_to_client(self, websocket: ServerConnection, data: Dict[str, Any]):
|
|
184
|
+
"""Send data to specific client"""
|
|
185
|
+
try:
|
|
186
|
+
if not websocket.closed:
|
|
187
|
+
await websocket.send(json.dumps(data))
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.error(f"Failed to send message to client: {e}")
|
|
190
|
+
|
|
191
|
+
async def notify_user(
|
|
192
|
+
self,
|
|
193
|
+
step_result: TaskStepResult,
|
|
194
|
+
user_id: str,
|
|
195
|
+
task_id: str,
|
|
196
|
+
step: int,
|
|
197
|
+
) -> UserConfirmation:
|
|
198
|
+
"""Notify user of task step result"""
|
|
199
|
+
callback_id = str(uuid.uuid4())
|
|
200
|
+
confirmation_future = asyncio.Future()
|
|
201
|
+
|
|
202
|
+
# Register callback
|
|
203
|
+
self.callback_registry[callback_id] = lambda confirmation: confirmation_future.set_result(
|
|
204
|
+
confirmation
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Prepare notification data
|
|
208
|
+
notification_data = {
|
|
209
|
+
"type": "task_step_result",
|
|
210
|
+
"callback_id": callback_id,
|
|
211
|
+
"step": step,
|
|
212
|
+
"message": step_result.message,
|
|
213
|
+
"result": step_result.result,
|
|
214
|
+
"status": step_result.status,
|
|
215
|
+
"error_code": step_result.error_code,
|
|
216
|
+
"error_message": step_result.error_message,
|
|
217
|
+
"user_id": user_id,
|
|
218
|
+
"task_id": task_id,
|
|
219
|
+
"timestamp": asyncio.get_event_loop().time(),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
try:
|
|
223
|
+
# Broadcast to all connected clients (can be optimized to send only
|
|
224
|
+
# to specific users)
|
|
225
|
+
await self.broadcast_message(notification_data)
|
|
226
|
+
|
|
227
|
+
# Wait for user confirmation with timeout
|
|
228
|
+
try:
|
|
229
|
+
# 5 minute timeout
|
|
230
|
+
return await asyncio.wait_for(confirmation_future, timeout=300)
|
|
231
|
+
except asyncio.TimeoutError:
|
|
232
|
+
logger.warning(f"User confirmation timeout for callback {callback_id}")
|
|
233
|
+
# Clean up callback
|
|
234
|
+
self.callback_registry.pop(callback_id, None)
|
|
235
|
+
return UserConfirmation(proceed=True) # Default to proceed
|
|
236
|
+
|
|
237
|
+
except Exception as e:
|
|
238
|
+
logger.error(f"WebSocket notification error: {e}")
|
|
239
|
+
# Clean up callback
|
|
240
|
+
self.callback_registry.pop(callback_id, None)
|
|
241
|
+
return UserConfirmation(proceed=True) # Default to proceed
|
|
242
|
+
|
|
243
|
+
async def send_heartbeat(self, user_id: str, task_id: str, interval: int = 30):
|
|
244
|
+
"""Send heartbeat message"""
|
|
245
|
+
heartbeat_data = {
|
|
246
|
+
"type": "heartbeat",
|
|
247
|
+
"status": "heartbeat",
|
|
248
|
+
"message": "Task is still executing...",
|
|
249
|
+
"user_id": user_id,
|
|
250
|
+
"task_id": task_id,
|
|
251
|
+
"timestamp": asyncio.get_event_loop().time(),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
while self._running:
|
|
255
|
+
try:
|
|
256
|
+
await self.broadcast_message(heartbeat_data)
|
|
257
|
+
await asyncio.sleep(interval)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"WebSocket heartbeat error: {e}")
|
|
260
|
+
break
|
|
261
|
+
|
|
262
|
+
async def broadcast_message(self, message: Dict[str, Any]):
|
|
263
|
+
"""Broadcast message to all connected clients"""
|
|
264
|
+
if not self.active_connections:
|
|
265
|
+
logger.debug("No active WebSocket connections for broadcast")
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
# Filter out closed connections
|
|
269
|
+
active_connections = [conn for conn in self.active_connections if not conn.closed]
|
|
270
|
+
self.active_connections = set(active_connections)
|
|
271
|
+
|
|
272
|
+
if active_connections:
|
|
273
|
+
await asyncio.gather(
|
|
274
|
+
*[self._send_to_client(conn, message) for conn in active_connections],
|
|
275
|
+
return_exceptions=True,
|
|
276
|
+
)
|
|
277
|
+
logger.debug(f"Broadcasted message to {len(active_connections)} clients")
|
|
278
|
+
|
|
279
|
+
async def send_to_user(self, user_id: str, message: Dict[str, Any]):
|
|
280
|
+
"""Send message to specific user (requires user connection mapping implementation)"""
|
|
281
|
+
# User ID to WebSocket connection mapping can be implemented here
|
|
282
|
+
# Currently simplified to broadcast
|
|
283
|
+
message["target_user_id"] = user_id
|
|
284
|
+
await self.broadcast_message(message)
|
|
285
|
+
|
|
286
|
+
def get_connection_count(self) -> int:
|
|
287
|
+
"""Get active connection count"""
|
|
288
|
+
return len([conn for conn in self.active_connections if not conn.closed])
|
|
289
|
+
|
|
290
|
+
def get_status(self) -> Dict[str, Any]:
|
|
291
|
+
"""Get WebSocket manager status"""
|
|
292
|
+
return {
|
|
293
|
+
"running": self._running,
|
|
294
|
+
"host": self.host,
|
|
295
|
+
"port": self.port,
|
|
296
|
+
"active_connections": self.get_connection_count(),
|
|
297
|
+
"pending_callbacks": len(self.callback_registry),
|
|
298
|
+
}
|