strands-agents-evals 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/METADATA +2 -1
- {strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/RECORD +25 -18
- strands_evals/evaluators/__init__.py +4 -0
- strands_evals/evaluators/conciseness_evaluator.py +139 -0
- strands_evals/evaluators/evaluator.py +4 -0
- strands_evals/evaluators/faithfulness_evaluator.py +21 -16
- strands_evals/evaluators/goal_success_rate_evaluator.py +21 -16
- strands_evals/evaluators/harmfulness_evaluator.py +21 -16
- strands_evals/evaluators/helpfulness_evaluator.py +21 -16
- strands_evals/evaluators/interactions_evaluator.py +6 -4
- strands_evals/evaluators/output_evaluator.py +6 -4
- strands_evals/evaluators/prompt_templates/conciseness/__init__.py +11 -0
- strands_evals/evaluators/prompt_templates/conciseness/conciseness_v0.py +9 -0
- strands_evals/evaluators/prompt_templates/response_relevance/__init__.py +11 -0
- strands_evals/evaluators/prompt_templates/response_relevance/response_relevance_v0.py +29 -0
- strands_evals/evaluators/response_relevance_evaluator.py +144 -0
- strands_evals/evaluators/tool_parameter_accuracy_evaluator.py +19 -8
- strands_evals/evaluators/tool_selection_accuracy_evaluator.py +19 -8
- strands_evals/evaluators/trajectory_evaluator.py +6 -4
- strands_evals/experiment.py +281 -90
- strands_evals/extractors/trace_extractor.py +13 -1
- strands_evals/utils.py +37 -0
- {strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/WHEEL +0 -0
- {strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/licenses/NOTICE +0 -0
strands_evals/experiment.py
CHANGED
|
@@ -4,8 +4,16 @@ import logging
|
|
|
4
4
|
import os
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
from typing import cast
|
|
7
8
|
|
|
8
9
|
from opentelemetry.trace import format_trace_id
|
|
10
|
+
from tenacity import (
|
|
11
|
+
RetryError,
|
|
12
|
+
retry,
|
|
13
|
+
retry_if_exception,
|
|
14
|
+
stop_after_attempt,
|
|
15
|
+
wait_exponential,
|
|
16
|
+
)
|
|
9
17
|
from typing_extensions import Any, Generic, TypeVar
|
|
10
18
|
|
|
11
19
|
from .case import Case
|
|
@@ -17,6 +25,7 @@ from .telemetry import get_tracer, serialize
|
|
|
17
25
|
from .telemetry._cloudwatch_logger import _send_to_cloudwatch
|
|
18
26
|
from .types.evaluation import EvaluationData
|
|
19
27
|
from .types.evaluation_report import EvaluationReport
|
|
28
|
+
from .utils import is_throttling_error
|
|
20
29
|
|
|
21
30
|
InputT = TypeVar("InputT")
|
|
22
31
|
OutputT = TypeVar("OutputT")
|
|
@@ -24,6 +33,11 @@ OutputT = TypeVar("OutputT")
|
|
|
24
33
|
logger = logging.getLogger()
|
|
25
34
|
logger.setLevel(logging.INFO)
|
|
26
35
|
|
|
36
|
+
# Retry configuration for handling throttling
|
|
37
|
+
_MAX_RETRY_ATTEMPTS = 6
|
|
38
|
+
_INITIAL_RETRY_DELAY = 4
|
|
39
|
+
_MAX_RETRY_DELAY = 240 # 4 minutes
|
|
40
|
+
|
|
27
41
|
|
|
28
42
|
def _get_label_from_score(evaluator: Evaluator, score: float) -> str:
|
|
29
43
|
"""
|
|
@@ -139,6 +153,34 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
139
153
|
"""
|
|
140
154
|
self._evaluators = new_evaluators
|
|
141
155
|
|
|
156
|
+
def _record_evaluator_result(
|
|
157
|
+
self,
|
|
158
|
+
evaluator_data: dict[str, dict[str, list]],
|
|
159
|
+
eval_name: str,
|
|
160
|
+
case_data: dict,
|
|
161
|
+
test_pass: bool,
|
|
162
|
+
score: float,
|
|
163
|
+
reason: str,
|
|
164
|
+
detailed_results: list,
|
|
165
|
+
):
|
|
166
|
+
"""
|
|
167
|
+
Record a single evaluator result in the evaluator_data dictionary.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
evaluator_data: Dictionary to store evaluator results
|
|
171
|
+
eval_name: Name of the evaluator
|
|
172
|
+
case_data: Case data (already serialized as dict)
|
|
173
|
+
test_pass: Whether the test passed
|
|
174
|
+
score: Evaluation score
|
|
175
|
+
reason: Reason/explanation for the result
|
|
176
|
+
detailed_results: Detailed evaluation outputs
|
|
177
|
+
"""
|
|
178
|
+
evaluator_data[eval_name]["cases"].append(case_data)
|
|
179
|
+
evaluator_data[eval_name]["test_passes"].append(test_pass)
|
|
180
|
+
evaluator_data[eval_name]["scores"].append(score)
|
|
181
|
+
evaluator_data[eval_name]["reasons"].append(reason)
|
|
182
|
+
evaluator_data[eval_name]["detailed_results"].append(detailed_results)
|
|
183
|
+
|
|
142
184
|
def _run_task(
|
|
143
185
|
self, task: Callable[[Case[InputT, OutputT]], OutputT | dict[str, Any]], case: Case[InputT, OutputT]
|
|
144
186
|
) -> EvaluationData[InputT, OutputT]:
|
|
@@ -240,92 +282,166 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
240
282
|
trace_id = None
|
|
241
283
|
|
|
242
284
|
try:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
285
|
+
|
|
286
|
+
@retry(
|
|
287
|
+
retry=retry_if_exception(is_throttling_error),
|
|
288
|
+
stop=stop_after_attempt(_MAX_RETRY_ATTEMPTS),
|
|
289
|
+
wait=wait_exponential(multiplier=_INITIAL_RETRY_DELAY, max=_MAX_RETRY_DELAY),
|
|
290
|
+
reraise=True,
|
|
291
|
+
)
|
|
292
|
+
async def _run_task_with_retry(task=task, case=case):
|
|
293
|
+
return await self._run_task_async(task, case)
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
with self._tracer.start_as_current_span(
|
|
297
|
+
f"execute_case {case_name}",
|
|
298
|
+
) as case_span:
|
|
299
|
+
evaluation_context = await _run_task_with_retry()
|
|
300
|
+
case_span.set_attributes(
|
|
301
|
+
{
|
|
302
|
+
"gen_ai.evaluation.data.input": serialize(evaluation_context.input),
|
|
303
|
+
"gen_ai.evaluation.data.expected_output": serialize(evaluation_context.expected_output),
|
|
304
|
+
"gen_ai.evaluation.data.actual_output": serialize(evaluation_context.actual_output),
|
|
305
|
+
"gen_ai.evaluation.data.has_trajectory": (
|
|
306
|
+
evaluation_context.actual_trajectory is not None
|
|
307
|
+
),
|
|
308
|
+
"gen_ai.evaluation.data.has_interactions": (
|
|
309
|
+
evaluation_context.actual_interactions is not None
|
|
310
|
+
),
|
|
311
|
+
}
|
|
312
|
+
)
|
|
313
|
+
trace_id = format_trace_id(case_span.get_span_context().trace_id)
|
|
314
|
+
except RetryError as e:
|
|
315
|
+
# Max retries exceeded
|
|
316
|
+
original_exception = e.last_attempt.exception()
|
|
317
|
+
if original_exception is None:
|
|
318
|
+
original_exception = Exception(f"Task execution failed after {_MAX_RETRY_ATTEMPTS} retries")
|
|
319
|
+
logger.error(
|
|
320
|
+
f"Max retry attempts ({_MAX_RETRY_ATTEMPTS}) exceeded for task execution "
|
|
321
|
+
f"on case {case_name}. Last error: {str(original_exception)}"
|
|
257
322
|
)
|
|
258
|
-
|
|
323
|
+
raise original_exception from e
|
|
259
324
|
|
|
260
325
|
# Evaluate with each evaluator
|
|
261
326
|
evaluator_results = []
|
|
262
327
|
for evaluator in self._evaluators:
|
|
263
|
-
with self._tracer.start_as_current_span(
|
|
264
|
-
f"evaluator {evaluator.get_type_name()}",
|
|
265
|
-
) as eval_span:
|
|
266
|
-
evaluation_outputs = await evaluator.evaluate_async(evaluation_context)
|
|
267
|
-
(aggregate_score, aggregate_pass, aggregate_reason) = evaluator.aggregator(evaluation_outputs)
|
|
268
328
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
329
|
+
@retry(
|
|
330
|
+
retry=retry_if_exception(is_throttling_error),
|
|
331
|
+
stop=stop_after_attempt(_MAX_RETRY_ATTEMPTS),
|
|
332
|
+
wait=wait_exponential(multiplier=_INITIAL_RETRY_DELAY, max=_MAX_RETRY_DELAY),
|
|
333
|
+
reraise=True,
|
|
334
|
+
)
|
|
335
|
+
async def _evaluate_with_retry(evaluator=evaluator, evaluation_context=evaluation_context):
|
|
336
|
+
outputs = await evaluator.evaluate_async(evaluation_context)
|
|
337
|
+
(score, passed, reason) = evaluator.aggregator(outputs)
|
|
338
|
+
return outputs, float(score), passed, reason
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
with self._tracer.start_as_current_span(
|
|
342
|
+
f"evaluator {evaluator.get_type_name()}",
|
|
343
|
+
) as eval_span:
|
|
344
|
+
(
|
|
345
|
+
evaluation_outputs,
|
|
346
|
+
aggregate_score,
|
|
347
|
+
aggregate_pass,
|
|
348
|
+
aggregate_reason,
|
|
349
|
+
) = await _evaluate_with_retry()
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
label = _get_label_from_score(evaluator, aggregate_score)
|
|
353
|
+
except Exception:
|
|
354
|
+
label = "UNKNOWN"
|
|
355
|
+
|
|
356
|
+
eval_span.set_attributes(
|
|
357
|
+
{
|
|
358
|
+
"gen_ai.evaluation.score.label": label,
|
|
359
|
+
"gen_ai.evaluation.score.value": str(aggregate_score),
|
|
360
|
+
"gen_ai.evaluation.test_pass": aggregate_pass,
|
|
361
|
+
"gen_ai.evaluation.explanation": aggregate_reason or "",
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
evaluator_results.append(
|
|
366
|
+
{
|
|
367
|
+
"evaluator_name": evaluator.get_type_name(),
|
|
368
|
+
"test_pass": aggregate_pass,
|
|
369
|
+
"score": aggregate_score,
|
|
370
|
+
"reason": aggregate_reason or "",
|
|
371
|
+
"detailed_results": evaluation_outputs,
|
|
372
|
+
}
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# CloudWatch logging for this evaluator
|
|
376
|
+
try:
|
|
377
|
+
evaluator_full_name = f"Custom.{evaluator.get_type_name()}"
|
|
378
|
+
region = os.environ.get("AWS_REGION", "us-east-1")
|
|
379
|
+
_config_arn = (
|
|
380
|
+
f"arn:aws:strands:{region}::strands-evaluation-empty-config/{self._config_id}"
|
|
381
|
+
)
|
|
382
|
+
_evaluator_arn = f"arn:aws:strands-evals:::evaluator/{evaluator_full_name}"
|
|
273
383
|
|
|
274
|
-
|
|
384
|
+
log_data = {
|
|
385
|
+
"gen_ai.evaluation.name": evaluator_full_name,
|
|
386
|
+
"gen_ai.evaluation.score.value": str(aggregate_score),
|
|
387
|
+
"gen_ai.evaluation.explanation": aggregate_reason or "",
|
|
388
|
+
"gen_ai.evaluation.score.label": label,
|
|
389
|
+
"gen_ai.response.id": trace_id,
|
|
390
|
+
"aws.bedrock_agentcore.evaluator.rating_scale": "Numerical",
|
|
391
|
+
"aws.bedrock_agentcore.evaluation_level": evaluator.evaluation_level or "Trace",
|
|
392
|
+
"event.name": "gen_ai.evaluation.result",
|
|
393
|
+
"aws.bedrock_agentcore.online_evaluation_config.arn": _config_arn,
|
|
394
|
+
"aws.bedrock_agentcore.online_evaluation_config.name": "strands-local-evaluation",
|
|
395
|
+
"aws.bedrock_agentcore.evaluator.arn": _evaluator_arn,
|
|
396
|
+
"session.id": case.session_id,
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
agent_observability_enabled = os.environ.get("AGENT_OBSERVABILITY_ENABLED", "")
|
|
400
|
+
if agent_observability_enabled:
|
|
401
|
+
_send_to_cloudwatch(
|
|
402
|
+
message="gen_ai.evaluation.result",
|
|
403
|
+
log_data=log_data,
|
|
404
|
+
trace_id=trace_id,
|
|
405
|
+
evaluator_name=evaluator_full_name,
|
|
406
|
+
score=cast(float, aggregate_score),
|
|
407
|
+
config_id=self._config_id,
|
|
408
|
+
label=label,
|
|
409
|
+
)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
logger.debug(f"Skipping CloudWatch logging: {str(e)}")
|
|
412
|
+
|
|
413
|
+
except RetryError as e:
|
|
414
|
+
# Max retries exceeded
|
|
415
|
+
original_exception = e.last_attempt.exception()
|
|
416
|
+
if original_exception is None:
|
|
417
|
+
original_exception = Exception(
|
|
418
|
+
f"Evaluator {evaluator.get_type_name()} failed after {_MAX_RETRY_ATTEMPTS} retries"
|
|
419
|
+
)
|
|
420
|
+
logger.error(
|
|
421
|
+
f"Max retry attempts ({_MAX_RETRY_ATTEMPTS}) exceeded for evaluator "
|
|
422
|
+
f"{evaluator.get_type_name()} on case {case_name}. Last error: {str(original_exception)}"
|
|
423
|
+
)
|
|
424
|
+
evaluator_results.append(
|
|
275
425
|
{
|
|
276
|
-
"
|
|
277
|
-
"
|
|
278
|
-
"
|
|
279
|
-
"
|
|
426
|
+
"evaluator_name": evaluator.get_type_name(),
|
|
427
|
+
"test_pass": False,
|
|
428
|
+
"score": 0,
|
|
429
|
+
"reason": f"Evaluator error: {str(original_exception)}",
|
|
430
|
+
"detailed_results": [],
|
|
280
431
|
}
|
|
281
432
|
)
|
|
282
|
-
|
|
433
|
+
except Exception as e:
|
|
434
|
+
# Catch non-throttling errors and record as failure (error isolation)
|
|
283
435
|
evaluator_results.append(
|
|
284
436
|
{
|
|
285
437
|
"evaluator_name": evaluator.get_type_name(),
|
|
286
|
-
"test_pass":
|
|
287
|
-
"score":
|
|
288
|
-
"reason":
|
|
289
|
-
"detailed_results":
|
|
438
|
+
"test_pass": False,
|
|
439
|
+
"score": 0,
|
|
440
|
+
"reason": f"Evaluator error: {str(e)}",
|
|
441
|
+
"detailed_results": [],
|
|
290
442
|
}
|
|
291
443
|
)
|
|
292
444
|
|
|
293
|
-
# CloudWatch logging for this evaluator
|
|
294
|
-
try:
|
|
295
|
-
evaluator_full_name = f"Custom.{evaluator.get_type_name()}"
|
|
296
|
-
region = os.environ.get("AWS_REGION", "us-east-1")
|
|
297
|
-
_config_arn = f"arn:aws:strands:{region}::strands-evaluation-empty-config/{self._config_id}"
|
|
298
|
-
_evaluator_arn = f"arn:aws:strands-evals:::evaluator/{evaluator_full_name}"
|
|
299
|
-
|
|
300
|
-
log_data = {
|
|
301
|
-
"gen_ai.evaluation.name": evaluator_full_name,
|
|
302
|
-
"gen_ai.evaluation.score.value": str(aggregate_score),
|
|
303
|
-
"gen_ai.evaluation.explanation": aggregate_reason or "",
|
|
304
|
-
"gen_ai.evaluation.score.label": label,
|
|
305
|
-
"gen_ai.response.id": trace_id,
|
|
306
|
-
"aws.bedrock_agentcore.evaluator.rating_scale": "Numerical",
|
|
307
|
-
"aws.bedrock_agentcore.evaluation_level": evaluator.evaluation_level or "Trace",
|
|
308
|
-
"event.name": "gen_ai.evaluation.result",
|
|
309
|
-
"aws.bedrock_agentcore.online_evaluation_config.arn": _config_arn,
|
|
310
|
-
"aws.bedrock_agentcore.online_evaluation_config.name": "strands-local-evaluation",
|
|
311
|
-
"aws.bedrock_agentcore.evaluator.arn": _evaluator_arn,
|
|
312
|
-
"session.id": case.session_id,
|
|
313
|
-
}
|
|
314
|
-
|
|
315
|
-
agent_observability_enabled = os.environ.get("AGENT_OBSERVABILITY_ENABLED", "")
|
|
316
|
-
if agent_observability_enabled:
|
|
317
|
-
_send_to_cloudwatch(
|
|
318
|
-
message="gen_ai.evaluation.result",
|
|
319
|
-
log_data=log_data,
|
|
320
|
-
trace_id=trace_id,
|
|
321
|
-
evaluator_name=evaluator_full_name,
|
|
322
|
-
score=aggregate_score,
|
|
323
|
-
config_id=self._config_id,
|
|
324
|
-
label=label,
|
|
325
|
-
)
|
|
326
|
-
except Exception as e:
|
|
327
|
-
logger.debug(f"Skipping CloudWatch logging: {str(e)}")
|
|
328
|
-
|
|
329
445
|
# Store results
|
|
330
446
|
results.append(
|
|
331
447
|
{
|
|
@@ -391,7 +507,16 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
391
507
|
"gen_ai.evaluation.case.input": serialize(case.input),
|
|
392
508
|
},
|
|
393
509
|
) as case_span:
|
|
394
|
-
# Task execution
|
|
510
|
+
# Task execution with retry logic
|
|
511
|
+
@retry(
|
|
512
|
+
retry=retry_if_exception(is_throttling_error),
|
|
513
|
+
stop=stop_after_attempt(_MAX_RETRY_ATTEMPTS),
|
|
514
|
+
wait=wait_exponential(multiplier=_INITIAL_RETRY_DELAY, max=_MAX_RETRY_DELAY),
|
|
515
|
+
reraise=True,
|
|
516
|
+
)
|
|
517
|
+
def _run_task_with_retry(task=task, case=case):
|
|
518
|
+
return self._run_task(task, case)
|
|
519
|
+
|
|
395
520
|
try:
|
|
396
521
|
with self._tracer.start_as_current_span(
|
|
397
522
|
"task_execution",
|
|
@@ -400,7 +525,7 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
400
525
|
"gen_ai.evaluation.case.name": case_name,
|
|
401
526
|
},
|
|
402
527
|
) as task_span:
|
|
403
|
-
evaluation_context =
|
|
528
|
+
evaluation_context = _run_task_with_retry()
|
|
404
529
|
task_span.set_attributes(
|
|
405
530
|
{
|
|
406
531
|
"gen_ai.evaluation.data.input": serialize(evaluation_context.input),
|
|
@@ -414,20 +539,59 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
414
539
|
),
|
|
415
540
|
}
|
|
416
541
|
)
|
|
542
|
+
except RetryError as e:
|
|
543
|
+
# Max retries exceeded
|
|
544
|
+
original_exception = e.last_attempt.exception()
|
|
545
|
+
if original_exception is None:
|
|
546
|
+
original_exception = Exception(f"Task execution failed after {_MAX_RETRY_ATTEMPTS} retries")
|
|
547
|
+
logger.error(
|
|
548
|
+
f"Max retry attempts ({_MAX_RETRY_ATTEMPTS}) exceeded for task execution "
|
|
549
|
+
f"on case {case_name}. Last error: {str(original_exception)}"
|
|
550
|
+
)
|
|
551
|
+
case_span.record_exception(original_exception)
|
|
552
|
+
for evaluator in self._evaluators:
|
|
553
|
+
eval_name = evaluator.get_type_name()
|
|
554
|
+
self._record_evaluator_result(
|
|
555
|
+
evaluator_data=evaluator_data,
|
|
556
|
+
eval_name=eval_name,
|
|
557
|
+
case_data=case.model_dump(),
|
|
558
|
+
test_pass=False,
|
|
559
|
+
score=0,
|
|
560
|
+
reason=f"Task execution error: {str(original_exception)}",
|
|
561
|
+
detailed_results=[],
|
|
562
|
+
)
|
|
563
|
+
continue
|
|
417
564
|
except Exception as e:
|
|
418
565
|
case_span.record_exception(e)
|
|
419
566
|
for evaluator in self._evaluators:
|
|
420
567
|
eval_name = evaluator.get_type_name()
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
568
|
+
self._record_evaluator_result(
|
|
569
|
+
evaluator_data=evaluator_data,
|
|
570
|
+
eval_name=eval_name,
|
|
571
|
+
case_data=case.model_dump(),
|
|
572
|
+
test_pass=False,
|
|
573
|
+
score=0,
|
|
574
|
+
reason=f"Task execution error: {str(e)}",
|
|
575
|
+
detailed_results=[],
|
|
576
|
+
)
|
|
426
577
|
continue
|
|
427
578
|
|
|
428
579
|
# Evaluate with each evaluator using the same task output
|
|
429
580
|
for evaluator in self._evaluators:
|
|
430
581
|
eval_name = evaluator.get_type_name()
|
|
582
|
+
|
|
583
|
+
# Evaluator execution with retry logic
|
|
584
|
+
@retry(
|
|
585
|
+
retry=retry_if_exception(is_throttling_error),
|
|
586
|
+
stop=stop_after_attempt(_MAX_RETRY_ATTEMPTS),
|
|
587
|
+
wait=wait_exponential(multiplier=_INITIAL_RETRY_DELAY, max=_MAX_RETRY_DELAY),
|
|
588
|
+
reraise=True,
|
|
589
|
+
)
|
|
590
|
+
def _evaluate_with_retry(evaluator=evaluator, evaluation_context=evaluation_context):
|
|
591
|
+
outputs = evaluator.evaluate(evaluation_context)
|
|
592
|
+
(score, passed, reason) = evaluator.aggregator(outputs)
|
|
593
|
+
return outputs, float(score), passed, reason
|
|
594
|
+
|
|
431
595
|
try:
|
|
432
596
|
with self._tracer.start_as_current_span(
|
|
433
597
|
f"evaluator {evaluator.get_type_name()}",
|
|
@@ -436,9 +600,8 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
436
600
|
"gen_ai.evaluation.case.name": case_name,
|
|
437
601
|
},
|
|
438
602
|
) as eval_span:
|
|
439
|
-
evaluation_outputs =
|
|
440
|
-
|
|
441
|
-
evaluation_outputs
|
|
603
|
+
evaluation_outputs, aggregate_score, aggregate_pass, aggregate_reason = (
|
|
604
|
+
_evaluate_with_retry()
|
|
442
605
|
)
|
|
443
606
|
eval_span.set_attributes(
|
|
444
607
|
{
|
|
@@ -448,17 +611,45 @@ class Experiment(Generic[InputT, OutputT]):
|
|
|
448
611
|
}
|
|
449
612
|
)
|
|
450
613
|
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
614
|
+
self._record_evaluator_result(
|
|
615
|
+
evaluator_data=evaluator_data,
|
|
616
|
+
eval_name=eval_name,
|
|
617
|
+
case_data=evaluation_context.model_dump(),
|
|
618
|
+
test_pass=aggregate_pass,
|
|
619
|
+
score=aggregate_score,
|
|
620
|
+
reason=aggregate_reason or "",
|
|
621
|
+
detailed_results=evaluation_outputs,
|
|
622
|
+
)
|
|
623
|
+
except RetryError as e:
|
|
624
|
+
# Max retries exceeded
|
|
625
|
+
original_exception = e.last_attempt.exception()
|
|
626
|
+
if original_exception is None:
|
|
627
|
+
original_exception = Exception(
|
|
628
|
+
f"Evaluator {evaluator.get_type_name()} failed after {_MAX_RETRY_ATTEMPTS} retries"
|
|
629
|
+
)
|
|
630
|
+
logger.error(
|
|
631
|
+
f"Max retry attempts ({_MAX_RETRY_ATTEMPTS}) exceeded for evaluator "
|
|
632
|
+
f"{evaluator.get_type_name()} on case {case_name}. Last error: {str(original_exception)}"
|
|
633
|
+
)
|
|
634
|
+
self._record_evaluator_result(
|
|
635
|
+
evaluator_data=evaluator_data,
|
|
636
|
+
eval_name=eval_name,
|
|
637
|
+
case_data=evaluation_context.model_dump(),
|
|
638
|
+
test_pass=False,
|
|
639
|
+
score=0,
|
|
640
|
+
reason=f"Evaluator error: {str(original_exception)}",
|
|
641
|
+
detailed_results=[],
|
|
642
|
+
)
|
|
456
643
|
except Exception as e:
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
644
|
+
self._record_evaluator_result(
|
|
645
|
+
evaluator_data=evaluator_data,
|
|
646
|
+
eval_name=eval_name,
|
|
647
|
+
case_data=evaluation_context.model_dump(),
|
|
648
|
+
test_pass=False,
|
|
649
|
+
score=0,
|
|
650
|
+
reason=f"Evaluator error: {str(e)}",
|
|
651
|
+
detailed_results=[],
|
|
652
|
+
)
|
|
462
653
|
|
|
463
654
|
reports = []
|
|
464
655
|
for evaluator in self._evaluators:
|
|
@@ -45,9 +45,11 @@ class TraceExtractor:
|
|
|
45
45
|
def _extract_trace_level(self, session: Session) -> list[TraceLevelInput]:
|
|
46
46
|
"""Extract trace-level inputs with session history up to each turn."""
|
|
47
47
|
evaluation_inputs: list[TraceLevelInput] = []
|
|
48
|
-
previous_turns: list[Union[UserMessage, AssistantMessage]] = []
|
|
48
|
+
previous_turns: list[Union[UserMessage, list[ToolExecution], AssistantMessage]] = []
|
|
49
49
|
|
|
50
50
|
for trace in session.traces:
|
|
51
|
+
tool_spans = self._find_tool_execution_spans(trace)
|
|
52
|
+
|
|
51
53
|
for span in trace.spans:
|
|
52
54
|
if not isinstance(span, AgentInvocationSpan):
|
|
53
55
|
continue
|
|
@@ -59,6 +61,16 @@ class TraceExtractor:
|
|
|
59
61
|
logger.warning(f"Failed to create user message: {e}")
|
|
60
62
|
continue
|
|
61
63
|
|
|
64
|
+
# Include tool executions in session history
|
|
65
|
+
if tool_spans:
|
|
66
|
+
try:
|
|
67
|
+
tool_executions = [
|
|
68
|
+
ToolExecution(tool_call=ts.tool_call, tool_result=ts.tool_result) for ts in tool_spans
|
|
69
|
+
]
|
|
70
|
+
previous_turns.append(tool_executions)
|
|
71
|
+
except (AttributeError, TypeError, ValueError) as e:
|
|
72
|
+
logger.warning(f"Failed to create tool executions: {e}")
|
|
73
|
+
|
|
62
74
|
trace_input = TraceLevelInput(
|
|
63
75
|
span_info=span.span_info,
|
|
64
76
|
agent_response=TextContent(text=span.agent_response),
|
strands_evals/utils.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from botocore.exceptions import ClientError
|
|
2
|
+
from strands.types.exceptions import EventLoopException, ModelThrottledException
|
|
3
|
+
|
|
4
|
+
THROTTLING_ERROR_CODES = {
|
|
5
|
+
"ThrottlingException",
|
|
6
|
+
"TooManyRequestsException",
|
|
7
|
+
"RequestLimitExceeded",
|
|
8
|
+
"ServiceUnavailable",
|
|
9
|
+
"ProvisionedThroughputExceededException",
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def is_throttling_error(exception: BaseException) -> bool:
|
|
14
|
+
"""
|
|
15
|
+
Check if an exception is a throttling/rate limiting error.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
exception: The exception to check
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
True if the exception indicates throttling, False otherwise
|
|
22
|
+
"""
|
|
23
|
+
# Check for Strands-specific throttling exceptions
|
|
24
|
+
if isinstance(exception, (ModelThrottledException, EventLoopException)):
|
|
25
|
+
return True
|
|
26
|
+
|
|
27
|
+
# Check for botocore.errorfactory.ThrottlingException (dynamically generated)
|
|
28
|
+
if type(exception).__name__ == "ThrottlingException":
|
|
29
|
+
return True
|
|
30
|
+
|
|
31
|
+
# Check for botocore ClientError with throttling error codes
|
|
32
|
+
if isinstance(exception, ClientError):
|
|
33
|
+
error_code = exception.response.get("Error", {}).get("Code", "")
|
|
34
|
+
if error_code in THROTTLING_ERROR_CODES:
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
return False
|
|
File without changes
|
{strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{strands_agents_evals-0.1.3.dist-info → strands_agents_evals-0.1.5.dist-info}/licenses/NOTICE
RENAMED
|
File without changes
|