azure-ai-evaluation 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/_common/onedp/models/_models.py +5 -0
- azure/ai/evaluation/_converters/_ai_services.py +60 -10
- azure/ai/evaluation/_converters/_models.py +75 -26
- azure/ai/evaluation/_evaluate/_eval_run.py +14 -1
- azure/ai/evaluation/_evaluate/_evaluate.py +13 -4
- azure/ai/evaluation/_evaluate/_evaluate_aoai.py +77 -33
- azure/ai/evaluation/_evaluate/_utils.py +4 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +2 -1
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +113 -19
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +7 -2
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +1 -1
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +2 -1
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +113 -3
- azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +8 -2
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +2 -1
- azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +10 -2
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +2 -1
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +2 -1
- azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +8 -2
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +104 -60
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +58 -41
- azure/ai/evaluation/_exceptions.py +1 -0
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/red_team/__init__.py +2 -1
- azure/ai/evaluation/red_team/_attack_objective_generator.py +17 -0
- azure/ai/evaluation/red_team/_callback_chat_target.py +14 -1
- azure/ai/evaluation/red_team/_evaluation_processor.py +376 -0
- azure/ai/evaluation/red_team/_mlflow_integration.py +322 -0
- azure/ai/evaluation/red_team/_orchestrator_manager.py +661 -0
- azure/ai/evaluation/red_team/_red_team.py +697 -3067
- azure/ai/evaluation/red_team/_result_processor.py +610 -0
- azure/ai/evaluation/red_team/_utils/__init__.py +34 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +3 -1
- azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +6 -0
- azure/ai/evaluation/red_team/_utils/exception_utils.py +345 -0
- azure/ai/evaluation/red_team/_utils/file_utils.py +266 -0
- azure/ai/evaluation/red_team/_utils/formatting_utils.py +115 -13
- azure/ai/evaluation/red_team/_utils/metric_mapping.py +24 -4
- azure/ai/evaluation/red_team/_utils/progress_utils.py +252 -0
- azure/ai/evaluation/red_team/_utils/retry_utils.py +218 -0
- azure/ai/evaluation/red_team/_utils/strategy_utils.py +17 -4
- azure/ai/evaluation/simulator/_adversarial_simulator.py +9 -0
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +19 -5
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +4 -3
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/METADATA +32 -2
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/RECORD +49 -41
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/WHEEL +1 -1
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.0.dist-info/licenses}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.10.0.dist-info → azure_ai_evaluation-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -37,6 +37,7 @@ class AzureRAIServiceTrueFalseScorer(Scorer):
|
|
|
37
37
|
model: Optional[str] = None,
|
|
38
38
|
prompt_template_key: Optional[str] = None,
|
|
39
39
|
logger: Optional[logging.Logger] = None,
|
|
40
|
+
context: Optional[str] = None,
|
|
40
41
|
) -> None:
|
|
41
42
|
"""Initialize the scorer.
|
|
42
43
|
|
|
@@ -52,6 +53,7 @@ class AzureRAIServiceTrueFalseScorer(Scorer):
|
|
|
52
53
|
self._model = model
|
|
53
54
|
self._prompt_template_key = prompt_template_key
|
|
54
55
|
self.logger = logger
|
|
56
|
+
self.context = context
|
|
55
57
|
|
|
56
58
|
# Use AzureRAIServiceTarget for scoring as well
|
|
57
59
|
self._scoring_target_red_llm = RAIServiceEvalChatTarget(
|
|
@@ -59,6 +61,7 @@ class AzureRAIServiceTrueFalseScorer(Scorer):
|
|
|
59
61
|
credential=credential,
|
|
60
62
|
azure_ai_project=azure_ai_project,
|
|
61
63
|
risk_category=risk_category,
|
|
64
|
+
context=context,
|
|
62
65
|
)
|
|
63
66
|
|
|
64
67
|
async def score_async(
|
|
@@ -78,6 +81,9 @@ class AzureRAIServiceTrueFalseScorer(Scorer):
|
|
|
78
81
|
# Validate request_response
|
|
79
82
|
self.validate(request_response, task=task)
|
|
80
83
|
|
|
84
|
+
if not self._scoring_target_red_llm.context:
|
|
85
|
+
self._scoring_target_red_llm.context = request_response.labels.get("context", "")
|
|
86
|
+
|
|
81
87
|
unvalidated_score: UnvalidatedScore = await self._score_value_with_llm(
|
|
82
88
|
prompt_target=self._scoring_target_red_llm,
|
|
83
89
|
system_prompt="Please evaluate the following response.", # TODO: What should this be?
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
"""
|
|
5
|
+
Exception handling utilities for Red Team Agent.
|
|
6
|
+
|
|
7
|
+
This module provides centralized exception handling, error categorization,
|
|
8
|
+
and error reporting utilities for red team operations.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import traceback
|
|
13
|
+
import asyncio
|
|
14
|
+
from typing import Optional, Any, Dict, Union
|
|
15
|
+
from enum import Enum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ErrorCategory(Enum):
|
|
19
|
+
"""Categories of errors that can occur during red team operations."""
|
|
20
|
+
|
|
21
|
+
NETWORK = "network"
|
|
22
|
+
AUTHENTICATION = "authentication"
|
|
23
|
+
CONFIGURATION = "configuration"
|
|
24
|
+
DATA_PROCESSING = "data_processing"
|
|
25
|
+
ORCHESTRATOR = "orchestrator"
|
|
26
|
+
EVALUATION = "evaluation"
|
|
27
|
+
FILE_IO = "file_io"
|
|
28
|
+
TIMEOUT = "timeout"
|
|
29
|
+
UNKNOWN = "unknown"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ErrorSeverity(Enum):
|
|
33
|
+
"""Severity levels for errors."""
|
|
34
|
+
|
|
35
|
+
LOW = "low" # Warning level, operation can continue
|
|
36
|
+
MEDIUM = "medium" # Error level, task failed but scan can continue
|
|
37
|
+
HIGH = "high" # Critical error, scan should be aborted
|
|
38
|
+
FATAL = "fatal" # Unrecoverable error
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RedTeamError(Exception):
|
|
42
|
+
"""Base exception for Red Team operations."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
message: str,
|
|
47
|
+
category: ErrorCategory = ErrorCategory.UNKNOWN,
|
|
48
|
+
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
|
|
49
|
+
context: Optional[Dict[str, Any]] = None,
|
|
50
|
+
original_exception: Optional[Exception] = None,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(message)
|
|
53
|
+
self.message = message
|
|
54
|
+
self.category = category
|
|
55
|
+
self.severity = severity
|
|
56
|
+
self.context = context or {}
|
|
57
|
+
self.original_exception = original_exception
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ExceptionHandler:
|
|
61
|
+
"""Centralized exception handling for Red Team operations."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, logger: Optional[logging.Logger] = None):
|
|
64
|
+
"""Initialize exception handler.
|
|
65
|
+
|
|
66
|
+
:param logger: Logger instance for error reporting
|
|
67
|
+
"""
|
|
68
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
69
|
+
self.error_counts: Dict[ErrorCategory, int] = {category: 0 for category in ErrorCategory}
|
|
70
|
+
|
|
71
|
+
def categorize_exception(self, exception: Exception) -> ErrorCategory:
|
|
72
|
+
"""Categorize an exception based on its type and message.
|
|
73
|
+
|
|
74
|
+
:param exception: The exception to categorize
|
|
75
|
+
:return: The appropriate error category
|
|
76
|
+
"""
|
|
77
|
+
import httpx
|
|
78
|
+
import httpcore
|
|
79
|
+
|
|
80
|
+
# Network-related errors
|
|
81
|
+
network_exceptions = (
|
|
82
|
+
httpx.ConnectTimeout,
|
|
83
|
+
httpx.ReadTimeout,
|
|
84
|
+
httpx.ConnectError,
|
|
85
|
+
httpx.HTTPError,
|
|
86
|
+
httpx.TimeoutException,
|
|
87
|
+
httpcore.ReadTimeout,
|
|
88
|
+
ConnectionError,
|
|
89
|
+
ConnectionRefusedError,
|
|
90
|
+
ConnectionResetError,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if isinstance(exception, network_exceptions):
|
|
94
|
+
return ErrorCategory.NETWORK
|
|
95
|
+
|
|
96
|
+
# Timeout errors (separate from network to handle asyncio.TimeoutError)
|
|
97
|
+
if isinstance(exception, (TimeoutError, asyncio.TimeoutError)):
|
|
98
|
+
return ErrorCategory.TIMEOUT
|
|
99
|
+
|
|
100
|
+
# File I/O errors
|
|
101
|
+
if isinstance(exception, (IOError, OSError, FileNotFoundError, PermissionError)):
|
|
102
|
+
return ErrorCategory.FILE_IO
|
|
103
|
+
|
|
104
|
+
# HTTP status code specific errors
|
|
105
|
+
if hasattr(exception, "response") and hasattr(exception.response, "status_code"):
|
|
106
|
+
status_code = exception.response.status_code
|
|
107
|
+
if 500 <= status_code < 600:
|
|
108
|
+
return ErrorCategory.NETWORK
|
|
109
|
+
elif status_code == 401:
|
|
110
|
+
return ErrorCategory.AUTHENTICATION
|
|
111
|
+
elif status_code == 403:
|
|
112
|
+
return ErrorCategory.CONFIGURATION
|
|
113
|
+
|
|
114
|
+
# String-based categorization
|
|
115
|
+
message = str(exception).lower()
|
|
116
|
+
|
|
117
|
+
# Define keyword mappings for cleaner logic
|
|
118
|
+
keyword_mappings = {
|
|
119
|
+
ErrorCategory.AUTHENTICATION: ["authentication", "unauthorized"],
|
|
120
|
+
ErrorCategory.CONFIGURATION: ["configuration", "config"],
|
|
121
|
+
ErrorCategory.ORCHESTRATOR: ["orchestrator"],
|
|
122
|
+
ErrorCategory.EVALUATION: ["evaluation", "evaluate", "model_error"],
|
|
123
|
+
ErrorCategory.DATA_PROCESSING: ["data", "json"],
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
for category, keywords in keyword_mappings.items():
|
|
127
|
+
if any(keyword in message for keyword in keywords):
|
|
128
|
+
return category
|
|
129
|
+
|
|
130
|
+
return ErrorCategory.UNKNOWN
|
|
131
|
+
|
|
132
|
+
def determine_severity(
|
|
133
|
+
self, exception: Exception, category: ErrorCategory, context: Optional[Dict[str, Any]] = None
|
|
134
|
+
) -> ErrorSeverity:
|
|
135
|
+
"""Determine the severity of an exception.
|
|
136
|
+
|
|
137
|
+
:param exception: The exception to evaluate
|
|
138
|
+
:param category: The error category
|
|
139
|
+
:param context: Additional context for severity determination
|
|
140
|
+
:return: The appropriate error severity
|
|
141
|
+
"""
|
|
142
|
+
context = context or {}
|
|
143
|
+
|
|
144
|
+
# Critical system errors
|
|
145
|
+
if isinstance(exception, (MemoryError, SystemExit, KeyboardInterrupt)):
|
|
146
|
+
return ErrorSeverity.FATAL
|
|
147
|
+
|
|
148
|
+
# Authentication and configuration are typically high severity
|
|
149
|
+
if category in (ErrorCategory.AUTHENTICATION, ErrorCategory.CONFIGURATION):
|
|
150
|
+
return ErrorSeverity.HIGH
|
|
151
|
+
|
|
152
|
+
# File I/O errors can be high severity if they involve critical files
|
|
153
|
+
if category == ErrorCategory.FILE_IO:
|
|
154
|
+
if context.get("critical_file", False):
|
|
155
|
+
return ErrorSeverity.HIGH
|
|
156
|
+
return ErrorSeverity.MEDIUM
|
|
157
|
+
|
|
158
|
+
# Network and timeout errors are usually medium severity (retryable)
|
|
159
|
+
if category in (ErrorCategory.NETWORK, ErrorCategory.TIMEOUT):
|
|
160
|
+
return ErrorSeverity.MEDIUM
|
|
161
|
+
|
|
162
|
+
# Task-specific errors are medium severity
|
|
163
|
+
if category in (ErrorCategory.ORCHESTRATOR, ErrorCategory.EVALUATION, ErrorCategory.DATA_PROCESSING):
|
|
164
|
+
return ErrorSeverity.MEDIUM
|
|
165
|
+
|
|
166
|
+
return ErrorSeverity.LOW
|
|
167
|
+
|
|
168
|
+
def handle_exception(
|
|
169
|
+
self,
|
|
170
|
+
exception: Exception,
|
|
171
|
+
context: Optional[Dict[str, Any]] = None,
|
|
172
|
+
task_name: Optional[str] = None,
|
|
173
|
+
reraise: bool = False,
|
|
174
|
+
) -> RedTeamError:
|
|
175
|
+
"""Handle an exception with proper categorization and logging.
|
|
176
|
+
|
|
177
|
+
:param exception: The exception to handle
|
|
178
|
+
:param context: Additional context information
|
|
179
|
+
:param task_name: Name of the task where the exception occurred
|
|
180
|
+
:param reraise: Whether to reraise the exception after handling
|
|
181
|
+
:return: A RedTeamError with categorized information
|
|
182
|
+
"""
|
|
183
|
+
context = context or {}
|
|
184
|
+
|
|
185
|
+
# If it's already a RedTeamError, just log and return/reraise
|
|
186
|
+
if isinstance(exception, RedTeamError):
|
|
187
|
+
self._log_error(exception, task_name)
|
|
188
|
+
if reraise:
|
|
189
|
+
raise exception
|
|
190
|
+
return exception
|
|
191
|
+
|
|
192
|
+
# Categorize the exception
|
|
193
|
+
category = self.categorize_exception(exception)
|
|
194
|
+
severity = self.determine_severity(exception, category, context)
|
|
195
|
+
|
|
196
|
+
# Update error counts
|
|
197
|
+
self.error_counts[category] += 1
|
|
198
|
+
|
|
199
|
+
# Create RedTeamError
|
|
200
|
+
message = f"{category.value.title()} error"
|
|
201
|
+
if task_name:
|
|
202
|
+
message += f" in {task_name}"
|
|
203
|
+
message += f": {str(exception)}"
|
|
204
|
+
|
|
205
|
+
red_team_error = RedTeamError(
|
|
206
|
+
message=message, category=category, severity=severity, context=context, original_exception=exception
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Log the error
|
|
210
|
+
self._log_error(red_team_error, task_name)
|
|
211
|
+
|
|
212
|
+
if reraise:
|
|
213
|
+
raise red_team_error
|
|
214
|
+
|
|
215
|
+
return red_team_error
|
|
216
|
+
|
|
217
|
+
def _log_error(self, error: RedTeamError, task_name: Optional[str] = None) -> None:
|
|
218
|
+
"""Log an error with appropriate level based on severity.
|
|
219
|
+
|
|
220
|
+
:param error: The RedTeamError to log
|
|
221
|
+
:param task_name: Optional task name for context
|
|
222
|
+
"""
|
|
223
|
+
# Determine log level based on severity
|
|
224
|
+
if error.severity == ErrorSeverity.FATAL:
|
|
225
|
+
log_level = logging.CRITICAL
|
|
226
|
+
elif error.severity == ErrorSeverity.HIGH:
|
|
227
|
+
log_level = logging.ERROR
|
|
228
|
+
elif error.severity == ErrorSeverity.MEDIUM:
|
|
229
|
+
log_level = logging.WARNING
|
|
230
|
+
else:
|
|
231
|
+
log_level = logging.INFO
|
|
232
|
+
|
|
233
|
+
# Create log message
|
|
234
|
+
message_parts = []
|
|
235
|
+
if task_name:
|
|
236
|
+
message_parts.append(f"[{task_name}]")
|
|
237
|
+
message_parts.append(f"[{error.category.value}]")
|
|
238
|
+
message_parts.append(f"[{error.severity.value}]")
|
|
239
|
+
message_parts.append(error.message)
|
|
240
|
+
|
|
241
|
+
log_message = " ".join(message_parts)
|
|
242
|
+
|
|
243
|
+
# Log with appropriate level
|
|
244
|
+
self.logger.log(log_level, log_message)
|
|
245
|
+
|
|
246
|
+
# Log additional context if available
|
|
247
|
+
if error.context:
|
|
248
|
+
self.logger.debug(f"Error context: {error.context}")
|
|
249
|
+
|
|
250
|
+
# Log original exception traceback for debugging
|
|
251
|
+
if error.original_exception and self.logger.isEnabledFor(logging.DEBUG):
|
|
252
|
+
self.logger.debug(f"Original exception traceback:\n{traceback.format_exc()}")
|
|
253
|
+
|
|
254
|
+
def should_abort_scan(self) -> bool:
|
|
255
|
+
"""Determine if the scan should be aborted based on error patterns.
|
|
256
|
+
|
|
257
|
+
:return: True if the scan should be aborted
|
|
258
|
+
"""
|
|
259
|
+
# Abort if we have too many high-severity errors
|
|
260
|
+
high_severity_categories = [ErrorCategory.AUTHENTICATION, ErrorCategory.CONFIGURATION]
|
|
261
|
+
high_severity_count = sum(self.error_counts[cat] for cat in high_severity_categories)
|
|
262
|
+
|
|
263
|
+
if high_severity_count > 2:
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
# Abort if we have too many network errors (indicates systemic issue)
|
|
267
|
+
if self.error_counts[ErrorCategory.NETWORK] > 10:
|
|
268
|
+
return True
|
|
269
|
+
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
def get_error_summary(self) -> Dict[str, Any]:
|
|
273
|
+
"""Get a summary of all errors encountered.
|
|
274
|
+
|
|
275
|
+
:return: Dictionary containing error statistics
|
|
276
|
+
"""
|
|
277
|
+
total_errors = sum(self.error_counts.values())
|
|
278
|
+
|
|
279
|
+
return {
|
|
280
|
+
"total_errors": total_errors,
|
|
281
|
+
"error_counts_by_category": dict(self.error_counts),
|
|
282
|
+
"most_common_category": max(self.error_counts, key=self.error_counts.get) if total_errors > 0 else None,
|
|
283
|
+
"should_abort": self.should_abort_scan(),
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
def log_error_summary(self) -> None:
|
|
287
|
+
"""Log a summary of all errors encountered."""
|
|
288
|
+
summary = self.get_error_summary()
|
|
289
|
+
|
|
290
|
+
if summary["total_errors"] == 0:
|
|
291
|
+
self.logger.info("No errors encountered during operation")
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
self.logger.info(f"Error Summary: {summary['total_errors']} total errors")
|
|
295
|
+
|
|
296
|
+
for category, count in summary["error_counts_by_category"].items():
|
|
297
|
+
if count > 0:
|
|
298
|
+
self.logger.info(f" {category}: {count}")
|
|
299
|
+
|
|
300
|
+
if summary["most_common_category"]:
|
|
301
|
+
self.logger.info(f"Most common error type: {summary['most_common_category']}")
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def create_exception_handler(logger: Optional[logging.Logger] = None) -> ExceptionHandler:
|
|
305
|
+
"""Create an ExceptionHandler instance.
|
|
306
|
+
|
|
307
|
+
:param logger: Logger instance for error reporting
|
|
308
|
+
:return: Configured ExceptionHandler
|
|
309
|
+
"""
|
|
310
|
+
return ExceptionHandler(logger=logger)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# Convenience context manager for handling exceptions
|
|
314
|
+
class exception_context:
|
|
315
|
+
"""Context manager for handling exceptions in Red Team operations."""
|
|
316
|
+
|
|
317
|
+
def __init__(
|
|
318
|
+
self,
|
|
319
|
+
handler: ExceptionHandler,
|
|
320
|
+
task_name: str,
|
|
321
|
+
context: Optional[Dict[str, Any]] = None,
|
|
322
|
+
reraise_fatal: bool = True,
|
|
323
|
+
):
|
|
324
|
+
self.handler = handler
|
|
325
|
+
self.task_name = task_name
|
|
326
|
+
self.context = context or {}
|
|
327
|
+
self.reraise_fatal = reraise_fatal
|
|
328
|
+
self.error: Optional[RedTeamError] = None
|
|
329
|
+
|
|
330
|
+
def __enter__(self):
|
|
331
|
+
return self
|
|
332
|
+
|
|
333
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
334
|
+
if exc_val is not None:
|
|
335
|
+
self.error = self.handler.handle_exception(
|
|
336
|
+
exception=exc_val, context=self.context, task_name=self.task_name, reraise=False
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Reraise fatal errors unless specifically disabled
|
|
340
|
+
if self.reraise_fatal and self.error.severity == ErrorSeverity.FATAL:
|
|
341
|
+
raise self.error
|
|
342
|
+
|
|
343
|
+
# Suppress the original exception (we've handled it)
|
|
344
|
+
return True
|
|
345
|
+
return False
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
"""
|
|
5
|
+
File operation utilities for Red Team Agent.
|
|
6
|
+
|
|
7
|
+
This module provides centralized file handling, path operations, and
|
|
8
|
+
data serialization utilities used across the red team components.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Dict, List, Optional, Union
|
|
17
|
+
|
|
18
|
+
# Try to import DefaultOpenEncoding, fallback to standard encoding
|
|
19
|
+
try:
|
|
20
|
+
from azure.ai.evaluation._common._utils import DefaultOpenEncoding
|
|
21
|
+
|
|
22
|
+
DEFAULT_ENCODING = DefaultOpenEncoding.WRITE
|
|
23
|
+
except ImportError:
|
|
24
|
+
DEFAULT_ENCODING = "utf-8"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FileManager:
|
|
28
|
+
"""Centralized file operations manager for Red Team operations."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, base_output_dir: Optional[str] = None, logger=None):
|
|
31
|
+
"""Initialize file manager.
|
|
32
|
+
|
|
33
|
+
:param base_output_dir: Base directory for all file operations
|
|
34
|
+
:param logger: Logger instance for file operations
|
|
35
|
+
"""
|
|
36
|
+
self.base_output_dir = base_output_dir or "."
|
|
37
|
+
self.logger = logger
|
|
38
|
+
|
|
39
|
+
def ensure_directory(self, path: Union[str, os.PathLike]) -> str:
|
|
40
|
+
"""Ensure a directory exists, creating it if necessary.
|
|
41
|
+
|
|
42
|
+
:param path: Path to the directory
|
|
43
|
+
:return: Absolute path to the directory
|
|
44
|
+
"""
|
|
45
|
+
abs_path = os.path.abspath(path)
|
|
46
|
+
os.makedirs(abs_path, exist_ok=True)
|
|
47
|
+
return abs_path
|
|
48
|
+
|
|
49
|
+
def generate_unique_filename(
|
|
50
|
+
self, prefix: str = "", suffix: str = "", extension: str = "", use_timestamp: bool = False
|
|
51
|
+
) -> str:
|
|
52
|
+
"""Generate a unique filename.
|
|
53
|
+
|
|
54
|
+
:param prefix: Prefix for the filename
|
|
55
|
+
:param suffix: Suffix for the filename
|
|
56
|
+
:param extension: File extension (with or without dot)
|
|
57
|
+
:param use_timestamp: Whether to include timestamp in filename
|
|
58
|
+
:return: Unique filename
|
|
59
|
+
"""
|
|
60
|
+
parts = []
|
|
61
|
+
|
|
62
|
+
if prefix:
|
|
63
|
+
parts.append(prefix)
|
|
64
|
+
|
|
65
|
+
if use_timestamp:
|
|
66
|
+
parts.append(datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
67
|
+
|
|
68
|
+
# Always include UUID for uniqueness
|
|
69
|
+
parts.append(str(uuid.uuid4()))
|
|
70
|
+
|
|
71
|
+
if suffix:
|
|
72
|
+
parts.append(suffix)
|
|
73
|
+
|
|
74
|
+
filename = "_".join(parts)
|
|
75
|
+
|
|
76
|
+
if extension:
|
|
77
|
+
if not extension.startswith("."):
|
|
78
|
+
extension = "." + extension
|
|
79
|
+
filename += extension
|
|
80
|
+
|
|
81
|
+
return filename
|
|
82
|
+
|
|
83
|
+
def get_scan_output_path(self, scan_id: str, filename: str = "") -> str:
|
|
84
|
+
"""Get path for scan output files.
|
|
85
|
+
|
|
86
|
+
:param scan_id: Unique scan identifier
|
|
87
|
+
:param filename: Optional filename to append
|
|
88
|
+
:return: Full path for scan output
|
|
89
|
+
"""
|
|
90
|
+
# Create scan directory based on DEBUG environment
|
|
91
|
+
is_debug = os.environ.get("DEBUG", "").lower() in ("true", "1", "yes", "y")
|
|
92
|
+
folder_prefix = "" if is_debug else "."
|
|
93
|
+
|
|
94
|
+
scan_dir = os.path.join(self.base_output_dir, f"{folder_prefix}{scan_id}")
|
|
95
|
+
self.ensure_directory(scan_dir)
|
|
96
|
+
|
|
97
|
+
# Create .gitignore in scan directory if not debug mode
|
|
98
|
+
if not is_debug:
|
|
99
|
+
gitignore_path = os.path.join(scan_dir, ".gitignore")
|
|
100
|
+
if not os.path.exists(gitignore_path):
|
|
101
|
+
with open(gitignore_path, "w", encoding="utf-8") as f:
|
|
102
|
+
f.write("*\n")
|
|
103
|
+
|
|
104
|
+
if filename:
|
|
105
|
+
return os.path.join(scan_dir, filename)
|
|
106
|
+
return scan_dir
|
|
107
|
+
|
|
108
|
+
def write_json(self, data: Any, filepath: Union[str, os.PathLike], indent: int = 2, ensure_dir: bool = True) -> str:
|
|
109
|
+
"""Write data to JSON file.
|
|
110
|
+
|
|
111
|
+
:param data: Data to write
|
|
112
|
+
:param filepath: Path to write the file
|
|
113
|
+
:param indent: JSON indentation
|
|
114
|
+
:param ensure_dir: Whether to ensure directory exists
|
|
115
|
+
:return: Absolute path of written file
|
|
116
|
+
"""
|
|
117
|
+
abs_path = os.path.abspath(filepath)
|
|
118
|
+
|
|
119
|
+
if ensure_dir:
|
|
120
|
+
self.ensure_directory(os.path.dirname(abs_path))
|
|
121
|
+
|
|
122
|
+
with open(abs_path, "w", encoding=DEFAULT_ENCODING) as f:
|
|
123
|
+
json.dump(data, f, indent=indent)
|
|
124
|
+
|
|
125
|
+
if self.logger:
|
|
126
|
+
self.logger.debug(f"Successfully wrote JSON to {abs_path}")
|
|
127
|
+
|
|
128
|
+
return abs_path
|
|
129
|
+
|
|
130
|
+
def read_json(self, filepath: Union[str, os.PathLike]) -> Any:
|
|
131
|
+
"""Read data from JSON file.
|
|
132
|
+
|
|
133
|
+
:param filepath: Path to the JSON file
|
|
134
|
+
:return: Parsed JSON data
|
|
135
|
+
"""
|
|
136
|
+
abs_path = os.path.abspath(filepath)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
with open(abs_path, "r", encoding="utf-8") as f:
|
|
140
|
+
data = json.load(f)
|
|
141
|
+
|
|
142
|
+
if self.logger:
|
|
143
|
+
self.logger.debug(f"Successfully read JSON from {abs_path}")
|
|
144
|
+
|
|
145
|
+
return data
|
|
146
|
+
except Exception as e:
|
|
147
|
+
if self.logger:
|
|
148
|
+
self.logger.error(f"Failed to read JSON from {abs_path}: {str(e)}")
|
|
149
|
+
raise
|
|
150
|
+
|
|
151
|
+
def read_jsonl(self, filepath: Union[str, os.PathLike]) -> List[Dict]:
|
|
152
|
+
"""Read data from JSONL file.
|
|
153
|
+
|
|
154
|
+
:param filepath: Path to the JSONL file
|
|
155
|
+
:return: List of parsed JSON objects
|
|
156
|
+
"""
|
|
157
|
+
abs_path = os.path.abspath(filepath)
|
|
158
|
+
data = []
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
with open(abs_path, "r", encoding="utf-8") as f:
|
|
162
|
+
for line_num, line in enumerate(f, 1):
|
|
163
|
+
line = line.strip()
|
|
164
|
+
if line:
|
|
165
|
+
try:
|
|
166
|
+
data.append(json.loads(line))
|
|
167
|
+
except json.JSONDecodeError as e:
|
|
168
|
+
if self.logger:
|
|
169
|
+
self.logger.warning(f"Skipping invalid JSON line {line_num} in {abs_path}: {str(e)}")
|
|
170
|
+
|
|
171
|
+
if self.logger:
|
|
172
|
+
self.logger.debug(f"Successfully read {len(data)} records from JSONL {abs_path}")
|
|
173
|
+
|
|
174
|
+
return data
|
|
175
|
+
except Exception as e:
|
|
176
|
+
if self.logger:
|
|
177
|
+
self.logger.error(f"Failed to read JSONL from {abs_path}: {str(e)}")
|
|
178
|
+
raise
|
|
179
|
+
|
|
180
|
+
def write_jsonl(self, data: List[Dict], filepath: Union[str, os.PathLike], ensure_dir: bool = True) -> str:
|
|
181
|
+
"""Write data to JSONL file.
|
|
182
|
+
|
|
183
|
+
:param data: List of dictionaries to write
|
|
184
|
+
:param filepath: Path to write the file
|
|
185
|
+
:param ensure_dir: Whether to ensure directory exists
|
|
186
|
+
:return: Absolute path of written file
|
|
187
|
+
"""
|
|
188
|
+
abs_path = os.path.abspath(filepath)
|
|
189
|
+
|
|
190
|
+
if ensure_dir:
|
|
191
|
+
self.ensure_directory(os.path.dirname(abs_path))
|
|
192
|
+
|
|
193
|
+
with open(abs_path, "w", encoding="utf-8") as f:
|
|
194
|
+
for item in data:
|
|
195
|
+
f.write(json.dumps(item) + "\n")
|
|
196
|
+
|
|
197
|
+
if self.logger:
|
|
198
|
+
self.logger.debug(f"Successfully wrote {len(data)} records to JSONL {abs_path}")
|
|
199
|
+
|
|
200
|
+
return abs_path
|
|
201
|
+
|
|
202
|
+
def safe_filename(self, name: str, max_length: int = 255) -> str:
|
|
203
|
+
"""Create a safe filename from a string.
|
|
204
|
+
|
|
205
|
+
:param name: Original name
|
|
206
|
+
:param max_length: Maximum filename length
|
|
207
|
+
:return: Safe filename
|
|
208
|
+
"""
|
|
209
|
+
# Replace invalid characters
|
|
210
|
+
invalid_chars = '<>:"/\\|?*'
|
|
211
|
+
safe_name = "".join(c if c not in invalid_chars else "_" for c in name)
|
|
212
|
+
|
|
213
|
+
# Replace spaces with underscores
|
|
214
|
+
safe_name = safe_name.replace(" ", "_")
|
|
215
|
+
|
|
216
|
+
# Truncate if too long
|
|
217
|
+
if len(safe_name) > max_length:
|
|
218
|
+
safe_name = safe_name[: max_length - 4] + "..."
|
|
219
|
+
|
|
220
|
+
return safe_name
|
|
221
|
+
|
|
222
|
+
def get_file_size(self, filepath: Union[str, os.PathLike]) -> int:
|
|
223
|
+
"""Get file size in bytes.
|
|
224
|
+
|
|
225
|
+
:param filepath: Path to the file
|
|
226
|
+
:return: File size in bytes
|
|
227
|
+
"""
|
|
228
|
+
return os.path.getsize(filepath)
|
|
229
|
+
|
|
230
|
+
def file_exists(self, filepath: Union[str, os.PathLike]) -> bool:
|
|
231
|
+
"""Check if file exists.
|
|
232
|
+
|
|
233
|
+
:param filepath: Path to check
|
|
234
|
+
:return: True if file exists
|
|
235
|
+
"""
|
|
236
|
+
return os.path.isfile(filepath)
|
|
237
|
+
|
|
238
|
+
def cleanup_file(self, filepath: Union[str, os.PathLike], ignore_errors: bool = True) -> bool:
|
|
239
|
+
"""Delete a file if it exists.
|
|
240
|
+
|
|
241
|
+
:param filepath: Path to the file to delete
|
|
242
|
+
:param ignore_errors: Whether to ignore deletion errors
|
|
243
|
+
:return: True if file was deleted or didn't exist
|
|
244
|
+
"""
|
|
245
|
+
try:
|
|
246
|
+
if self.file_exists(filepath):
|
|
247
|
+
os.remove(filepath)
|
|
248
|
+
if self.logger:
|
|
249
|
+
self.logger.debug(f"Deleted file: {filepath}")
|
|
250
|
+
return True
|
|
251
|
+
except Exception as e:
|
|
252
|
+
if not ignore_errors:
|
|
253
|
+
raise
|
|
254
|
+
if self.logger:
|
|
255
|
+
self.logger.warning(f"Failed to delete file {filepath}: {str(e)}")
|
|
256
|
+
return False
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def create_file_manager(base_output_dir: Optional[str] = None, logger=None) -> FileManager:
|
|
260
|
+
"""Create a FileManager instance.
|
|
261
|
+
|
|
262
|
+
:param base_output_dir: Base directory for file operations
|
|
263
|
+
:param logger: Logger instance
|
|
264
|
+
:return: Configured FileManager
|
|
265
|
+
"""
|
|
266
|
+
return FileManager(base_output_dir=base_output_dir, logger=logger)
|