aiecs 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +75 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +295 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +341 -0
- aiecs/config/__init__.py +15 -0
- aiecs/config/config.py +117 -0
- aiecs/config/registry.py +19 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +150 -0
- aiecs/core/interface/storage_interface.py +214 -0
- aiecs/domain/__init__.py +20 -0
- aiecs/domain/context/__init__.py +28 -0
- aiecs/domain/context/content_engine.py +982 -0
- aiecs/domain/context/conversation_models.py +306 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +49 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +460 -0
- aiecs/domain/task/model.py +50 -0
- aiecs/domain/task/task_context.py +257 -0
- aiecs/infrastructure/__init__.py +26 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +341 -0
- aiecs/infrastructure/messaging/websocket_manager.py +289 -0
- aiecs/infrastructure/monitoring/__init__.py +12 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +138 -0
- aiecs/infrastructure/monitoring/structured_logger.py +50 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +376 -0
- aiecs/infrastructure/persistence/__init__.py +12 -0
- aiecs/infrastructure/persistence/database_manager.py +286 -0
- aiecs/infrastructure/persistence/file_storage.py +671 -0
- aiecs/infrastructure/persistence/redis_client.py +162 -0
- aiecs/llm/__init__.py +54 -0
- aiecs/llm/base_client.py +99 -0
- aiecs/llm/client_factory.py +339 -0
- aiecs/llm/custom_callbacks.py +228 -0
- aiecs/llm/openai_client.py +125 -0
- aiecs/llm/vertex_client.py +186 -0
- aiecs/llm/xai_client.py +184 -0
- aiecs/main.py +351 -0
- aiecs/scripts/DEPENDENCY_SYSTEM_SUMMARY.md +241 -0
- aiecs/scripts/README_DEPENDENCY_CHECKER.md +309 -0
- aiecs/scripts/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/dependency_checker.py +825 -0
- aiecs/scripts/dependency_fixer.py +348 -0
- aiecs/scripts/download_nlp_data.py +348 -0
- aiecs/scripts/fix_weasel_validator.py +121 -0
- aiecs/scripts/fix_weasel_validator.sh +82 -0
- aiecs/scripts/patch_weasel_library.sh +188 -0
- aiecs/scripts/quick_dependency_check.py +269 -0
- aiecs/scripts/run_weasel_patch.sh +41 -0
- aiecs/scripts/setup_nlp_data.sh +217 -0
- aiecs/tasks/__init__.py +2 -0
- aiecs/tasks/worker.py +111 -0
- aiecs/tools/__init__.py +196 -0
- aiecs/tools/base_tool.py +202 -0
- aiecs/tools/langchain_adapter.py +361 -0
- aiecs/tools/task_tools/__init__.py +82 -0
- aiecs/tools/task_tools/chart_tool.py +704 -0
- aiecs/tools/task_tools/classfire_tool.py +901 -0
- aiecs/tools/task_tools/image_tool.py +397 -0
- aiecs/tools/task_tools/office_tool.py +600 -0
- aiecs/tools/task_tools/pandas_tool.py +565 -0
- aiecs/tools/task_tools/report_tool.py +499 -0
- aiecs/tools/task_tools/research_tool.py +363 -0
- aiecs/tools/task_tools/scraper_tool.py +548 -0
- aiecs/tools/task_tools/search_api.py +7 -0
- aiecs/tools/task_tools/stats_tool.py +513 -0
- aiecs/tools/temp_file_manager.py +126 -0
- aiecs/tools/tool_executor/__init__.py +35 -0
- aiecs/tools/tool_executor/tool_executor.py +518 -0
- aiecs/utils/LLM_output_structor.py +409 -0
- aiecs/utils/__init__.py +23 -0
- aiecs/utils/base_callback.py +50 -0
- aiecs/utils/execution_utils.py +158 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +13 -0
- aiecs/utils/token_usage_repository.py +279 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +41 -0
- aiecs-1.0.0.dist-info/METADATA +610 -0
- aiecs-1.0.0.dist-info/RECORD +90 -0
- aiecs-1.0.0.dist-info/WHEEL +5 -0
- aiecs-1.0.0.dist-info/entry_points.txt +7 -0
- aiecs-1.0.0.dist-info/licenses/LICENSE +225 -0
- aiecs-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Dict, Any, Optional
|
|
5
|
+
import jaeger_client
|
|
6
|
+
import jaeger_client.config
|
|
7
|
+
from opentracing import tracer, Span
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TracingManager:
|
|
13
|
+
"""
|
|
14
|
+
Specialized handler for distributed tracing and link tracking
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, service_name: str = "service_executor",
|
|
18
|
+
jaeger_host: Optional[str] = None,
|
|
19
|
+
jaeger_port: Optional[int] = None,
|
|
20
|
+
enable_tracing: Optional[bool] = None):
|
|
21
|
+
self.service_name = service_name
|
|
22
|
+
# Get configuration from environment variables, use defaults if not available
|
|
23
|
+
self.jaeger_host = jaeger_host or os.getenv("JAEGER_AGENT_HOST", "jaeger")
|
|
24
|
+
self.jaeger_port = jaeger_port or int(os.getenv("JAEGER_AGENT_PORT", "6831"))
|
|
25
|
+
self.enable_tracing = enable_tracing if enable_tracing is not None else os.getenv("JAEGER_ENABLE_TRACING", "true").lower() == "true"
|
|
26
|
+
self.tracer = None
|
|
27
|
+
|
|
28
|
+
if self.enable_tracing:
|
|
29
|
+
self._init_tracer()
|
|
30
|
+
|
|
31
|
+
def _init_tracer(self):
|
|
32
|
+
"""Initialize Jaeger tracer"""
|
|
33
|
+
try:
|
|
34
|
+
config = jaeger_client.config.Config(
|
|
35
|
+
config={
|
|
36
|
+
'sampler': {
|
|
37
|
+
'type': 'const',
|
|
38
|
+
'param': 1,
|
|
39
|
+
},
|
|
40
|
+
'local_agent': {
|
|
41
|
+
'reporting_host': self.jaeger_host,
|
|
42
|
+
'reporting_port': self.jaeger_port,
|
|
43
|
+
},
|
|
44
|
+
'logging': True,
|
|
45
|
+
},
|
|
46
|
+
service_name=self.service_name,
|
|
47
|
+
validate=True
|
|
48
|
+
)
|
|
49
|
+
self.tracer = config.initialize_tracer()
|
|
50
|
+
logger.info(f"Jaeger tracer initialized for service '{self.service_name}' at {self.jaeger_host}:{self.jaeger_port}")
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.warning(f"Failed to initialize Jaeger tracer: {e}")
|
|
53
|
+
self.tracer = None
|
|
54
|
+
self.enable_tracing = False
|
|
55
|
+
|
|
56
|
+
def start_span(self, operation_name: str, parent_span: Optional[Span] = None,
|
|
57
|
+
tags: Optional[Dict[str, Any]] = None) -> Optional[Span]:
|
|
58
|
+
"""
|
|
59
|
+
Start a tracing span
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
operation_name: Operation name
|
|
63
|
+
parent_span: Parent span
|
|
64
|
+
tags: Initial tags
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Span object or None (if tracing is not enabled)
|
|
68
|
+
"""
|
|
69
|
+
if not self.enable_tracing or not self.tracer:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
span = self.tracer.start_span(
|
|
74
|
+
operation_name=operation_name,
|
|
75
|
+
child_of=parent_span
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Set initial tags
|
|
79
|
+
if tags:
|
|
80
|
+
for key, value in tags.items():
|
|
81
|
+
span.set_tag(key, value)
|
|
82
|
+
|
|
83
|
+
# Set service information
|
|
84
|
+
span.set_tag("service.name", self.service_name)
|
|
85
|
+
span.set_tag("span.kind", "server")
|
|
86
|
+
|
|
87
|
+
return span
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.error(f"Error starting span '{operation_name}': {e}")
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
def finish_span(self, span: Optional[Span], tags: Optional[Dict[str, Any]] = None,
|
|
93
|
+
logs: Optional[Dict[str, Any]] = None, error: Optional[Exception] = None):
|
|
94
|
+
"""
|
|
95
|
+
Finish tracing span
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
span: Span to finish
|
|
99
|
+
tags: Additional tags
|
|
100
|
+
logs: Log information
|
|
101
|
+
error: Error information
|
|
102
|
+
"""
|
|
103
|
+
if not span or not self.enable_tracing:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# Add additional tags
|
|
108
|
+
if tags:
|
|
109
|
+
for key, value in tags.items():
|
|
110
|
+
span.set_tag(key, value)
|
|
111
|
+
|
|
112
|
+
# Record error
|
|
113
|
+
if error:
|
|
114
|
+
span.set_tag("error", True)
|
|
115
|
+
span.set_tag("error.kind", type(error).__name__)
|
|
116
|
+
span.set_tag("error.message", str(error))
|
|
117
|
+
span.log_kv({"event": "error", "error.object": error})
|
|
118
|
+
|
|
119
|
+
# Add logs
|
|
120
|
+
if logs:
|
|
121
|
+
span.log_kv(logs)
|
|
122
|
+
|
|
123
|
+
span.finish()
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.error(f"Error finishing span: {e}")
|
|
126
|
+
|
|
127
|
+
def with_tracing(self, operation_name: str, tags: Optional[Dict[str, Any]] = None):
|
|
128
|
+
"""
|
|
129
|
+
Tracing decorator
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
operation_name: Operation name
|
|
133
|
+
tags: Initial tags
|
|
134
|
+
"""
|
|
135
|
+
def decorator(func):
|
|
136
|
+
@functools.wraps(func)
|
|
137
|
+
async def async_wrapper(*args, **kwargs):
|
|
138
|
+
if not self.enable_tracing or not self.tracer:
|
|
139
|
+
return await func(*args, **kwargs)
|
|
140
|
+
|
|
141
|
+
span = self.start_span(operation_name, tags=tags)
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Add function arguments as tags
|
|
145
|
+
self._add_function_args_to_span(span, args, kwargs)
|
|
146
|
+
|
|
147
|
+
result = await func(*args, **kwargs)
|
|
148
|
+
|
|
149
|
+
# Record success
|
|
150
|
+
if span:
|
|
151
|
+
span.set_tag("success", True)
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
except Exception as e:
|
|
155
|
+
self.finish_span(span, error=e)
|
|
156
|
+
raise
|
|
157
|
+
finally:
|
|
158
|
+
if span and not span.finished:
|
|
159
|
+
self.finish_span(span)
|
|
160
|
+
|
|
161
|
+
@functools.wraps(func)
|
|
162
|
+
def sync_wrapper(*args, **kwargs):
|
|
163
|
+
if not self.enable_tracing or not self.tracer:
|
|
164
|
+
return func(*args, **kwargs)
|
|
165
|
+
|
|
166
|
+
span = self.start_span(operation_name, tags=tags)
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
# Add function arguments as tags
|
|
170
|
+
self._add_function_args_to_span(span, args, kwargs)
|
|
171
|
+
|
|
172
|
+
result = func(*args, **kwargs)
|
|
173
|
+
|
|
174
|
+
# Record success
|
|
175
|
+
if span:
|
|
176
|
+
span.set_tag("success", True)
|
|
177
|
+
|
|
178
|
+
return result
|
|
179
|
+
except Exception as e:
|
|
180
|
+
self.finish_span(span, error=e)
|
|
181
|
+
raise
|
|
182
|
+
finally:
|
|
183
|
+
if span and not span.finished:
|
|
184
|
+
self.finish_span(span)
|
|
185
|
+
|
|
186
|
+
# Return appropriate wrapper based on function type
|
|
187
|
+
import asyncio
|
|
188
|
+
if asyncio.iscoroutinefunction(func):
|
|
189
|
+
return async_wrapper
|
|
190
|
+
else:
|
|
191
|
+
return sync_wrapper
|
|
192
|
+
|
|
193
|
+
return decorator
|
|
194
|
+
|
|
195
|
+
def _add_function_args_to_span(self, span: Optional[Span], args: tuple, kwargs: Dict[str, Any]):
|
|
196
|
+
"""Add function arguments to span tags"""
|
|
197
|
+
if not span:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
# Add positional arguments
|
|
202
|
+
for i, arg in enumerate(args):
|
|
203
|
+
if isinstance(arg, (str, int, float, bool)):
|
|
204
|
+
span.set_tag(f"arg_{i}", arg)
|
|
205
|
+
elif hasattr(arg, '__class__'):
|
|
206
|
+
span.set_tag(f"arg_{i}_type", arg.__class__.__name__)
|
|
207
|
+
|
|
208
|
+
# Add keyword arguments
|
|
209
|
+
for key, value in kwargs.items():
|
|
210
|
+
if isinstance(value, (str, int, float, bool)):
|
|
211
|
+
span.set_tag(key, value)
|
|
212
|
+
elif isinstance(value, dict) and len(str(value)) < 1000: # Avoid overly large dictionaries
|
|
213
|
+
span.set_tag(f"{key}_json", str(value))
|
|
214
|
+
elif hasattr(value, '__class__'):
|
|
215
|
+
span.set_tag(f"{key}_type", value.__class__.__name__)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.debug(f"Error adding function args to span: {e}")
|
|
218
|
+
|
|
219
|
+
def trace_database_operation(self, operation: str, table: str = None, query: str = None):
|
|
220
|
+
"""Database operation tracing decorator"""
|
|
221
|
+
def decorator(func):
|
|
222
|
+
@functools.wraps(func)
|
|
223
|
+
async def wrapper(*args, **kwargs):
|
|
224
|
+
tags = {
|
|
225
|
+
"component": "database",
|
|
226
|
+
"db.type": "postgresql",
|
|
227
|
+
"db.statement.type": operation
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
if table:
|
|
231
|
+
tags["db.table"] = table
|
|
232
|
+
if query:
|
|
233
|
+
tags["db.statement"] = query[:500] # Limit query length
|
|
234
|
+
|
|
235
|
+
span = self.start_span(f"db.{operation}", tags=tags)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
result = await func(*args, **kwargs)
|
|
239
|
+
if span:
|
|
240
|
+
span.set_tag("db.rows_affected", len(result) if isinstance(result, list) else 1)
|
|
241
|
+
return result
|
|
242
|
+
except Exception as e:
|
|
243
|
+
self.finish_span(span, error=e)
|
|
244
|
+
raise
|
|
245
|
+
finally:
|
|
246
|
+
if span and not span.finished:
|
|
247
|
+
self.finish_span(span)
|
|
248
|
+
|
|
249
|
+
return wrapper
|
|
250
|
+
return decorator
|
|
251
|
+
|
|
252
|
+
def trace_external_call(self, service_name: str, endpoint: str = None):
|
|
253
|
+
"""External service call tracing decorator"""
|
|
254
|
+
def decorator(func):
|
|
255
|
+
@functools.wraps(func)
|
|
256
|
+
async def wrapper(*args, **kwargs):
|
|
257
|
+
tags = {
|
|
258
|
+
"component": "http",
|
|
259
|
+
"span.kind": "client",
|
|
260
|
+
"peer.service": service_name
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if endpoint:
|
|
264
|
+
tags["http.url"] = endpoint
|
|
265
|
+
|
|
266
|
+
span = self.start_span(f"http.{service_name}", tags=tags)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
result = await func(*args, **kwargs)
|
|
270
|
+
if span:
|
|
271
|
+
span.set_tag("http.status_code", 200)
|
|
272
|
+
return result
|
|
273
|
+
except Exception as e:
|
|
274
|
+
if span:
|
|
275
|
+
span.set_tag("http.status_code", 500)
|
|
276
|
+
self.finish_span(span, error=e)
|
|
277
|
+
raise
|
|
278
|
+
finally:
|
|
279
|
+
if span and not span.finished:
|
|
280
|
+
self.finish_span(span)
|
|
281
|
+
|
|
282
|
+
return wrapper
|
|
283
|
+
return decorator
|
|
284
|
+
|
|
285
|
+
def trace_tool_execution(self, tool_name: str, operation: str):
|
|
286
|
+
"""Tool execution tracing decorator"""
|
|
287
|
+
def decorator(func):
|
|
288
|
+
@functools.wraps(func)
|
|
289
|
+
async def wrapper(*args, **kwargs):
|
|
290
|
+
tags = {
|
|
291
|
+
"component": "tool",
|
|
292
|
+
"tool.name": tool_name,
|
|
293
|
+
"tool.operation": operation
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
span = self.start_span(f"tool.{tool_name}.{operation}", tags=tags)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
result = await func(*args, **kwargs)
|
|
300
|
+
if span:
|
|
301
|
+
span.set_tag("tool.success", True)
|
|
302
|
+
if hasattr(result, '__len__'):
|
|
303
|
+
span.set_tag("tool.result_size", len(result))
|
|
304
|
+
return result
|
|
305
|
+
except Exception as e:
|
|
306
|
+
if span:
|
|
307
|
+
span.set_tag("tool.success", False)
|
|
308
|
+
self.finish_span(span, error=e)
|
|
309
|
+
raise
|
|
310
|
+
finally:
|
|
311
|
+
if span and not span.finished:
|
|
312
|
+
self.finish_span(span)
|
|
313
|
+
|
|
314
|
+
return wrapper
|
|
315
|
+
return decorator
|
|
316
|
+
|
|
317
|
+
def create_child_span(self, parent_span: Optional[Span], operation_name: str,
|
|
318
|
+
tags: Optional[Dict[str, Any]] = None) -> Optional[Span]:
|
|
319
|
+
"""Create child span"""
|
|
320
|
+
if not self.enable_tracing or not parent_span:
|
|
321
|
+
return None
|
|
322
|
+
|
|
323
|
+
return self.start_span(operation_name, parent_span=parent_span, tags=tags)
|
|
324
|
+
|
|
325
|
+
def inject_span_context(self, span: Optional[Span], carrier: Dict[str, str]):
|
|
326
|
+
"""Inject span context into carrier (for cross-service propagation)"""
|
|
327
|
+
if not self.enable_tracing or not span or not self.tracer:
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
from opentracing.propagation import Format
|
|
332
|
+
self.tracer.inject(span.context, Format.TEXT_MAP, carrier)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.error(f"Error injecting span context: {e}")
|
|
335
|
+
|
|
336
|
+
def extract_span_context(self, carrier: Dict[str, str]) -> Optional[Any]:
|
|
337
|
+
"""Extract span context from carrier"""
|
|
338
|
+
if not self.enable_tracing or not self.tracer:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
from opentracing.propagation import Format
|
|
343
|
+
return self.tracer.extract(Format.TEXT_MAP, carrier)
|
|
344
|
+
except Exception as e:
|
|
345
|
+
logger.error(f"Error extracting span context: {e}")
|
|
346
|
+
return None
|
|
347
|
+
|
|
348
|
+
def get_active_span(self) -> Optional[Span]:
|
|
349
|
+
"""Get current active span"""
|
|
350
|
+
if not self.enable_tracing or not self.tracer:
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
return self.tracer.active_span
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.error(f"Error getting active span: {e}")
|
|
357
|
+
return None
|
|
358
|
+
|
|
359
|
+
def close_tracer(self):
|
|
360
|
+
"""Close tracer"""
|
|
361
|
+
if self.tracer:
|
|
362
|
+
try:
|
|
363
|
+
self.tracer.close()
|
|
364
|
+
logger.info("Tracer closed successfully")
|
|
365
|
+
except Exception as e:
|
|
366
|
+
logger.error(f"Error closing tracer: {e}")
|
|
367
|
+
|
|
368
|
+
def get_tracer_info(self) -> Dict[str, Any]:
|
|
369
|
+
"""Get tracer information"""
|
|
370
|
+
return {
|
|
371
|
+
"enabled": self.enable_tracing,
|
|
372
|
+
"service_name": self.service_name,
|
|
373
|
+
"jaeger_host": self.jaeger_host,
|
|
374
|
+
"jaeger_port": self.jaeger_port,
|
|
375
|
+
"tracer_initialized": self.tracer is not None
|
|
376
|
+
}
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
import asyncpg
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Dict, List, Any, Optional
|
|
6
|
+
from aiecs.domain.execution.model import TaskStatus, TaskStepResult
|
|
7
|
+
from aiecs.config.config import get_settings
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatabaseManager:
|
|
13
|
+
"""
|
|
14
|
+
Specialized handler for database connections, operations, and task history management
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, db_config: Optional[Dict[str, Any]] = None):
|
|
18
|
+
if db_config is None:
|
|
19
|
+
settings = get_settings()
|
|
20
|
+
self.db_config = settings.database_config
|
|
21
|
+
else:
|
|
22
|
+
self.db_config = db_config
|
|
23
|
+
self.connection_pool = None
|
|
24
|
+
self._initialized = False
|
|
25
|
+
|
|
26
|
+
async def init_connection_pool(self, min_size: int = 10, max_size: int = 20):
|
|
27
|
+
"""Initialize database connection pool"""
|
|
28
|
+
try:
|
|
29
|
+
self.connection_pool = await asyncpg.create_pool(
|
|
30
|
+
**self.db_config,
|
|
31
|
+
min_size=min_size,
|
|
32
|
+
max_size=max_size
|
|
33
|
+
)
|
|
34
|
+
logger.info("Database connection pool initialized successfully")
|
|
35
|
+
except Exception as e:
|
|
36
|
+
logger.error(f"Failed to initialize database connection pool: {e}")
|
|
37
|
+
raise
|
|
38
|
+
|
|
39
|
+
async def _get_connection(self):
|
|
40
|
+
"""Get database connection"""
|
|
41
|
+
if self.connection_pool:
|
|
42
|
+
return self.connection_pool.acquire()
|
|
43
|
+
else:
|
|
44
|
+
return asyncpg.connect(**self.db_config)
|
|
45
|
+
|
|
46
|
+
async def init_database_schema(self):
|
|
47
|
+
"""Initialize database table structure"""
|
|
48
|
+
try:
|
|
49
|
+
if self.connection_pool:
|
|
50
|
+
async with self.connection_pool.acquire() as conn:
|
|
51
|
+
await self._create_tables(conn)
|
|
52
|
+
else:
|
|
53
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
54
|
+
try:
|
|
55
|
+
await self._create_tables(conn)
|
|
56
|
+
finally:
|
|
57
|
+
await conn.close()
|
|
58
|
+
|
|
59
|
+
self._initialized = True
|
|
60
|
+
logger.info("Database schema initialized successfully")
|
|
61
|
+
return True
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.error(f"Database initialization error: {e}")
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
async def _create_tables(self, conn):
|
|
67
|
+
"""Create database tables"""
|
|
68
|
+
await conn.execute('''
|
|
69
|
+
CREATE TABLE IF NOT EXISTS task_history (
|
|
70
|
+
id SERIAL PRIMARY KEY,
|
|
71
|
+
user_id TEXT NOT NULL,
|
|
72
|
+
task_id TEXT NOT NULL,
|
|
73
|
+
step INTEGER NOT NULL,
|
|
74
|
+
result JSONB NOT NULL,
|
|
75
|
+
timestamp TIMESTAMP NOT NULL,
|
|
76
|
+
status TEXT NOT NULL DEFAULT 'pending'
|
|
77
|
+
);
|
|
78
|
+
CREATE INDEX IF NOT EXISTS idx_task_history_user_id ON task_history (user_id);
|
|
79
|
+
CREATE INDEX IF NOT EXISTS idx_task_history_task_id ON task_history (task_id);
|
|
80
|
+
CREATE INDEX IF NOT EXISTS idx_task_history_status ON task_history (status);
|
|
81
|
+
CREATE INDEX IF NOT EXISTS idx_task_history_timestamp ON task_history (timestamp);
|
|
82
|
+
''')
|
|
83
|
+
|
|
84
|
+
async def save_task_history(self, user_id: str, task_id: str, step: int, step_result: TaskStepResult):
|
|
85
|
+
"""Save task execution history"""
|
|
86
|
+
if not self._initialized:
|
|
87
|
+
await self.init_database_schema()
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
if self.connection_pool:
|
|
91
|
+
async with self.connection_pool.acquire() as conn:
|
|
92
|
+
await conn.execute(
|
|
93
|
+
'INSERT INTO task_history (user_id, task_id, step, result, timestamp, status) VALUES ($1, $2, $3, $4, $5, $6)',
|
|
94
|
+
user_id, task_id, step, json.dumps(step_result.dict()), datetime.now(), step_result.status
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
98
|
+
try:
|
|
99
|
+
await conn.execute(
|
|
100
|
+
'INSERT INTO task_history (user_id, task_id, step, result, timestamp, status) VALUES ($1, $2, $3, $4, $5, $6)',
|
|
101
|
+
user_id, task_id, step, json.dumps(step_result.dict()), datetime.now(), step_result.status
|
|
102
|
+
)
|
|
103
|
+
finally:
|
|
104
|
+
await conn.close()
|
|
105
|
+
|
|
106
|
+
logger.debug(f"Saved task history for user {user_id}, task {task_id}, step {step}")
|
|
107
|
+
return True
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"Database error saving task history: {e}")
|
|
110
|
+
raise Exception(f"Database error: {e}")
|
|
111
|
+
|
|
112
|
+
async def load_task_history(self, user_id: str, task_id: str) -> List[Dict]:
|
|
113
|
+
"""Load task execution history"""
|
|
114
|
+
if not self._initialized:
|
|
115
|
+
await self.init_database_schema()
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
if self.connection_pool:
|
|
119
|
+
async with self.connection_pool.acquire() as conn:
|
|
120
|
+
records = await conn.fetch(
|
|
121
|
+
'SELECT step, result, timestamp, status FROM task_history WHERE user_id = $1 AND task_id = $2 ORDER BY step ASC',
|
|
122
|
+
user_id, task_id
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
126
|
+
try:
|
|
127
|
+
records = await conn.fetch(
|
|
128
|
+
'SELECT step, result, timestamp, status FROM task_history WHERE user_id = $1 AND task_id = $2 ORDER BY step ASC',
|
|
129
|
+
user_id, task_id
|
|
130
|
+
)
|
|
131
|
+
finally:
|
|
132
|
+
await conn.close()
|
|
133
|
+
|
|
134
|
+
return [
|
|
135
|
+
{
|
|
136
|
+
"step": r['step'],
|
|
137
|
+
"result": json.loads(r['result']),
|
|
138
|
+
"timestamp": r['timestamp'].isoformat(),
|
|
139
|
+
"status": r['status']
|
|
140
|
+
}
|
|
141
|
+
for r in records
|
|
142
|
+
]
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Database error loading task history: {e}")
|
|
145
|
+
raise Exception(f"Database error: {e}")
|
|
146
|
+
|
|
147
|
+
async def mark_task_as_cancelled(self, user_id: str, task_id: str):
|
|
148
|
+
"""Mark task as cancelled"""
|
|
149
|
+
if not self._initialized:
|
|
150
|
+
await self.init_database_schema()
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
if self.connection_pool:
|
|
154
|
+
async with self.connection_pool.acquire() as conn:
|
|
155
|
+
await conn.execute(
|
|
156
|
+
'UPDATE task_history SET status = $1 WHERE user_id = $2 AND task_id = $3',
|
|
157
|
+
TaskStatus.CANCELLED.value, user_id, task_id
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
161
|
+
try:
|
|
162
|
+
await conn.execute(
|
|
163
|
+
'UPDATE task_history SET status = $1 WHERE user_id = $2 AND task_id = $3',
|
|
164
|
+
TaskStatus.CANCELLED.value, user_id, task_id
|
|
165
|
+
)
|
|
166
|
+
finally:
|
|
167
|
+
await conn.close()
|
|
168
|
+
|
|
169
|
+
logger.info(f"Marked task {task_id} as cancelled for user {user_id}")
|
|
170
|
+
return True
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.error(f"Database error marking task as cancelled: {e}")
|
|
173
|
+
raise Exception(f"Database error: {e}")
|
|
174
|
+
|
|
175
|
+
async def check_task_status(self, user_id: str, task_id: str) -> TaskStatus:
|
|
176
|
+
"""Check task status"""
|
|
177
|
+
if not self._initialized:
|
|
178
|
+
await self.init_database_schema()
|
|
179
|
+
|
|
180
|
+
try:
|
|
181
|
+
if self.connection_pool:
|
|
182
|
+
async with self.connection_pool.acquire() as conn:
|
|
183
|
+
record = await conn.fetchrow(
|
|
184
|
+
'SELECT status FROM task_history WHERE user_id = $1 AND task_id = $2 ORDER BY step DESC LIMIT 1',
|
|
185
|
+
user_id, task_id
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
189
|
+
try:
|
|
190
|
+
record = await conn.fetchrow(
|
|
191
|
+
'SELECT status FROM task_history WHERE user_id = $1 AND task_id = $2 ORDER BY step DESC LIMIT 1',
|
|
192
|
+
user_id, task_id
|
|
193
|
+
)
|
|
194
|
+
finally:
|
|
195
|
+
await conn.close()
|
|
196
|
+
|
|
197
|
+
return TaskStatus(record['status']) if record else TaskStatus.PENDING
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error(f"Database error checking task status: {e}")
|
|
200
|
+
raise Exception(f"Database error: {e}")
|
|
201
|
+
|
|
202
|
+
async def get_user_tasks(self, user_id: str, limit: int = 100) -> List[Dict]:
|
|
203
|
+
"""Get user task list"""
|
|
204
|
+
if not self._initialized:
|
|
205
|
+
await self.init_database_schema()
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
if self.connection_pool:
|
|
209
|
+
async with self.connection_pool.acquire() as conn:
|
|
210
|
+
records = await conn.fetch(
|
|
211
|
+
'''SELECT DISTINCT task_id,
|
|
212
|
+
MAX(timestamp) as last_updated,
|
|
213
|
+
(SELECT status FROM task_history th2
|
|
214
|
+
WHERE th2.user_id = $1 AND th2.task_id = th1.task_id
|
|
215
|
+
ORDER BY step DESC LIMIT 1) as status
|
|
216
|
+
FROM task_history th1
|
|
217
|
+
WHERE user_id = $1
|
|
218
|
+
GROUP BY task_id
|
|
219
|
+
ORDER BY last_updated DESC
|
|
220
|
+
LIMIT $2''',
|
|
221
|
+
user_id, limit
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
225
|
+
try:
|
|
226
|
+
records = await conn.fetch(
|
|
227
|
+
'''SELECT DISTINCT task_id,
|
|
228
|
+
MAX(timestamp) as last_updated,
|
|
229
|
+
(SELECT status FROM task_history th2
|
|
230
|
+
WHERE th2.user_id = $1 AND th2.task_id = th1.task_id
|
|
231
|
+
ORDER BY step DESC LIMIT 1) as status
|
|
232
|
+
FROM task_history th1
|
|
233
|
+
WHERE user_id = $1
|
|
234
|
+
GROUP BY task_id
|
|
235
|
+
ORDER BY last_updated DESC
|
|
236
|
+
LIMIT $2''',
|
|
237
|
+
user_id, limit
|
|
238
|
+
)
|
|
239
|
+
finally:
|
|
240
|
+
await conn.close()
|
|
241
|
+
|
|
242
|
+
return [
|
|
243
|
+
{
|
|
244
|
+
"task_id": r['task_id'],
|
|
245
|
+
"last_updated": r['last_updated'].isoformat(),
|
|
246
|
+
"status": r['status']
|
|
247
|
+
}
|
|
248
|
+
for r in records
|
|
249
|
+
]
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(f"Database error getting user tasks: {e}")
|
|
252
|
+
raise Exception(f"Database error: {e}")
|
|
253
|
+
|
|
254
|
+
async def cleanup_old_tasks(self, days_old: int = 30):
|
|
255
|
+
"""Clean up old task records"""
|
|
256
|
+
if not self._initialized:
|
|
257
|
+
await self.init_database_schema()
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
if self.connection_pool:
|
|
261
|
+
async with self.connection_pool.acquire() as conn:
|
|
262
|
+
result = await conn.execute(
|
|
263
|
+
'DELETE FROM task_history WHERE timestamp < NOW() - INTERVAL %s DAY',
|
|
264
|
+
days_old
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
conn = await asyncpg.connect(**self.db_config)
|
|
268
|
+
try:
|
|
269
|
+
result = await conn.execute(
|
|
270
|
+
'DELETE FROM task_history WHERE timestamp < NOW() - INTERVAL %s DAY',
|
|
271
|
+
days_old
|
|
272
|
+
)
|
|
273
|
+
finally:
|
|
274
|
+
await conn.close()
|
|
275
|
+
|
|
276
|
+
logger.info(f"Cleaned up old task records: {result}")
|
|
277
|
+
return True
|
|
278
|
+
except Exception as e:
|
|
279
|
+
logger.error(f"Database error during cleanup: {e}")
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
async def close(self):
|
|
283
|
+
"""Close database connection pool"""
|
|
284
|
+
if self.connection_pool:
|
|
285
|
+
await self.connection_pool.close()
|
|
286
|
+
logger.info("Database connection pool closed")
|