azure-ai-evaluation 1.10.0__py3-none-any.whl → 1.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/_common/onedp/models/_models.py +5 -0
- azure/ai/evaluation/_converters/_ai_services.py +60 -10
- azure/ai/evaluation/_converters/_models.py +75 -26
- azure/ai/evaluation/_evaluate/_eval_run.py +14 -1
- azure/ai/evaluation/_evaluate/_evaluate.py +13 -4
- azure/ai/evaluation/_evaluate/_evaluate_aoai.py +104 -35
- azure/ai/evaluation/_evaluate/_utils.py +4 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +2 -1
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +113 -19
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +7 -2
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +1 -1
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +2 -1
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +113 -3
- azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +8 -2
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +2 -1
- azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +10 -2
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +2 -1
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +2 -1
- azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +8 -2
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +104 -60
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +58 -41
- azure/ai/evaluation/_exceptions.py +1 -0
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/red_team/__init__.py +2 -1
- azure/ai/evaluation/red_team/_attack_objective_generator.py +17 -0
- azure/ai/evaluation/red_team/_callback_chat_target.py +14 -1
- azure/ai/evaluation/red_team/_evaluation_processor.py +376 -0
- azure/ai/evaluation/red_team/_mlflow_integration.py +322 -0
- azure/ai/evaluation/red_team/_orchestrator_manager.py +661 -0
- azure/ai/evaluation/red_team/_red_team.py +697 -3067
- azure/ai/evaluation/red_team/_result_processor.py +610 -0
- azure/ai/evaluation/red_team/_utils/__init__.py +34 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +3 -1
- azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +6 -0
- azure/ai/evaluation/red_team/_utils/exception_utils.py +345 -0
- azure/ai/evaluation/red_team/_utils/file_utils.py +266 -0
- azure/ai/evaluation/red_team/_utils/formatting_utils.py +115 -13
- azure/ai/evaluation/red_team/_utils/metric_mapping.py +24 -4
- azure/ai/evaluation/red_team/_utils/progress_utils.py +252 -0
- azure/ai/evaluation/red_team/_utils/retry_utils.py +218 -0
- azure/ai/evaluation/red_team/_utils/strategy_utils.py +17 -4
- azure/ai/evaluation/simulator/_adversarial_simulator.py +9 -0
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +19 -5
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +4 -3
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.1.dist-info}/METADATA +39 -3
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.1.dist-info}/RECORD +49 -41
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.1.dist-info}/WHEEL +1 -1
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.1.dist-info/licenses}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.1.dist-info}/top_level.txt +0 -0
|
@@ -3,27 +3,29 @@ Utility functions for formatting, conversion, and processing in Red Team Agent.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
|
-
import pandas as pd
|
|
7
6
|
import math
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
import itertools
|
|
8
|
+
import os
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Dict, List, Union, Any
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from pyrit.models import ChatMessage
|
|
13
|
+
from pyrit.memory import CentralMemory
|
|
10
14
|
from .._attack_strategy import AttackStrategy
|
|
11
15
|
from .._red_team_result import RedTeamResult
|
|
12
|
-
from pyrit.models import ChatMessage
|
|
13
16
|
|
|
14
17
|
|
|
15
|
-
def message_to_dict(message: ChatMessage) -> Dict[str, str]:
|
|
16
|
-
"""Convert a ChatMessage to dictionary format.
|
|
18
|
+
def message_to_dict(message: ChatMessage, context: str = None) -> Dict[str, str]:
|
|
19
|
+
"""Convert a ChatMessage and context to dictionary format.
|
|
17
20
|
|
|
18
21
|
:param message: The chat message to convert
|
|
19
22
|
:type message: ChatMessage
|
|
23
|
+
:param context: Additional context to include in the dictionary
|
|
24
|
+
:type context: str
|
|
20
25
|
:return: Dictionary representation with role and content
|
|
21
26
|
:rtype: Dict[str, str]
|
|
22
27
|
"""
|
|
23
|
-
return {
|
|
24
|
-
"role": message.role,
|
|
25
|
-
"content": message.content,
|
|
26
|
-
}
|
|
28
|
+
return {"role": message.role, "content": message.content, "context": context}
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
def get_strategy_name(attack_strategy: Union[AttackStrategy, List[AttackStrategy]]) -> str:
|
|
@@ -89,7 +91,7 @@ def get_attack_success(result: str) -> bool:
|
|
|
89
91
|
"""
|
|
90
92
|
Based on EVALUATION_PASS_FAIL_MAPPING, returns True denoting attack success if result matches the value of the EVALUATION_PASS_FAIL_MAPPING key False.
|
|
91
93
|
"""
|
|
92
|
-
from
|
|
94
|
+
from azure.ai.evaluation._constants import EVALUATION_PASS_FAIL_MAPPING
|
|
93
95
|
|
|
94
96
|
return result == EVALUATION_PASS_FAIL_MAPPING.get(False, "true")
|
|
95
97
|
|
|
@@ -117,7 +119,7 @@ def format_scorecard(redteam_result: RedTeamResult) -> str:
|
|
|
117
119
|
separator = "-" * 132
|
|
118
120
|
output.append(separator)
|
|
119
121
|
output.append(
|
|
120
|
-
f"{'Risk Category':<
|
|
122
|
+
f"{'Risk Category':<21}| {'Baseline ASR':<14} | {'Easy-Complexity Attacks ASR':<28} | {'Moderate-Complexity Attacks ASR':<30} | {'Difficult-Complexity Attacks ASR':<30}"
|
|
121
123
|
)
|
|
122
124
|
output.append(separator)
|
|
123
125
|
|
|
@@ -134,7 +136,7 @@ def format_scorecard(redteam_result: RedTeamResult) -> str:
|
|
|
134
136
|
moderate = "N/A" if is_none_or_nan(moderate_val) else f"{moderate_val}%"
|
|
135
137
|
difficult = "N/A" if is_none_or_nan(difficult_val) else f"{difficult_val}%"
|
|
136
138
|
|
|
137
|
-
output.append(f"{risk_category:<
|
|
139
|
+
output.append(f"{risk_category:<21}| {baseline:<14} | {easy:<28} | {moderate:<31} | {difficult:<30}")
|
|
138
140
|
|
|
139
141
|
return "\n".join(output)
|
|
140
142
|
|
|
@@ -160,3 +162,103 @@ def list_mean_nan_safe(data_list: List[Any]) -> float:
|
|
|
160
162
|
if not filtered_list:
|
|
161
163
|
return 0.0
|
|
162
164
|
return sum(filtered_list) / len(filtered_list)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def write_pyrit_outputs_to_file(
|
|
168
|
+
*,
|
|
169
|
+
output_path: str,
|
|
170
|
+
logger: logging.Logger,
|
|
171
|
+
prompt_to_context: Dict[str, str],
|
|
172
|
+
) -> str:
|
|
173
|
+
"""Write PyRIT outputs to a file with a name based on orchestrator, strategy, and risk category.
|
|
174
|
+
|
|
175
|
+
:param output_path: Path to write the output file
|
|
176
|
+
:type output_path: str
|
|
177
|
+
:param logger: Logger instance for logging
|
|
178
|
+
:type logger: logging.Logger
|
|
179
|
+
:param prompt_to_context: Mapping of prompts to their context
|
|
180
|
+
:type prompt_to_context: Dict[str, str]
|
|
181
|
+
:return: Path to the output file
|
|
182
|
+
:rtype: str
|
|
183
|
+
:raises IOError: If the output file cannot be read or written
|
|
184
|
+
:raises PermissionError: If there are insufficient permissions to access the output file
|
|
185
|
+
:raises Exception: For other unexpected errors during file operations or memory retrieval
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
logger.debug(f"Writing PyRIT outputs to file: {output_path}")
|
|
189
|
+
memory = CentralMemory.get_memory_instance()
|
|
190
|
+
|
|
191
|
+
memory_label = {"risk_strategy_path": output_path}
|
|
192
|
+
|
|
193
|
+
prompts_request_pieces = memory.get_prompt_request_pieces(labels=memory_label)
|
|
194
|
+
|
|
195
|
+
conversations = [
|
|
196
|
+
[
|
|
197
|
+
(item.to_chat_message(), prompt_to_context.get(item.original_value, "") or item.labels.get("context", ""))
|
|
198
|
+
for item in group
|
|
199
|
+
]
|
|
200
|
+
for conv_id, group in itertools.groupby(prompts_request_pieces, key=lambda x: x.conversation_id)
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Check if we should overwrite existing file with more conversations
|
|
204
|
+
if os.path.exists(output_path):
|
|
205
|
+
existing_line_count = 0
|
|
206
|
+
try:
|
|
207
|
+
with open(output_path, "r") as existing_file:
|
|
208
|
+
existing_line_count = sum(1 for _ in existing_file)
|
|
209
|
+
|
|
210
|
+
if len(conversations) > existing_line_count:
|
|
211
|
+
logger.debug(
|
|
212
|
+
f"Found more prompts ({len(conversations)}) than existing file lines ({existing_line_count}). Replacing content."
|
|
213
|
+
)
|
|
214
|
+
# Convert to json lines
|
|
215
|
+
json_lines = ""
|
|
216
|
+
for conversation in conversations:
|
|
217
|
+
if conversation[0][0].role == "system":
|
|
218
|
+
# Skip system messages in the output
|
|
219
|
+
continue
|
|
220
|
+
json_lines += (
|
|
221
|
+
json.dumps(
|
|
222
|
+
{
|
|
223
|
+
"conversation": {
|
|
224
|
+
"messages": [message_to_dict(message[0], message[1]) for message in conversation]
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
+ "\n"
|
|
229
|
+
)
|
|
230
|
+
with Path(output_path).open("w") as f:
|
|
231
|
+
f.writelines(json_lines)
|
|
232
|
+
logger.debug(
|
|
233
|
+
f"Successfully wrote {len(conversations)-existing_line_count} new conversation(s) to {output_path}"
|
|
234
|
+
)
|
|
235
|
+
else:
|
|
236
|
+
logger.debug(
|
|
237
|
+
f"Existing file has {existing_line_count} lines, new data has {len(conversations)} prompts. Keeping existing file."
|
|
238
|
+
)
|
|
239
|
+
return output_path
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.warning(f"Failed to read existing file {output_path}: {str(e)}")
|
|
242
|
+
else:
|
|
243
|
+
logger.debug(f"Creating new file: {output_path}")
|
|
244
|
+
# Convert to json lines
|
|
245
|
+
json_lines = ""
|
|
246
|
+
|
|
247
|
+
for conversation in conversations:
|
|
248
|
+
if conversation[0][0].role == "system":
|
|
249
|
+
# Skip system messages in the output
|
|
250
|
+
continue
|
|
251
|
+
json_lines += (
|
|
252
|
+
json.dumps(
|
|
253
|
+
{
|
|
254
|
+
"conversation": {
|
|
255
|
+
"messages": [message_to_dict(message[0], message[1]) for message in conversation]
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
+ "\n"
|
|
260
|
+
)
|
|
261
|
+
with Path(output_path).open("w") as f:
|
|
262
|
+
f.writelines(json_lines)
|
|
263
|
+
logger.debug(f"Successfully wrote {len(conversations)} conversations to {output_path}")
|
|
264
|
+
return str(output_path)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Mapping utilities for metrics and risk categories in Red Team Agent.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Union
|
|
5
6
|
from .._attack_objective_generator import _InternalRiskCategory, RiskCategory
|
|
6
7
|
from azure.ai.evaluation._constants import EvaluationMetrics, _InternalEvaluationMetrics
|
|
7
8
|
from azure.ai.evaluation._common.constants import Tasks, _InternalAnnotationTasks
|
|
@@ -13,6 +14,8 @@ RISK_CATEGORY_METRIC_MAP = {
|
|
|
13
14
|
RiskCategory.Sexual: EvaluationMetrics.SEXUAL,
|
|
14
15
|
RiskCategory.SelfHarm: EvaluationMetrics.SELF_HARM,
|
|
15
16
|
RiskCategory.ProtectedMaterial: EvaluationMetrics.PROTECTED_MATERIAL,
|
|
17
|
+
RiskCategory.UngroundedAttributes: EvaluationMetrics.UNGROUNDED_ATTRIBUTES,
|
|
18
|
+
RiskCategory.IndirectAttack: EvaluationMetrics.XPIA,
|
|
16
19
|
_InternalRiskCategory.ECI: _InternalEvaluationMetrics.ECI,
|
|
17
20
|
RiskCategory.CodeVulnerability: EvaluationMetrics.CODE_VULNERABILITY,
|
|
18
21
|
}
|
|
@@ -23,28 +26,45 @@ RISK_CATEGORY_ANNOTATION_TASK_MAP = {
|
|
|
23
26
|
RiskCategory.Sexual: Tasks.CONTENT_HARM,
|
|
24
27
|
RiskCategory.SelfHarm: Tasks.CONTENT_HARM,
|
|
25
28
|
RiskCategory.ProtectedMaterial: Tasks.PROTECTED_MATERIAL,
|
|
29
|
+
RiskCategory.UngroundedAttributes: Tasks.UNGROUNDED_ATTRIBUTES,
|
|
30
|
+
RiskCategory.IndirectAttack: Tasks.XPIA,
|
|
26
31
|
_InternalRiskCategory.ECI: _InternalAnnotationTasks.ECI,
|
|
27
32
|
RiskCategory.CodeVulnerability: Tasks.CODE_VULNERABILITY,
|
|
28
33
|
}
|
|
29
34
|
|
|
30
35
|
|
|
31
|
-
def get_metric_from_risk_category(risk_category: RiskCategory) -> str:
|
|
36
|
+
def get_metric_from_risk_category(risk_category: Union[RiskCategory]) -> str:
|
|
32
37
|
"""Get the appropriate evaluation metric for a given risk category.
|
|
33
38
|
|
|
34
39
|
:param risk_category: The risk category to map to an evaluation metric
|
|
35
|
-
:type risk_category: RiskCategory
|
|
40
|
+
:type risk_category: Union[RiskCategory]
|
|
36
41
|
:return: The corresponding evaluation metric
|
|
37
42
|
:rtype: str
|
|
38
43
|
"""
|
|
39
44
|
return RISK_CATEGORY_METRIC_MAP.get(risk_category, EvaluationMetrics.HATE_FAIRNESS)
|
|
40
45
|
|
|
41
46
|
|
|
42
|
-
def get_annotation_task_from_risk_category(risk_category: RiskCategory) -> str:
|
|
47
|
+
def get_annotation_task_from_risk_category(risk_category: Union[RiskCategory]) -> str:
|
|
43
48
|
"""
|
|
44
49
|
Get the appropriate annotation task for a given risk category.
|
|
45
50
|
:param risk_category: The risk category to map to an annotation task
|
|
46
|
-
:type risk_category: RiskCategory
|
|
51
|
+
:type risk_category: Union[RiskCategory]
|
|
47
52
|
:return: The corresponding annotation task
|
|
48
53
|
:rtype: str
|
|
49
54
|
"""
|
|
50
55
|
return RISK_CATEGORY_ANNOTATION_TASK_MAP.get(risk_category, Tasks.CONTENT_HARM)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_attack_objective_from_risk_category(risk_category: Union[RiskCategory]) -> str:
|
|
59
|
+
"""Get the attack objective string for a given risk category.
|
|
60
|
+
|
|
61
|
+
:param risk_category: The risk category to map to an attack objective
|
|
62
|
+
:type risk_category: Union[RiskCategory]
|
|
63
|
+
:return: The corresponding attack objective string
|
|
64
|
+
:rtype: str
|
|
65
|
+
"""
|
|
66
|
+
return (
|
|
67
|
+
"isa"
|
|
68
|
+
if risk_category == RiskCategory.UngroundedAttributes
|
|
69
|
+
else "xpia" if risk_category == RiskCategory.IndirectAttack else risk_category.value
|
|
70
|
+
)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
"""
|
|
5
|
+
Progress and status management utilities for Red Team Agent.
|
|
6
|
+
|
|
7
|
+
This module provides centralized progress tracking, task status management,
|
|
8
|
+
and user feedback utilities for red team operations.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import time
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Dict, Optional, Any
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from .constants import TASK_STATUS
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProgressManager:
|
|
21
|
+
"""Centralized progress and status tracking for Red Team operations."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, total_tasks: int = 0, logger=None, show_progress_bar: bool = True, progress_desc: str = "Processing"
|
|
25
|
+
):
|
|
26
|
+
"""Initialize progress manager.
|
|
27
|
+
|
|
28
|
+
:param total_tasks: Total number of tasks to track
|
|
29
|
+
:param logger: Logger instance for progress messages
|
|
30
|
+
:param show_progress_bar: Whether to show a progress bar
|
|
31
|
+
:param progress_desc: Description for the progress bar
|
|
32
|
+
"""
|
|
33
|
+
self.total_tasks = total_tasks
|
|
34
|
+
self.completed_tasks = 0
|
|
35
|
+
self.failed_tasks = 0
|
|
36
|
+
self.timeout_tasks = 0
|
|
37
|
+
self.logger = logger
|
|
38
|
+
self.show_progress_bar = show_progress_bar
|
|
39
|
+
self.progress_desc = progress_desc
|
|
40
|
+
|
|
41
|
+
# Task status tracking
|
|
42
|
+
self.task_statuses: Dict[str, str] = {}
|
|
43
|
+
|
|
44
|
+
# Timing
|
|
45
|
+
self.start_time: Optional[float] = None
|
|
46
|
+
self.end_time: Optional[float] = None
|
|
47
|
+
|
|
48
|
+
# Progress bar
|
|
49
|
+
self.progress_bar: Optional[tqdm] = None
|
|
50
|
+
self.progress_lock = asyncio.Lock()
|
|
51
|
+
|
|
52
|
+
def start(self) -> None:
|
|
53
|
+
"""Start progress tracking."""
|
|
54
|
+
self.start_time = time.time()
|
|
55
|
+
|
|
56
|
+
if self.show_progress_bar and self.total_tasks > 0:
|
|
57
|
+
self.progress_bar = tqdm(
|
|
58
|
+
total=self.total_tasks,
|
|
59
|
+
desc=f"{self.progress_desc}: ",
|
|
60
|
+
ncols=100,
|
|
61
|
+
unit="task",
|
|
62
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
|
|
63
|
+
)
|
|
64
|
+
self.progress_bar.set_postfix({"current": "initializing"})
|
|
65
|
+
|
|
66
|
+
def stop(self) -> None:
|
|
67
|
+
"""Stop progress tracking and cleanup."""
|
|
68
|
+
self.end_time = time.time()
|
|
69
|
+
|
|
70
|
+
if self.progress_bar:
|
|
71
|
+
self.progress_bar.close()
|
|
72
|
+
self.progress_bar = None
|
|
73
|
+
|
|
74
|
+
async def update_task_status(self, task_key: str, status: str, details: Optional[str] = None) -> None:
|
|
75
|
+
"""Update the status of a specific task.
|
|
76
|
+
|
|
77
|
+
:param task_key: Unique identifier for the task
|
|
78
|
+
:param status: New status for the task
|
|
79
|
+
:param details: Optional details about the status change
|
|
80
|
+
"""
|
|
81
|
+
old_status = self.task_statuses.get(task_key)
|
|
82
|
+
self.task_statuses[task_key] = status
|
|
83
|
+
|
|
84
|
+
# Update counters based on status change
|
|
85
|
+
if old_status != status:
|
|
86
|
+
if status == TASK_STATUS["COMPLETED"]:
|
|
87
|
+
self.completed_tasks += 1
|
|
88
|
+
await self._update_progress_bar()
|
|
89
|
+
elif status == TASK_STATUS["FAILED"]:
|
|
90
|
+
self.failed_tasks += 1
|
|
91
|
+
await self._update_progress_bar()
|
|
92
|
+
elif status == TASK_STATUS["TIMEOUT"]:
|
|
93
|
+
self.timeout_tasks += 1
|
|
94
|
+
await self._update_progress_bar()
|
|
95
|
+
|
|
96
|
+
# Log status change
|
|
97
|
+
if self.logger and details:
|
|
98
|
+
self.logger.debug(f"Task {task_key}: {old_status} -> {status} ({details})")
|
|
99
|
+
|
|
100
|
+
async def _update_progress_bar(self) -> None:
|
|
101
|
+
"""Update the progress bar display."""
|
|
102
|
+
if not self.progress_bar:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
async with self.progress_lock:
|
|
106
|
+
self.progress_bar.update(1)
|
|
107
|
+
|
|
108
|
+
completion_pct = (self.completed_tasks / self.total_tasks) * 100 if self.total_tasks > 0 else 0
|
|
109
|
+
|
|
110
|
+
# Calculate time estimates
|
|
111
|
+
if self.start_time:
|
|
112
|
+
elapsed_time = time.time() - self.start_time
|
|
113
|
+
if self.completed_tasks > 0:
|
|
114
|
+
avg_time_per_task = elapsed_time / self.completed_tasks
|
|
115
|
+
remaining_tasks = self.total_tasks - self.completed_tasks - self.failed_tasks - self.timeout_tasks
|
|
116
|
+
est_remaining_time = avg_time_per_task * remaining_tasks if remaining_tasks > 0 else 0
|
|
117
|
+
|
|
118
|
+
postfix = {
|
|
119
|
+
"completed": f"{completion_pct:.1f}%",
|
|
120
|
+
"failed": self.failed_tasks,
|
|
121
|
+
"timeout": self.timeout_tasks,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
if est_remaining_time > 0:
|
|
125
|
+
postfix["eta"] = f"{est_remaining_time/60:.1f}m"
|
|
126
|
+
|
|
127
|
+
self.progress_bar.set_postfix(postfix)
|
|
128
|
+
|
|
129
|
+
def write_progress_message(self, message: str) -> None:
|
|
130
|
+
"""Write a message that respects the progress bar.
|
|
131
|
+
|
|
132
|
+
:param message: Message to display
|
|
133
|
+
"""
|
|
134
|
+
if self.progress_bar:
|
|
135
|
+
tqdm.write(message)
|
|
136
|
+
else:
|
|
137
|
+
print(message)
|
|
138
|
+
|
|
139
|
+
def log_task_completion(
|
|
140
|
+
self, task_name: str, duration: float, success: bool = True, details: Optional[str] = None
|
|
141
|
+
) -> None:
|
|
142
|
+
"""Log the completion of a task.
|
|
143
|
+
|
|
144
|
+
:param task_name: Name of the completed task
|
|
145
|
+
:param duration: Duration in seconds
|
|
146
|
+
:param success: Whether the task completed successfully
|
|
147
|
+
:param details: Optional additional details
|
|
148
|
+
"""
|
|
149
|
+
status_icon = "✅" if success else "❌"
|
|
150
|
+
message = f"{status_icon} {task_name} completed in {duration:.1f}s"
|
|
151
|
+
|
|
152
|
+
if details:
|
|
153
|
+
message += f" - {details}"
|
|
154
|
+
|
|
155
|
+
self.write_progress_message(message)
|
|
156
|
+
|
|
157
|
+
if self.logger:
|
|
158
|
+
log_level = "info" if success else "warning"
|
|
159
|
+
getattr(self.logger, log_level)(message)
|
|
160
|
+
|
|
161
|
+
def log_task_timeout(self, task_name: str, timeout_duration: float) -> None:
|
|
162
|
+
"""Log a task timeout.
|
|
163
|
+
|
|
164
|
+
:param task_name: Name of the timed out task
|
|
165
|
+
:param timeout_duration: Timeout duration in seconds
|
|
166
|
+
"""
|
|
167
|
+
message = f"⚠️ TIMEOUT: {task_name} after {timeout_duration}s"
|
|
168
|
+
self.write_progress_message(message)
|
|
169
|
+
|
|
170
|
+
if self.logger:
|
|
171
|
+
self.logger.warning(message)
|
|
172
|
+
|
|
173
|
+
def log_task_error(self, task_name: str, error: Exception) -> None:
|
|
174
|
+
"""Log a task error.
|
|
175
|
+
|
|
176
|
+
:param task_name: Name of the failed task
|
|
177
|
+
:param error: The exception that occurred
|
|
178
|
+
"""
|
|
179
|
+
message = f"❌ ERROR: {task_name} - {error.__class__.__name__}: {str(error)}"
|
|
180
|
+
self.write_progress_message(message)
|
|
181
|
+
|
|
182
|
+
if self.logger:
|
|
183
|
+
self.logger.error(message)
|
|
184
|
+
|
|
185
|
+
def get_summary(self) -> Dict[str, Any]:
|
|
186
|
+
"""Get a summary of progress and statistics.
|
|
187
|
+
|
|
188
|
+
:return: Dictionary containing progress summary
|
|
189
|
+
"""
|
|
190
|
+
total_time = None
|
|
191
|
+
if self.start_time:
|
|
192
|
+
end_time = self.end_time or time.time()
|
|
193
|
+
total_time = end_time - self.start_time
|
|
194
|
+
|
|
195
|
+
return {
|
|
196
|
+
"total_tasks": self.total_tasks,
|
|
197
|
+
"completed_tasks": self.completed_tasks,
|
|
198
|
+
"failed_tasks": self.failed_tasks,
|
|
199
|
+
"timeout_tasks": self.timeout_tasks,
|
|
200
|
+
"success_rate": (self.completed_tasks / self.total_tasks) * 100 if self.total_tasks > 0 else 0,
|
|
201
|
+
"total_time_seconds": total_time,
|
|
202
|
+
"average_time_per_task": (
|
|
203
|
+
total_time / self.completed_tasks if total_time and self.completed_tasks > 0 else None
|
|
204
|
+
),
|
|
205
|
+
"task_statuses": self.task_statuses.copy(),
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
def print_summary(self) -> None:
|
|
209
|
+
"""Print a formatted summary of the progress."""
|
|
210
|
+
summary = self.get_summary()
|
|
211
|
+
|
|
212
|
+
self.write_progress_message("\n" + "=" * 60)
|
|
213
|
+
self.write_progress_message("EXECUTION SUMMARY")
|
|
214
|
+
self.write_progress_message("=" * 60)
|
|
215
|
+
self.write_progress_message(f"Total Tasks: {summary['total_tasks']}")
|
|
216
|
+
self.write_progress_message(f"Completed: {summary['completed_tasks']}")
|
|
217
|
+
self.write_progress_message(f"Failed: {summary['failed_tasks']}")
|
|
218
|
+
self.write_progress_message(f"Timeouts: {summary['timeout_tasks']}")
|
|
219
|
+
self.write_progress_message(f"Success Rate: {summary['success_rate']:.1f}%")
|
|
220
|
+
|
|
221
|
+
if summary["total_time_seconds"]:
|
|
222
|
+
self.write_progress_message(f"Total Time: {summary['total_time_seconds']:.1f}s")
|
|
223
|
+
|
|
224
|
+
if summary["average_time_per_task"]:
|
|
225
|
+
self.write_progress_message(f"Avg Time/Task: {summary['average_time_per_task']:.1f}s")
|
|
226
|
+
|
|
227
|
+
self.write_progress_message("=" * 60)
|
|
228
|
+
|
|
229
|
+
def __enter__(self):
|
|
230
|
+
"""Context manager entry."""
|
|
231
|
+
self.start()
|
|
232
|
+
return self
|
|
233
|
+
|
|
234
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
235
|
+
"""Context manager exit."""
|
|
236
|
+
self.stop()
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def create_progress_manager(
|
|
240
|
+
total_tasks: int = 0, logger=None, show_progress_bar: bool = True, progress_desc: str = "Processing"
|
|
241
|
+
) -> ProgressManager:
|
|
242
|
+
"""Create a ProgressManager instance.
|
|
243
|
+
|
|
244
|
+
:param total_tasks: Total number of tasks to track
|
|
245
|
+
:param logger: Logger instance
|
|
246
|
+
:param show_progress_bar: Whether to show progress bar
|
|
247
|
+
:param progress_desc: Description for progress bar
|
|
248
|
+
:return: Configured ProgressManager
|
|
249
|
+
"""
|
|
250
|
+
return ProgressManager(
|
|
251
|
+
total_tasks=total_tasks, logger=logger, show_progress_bar=show_progress_bar, progress_desc=progress_desc
|
|
252
|
+
)
|