arthur-common 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (40) hide show
  1. arthur_common/__init__.py +0 -0
  2. arthur_common/__version__.py +1 -0
  3. arthur_common/aggregations/__init__.py +2 -0
  4. arthur_common/aggregations/aggregator.py +214 -0
  5. arthur_common/aggregations/functions/README.md +26 -0
  6. arthur_common/aggregations/functions/__init__.py +25 -0
  7. arthur_common/aggregations/functions/categorical_count.py +89 -0
  8. arthur_common/aggregations/functions/confusion_matrix.py +412 -0
  9. arthur_common/aggregations/functions/inference_count.py +69 -0
  10. arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
  11. arthur_common/aggregations/functions/inference_null_count.py +82 -0
  12. arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
  13. arthur_common/aggregations/functions/mean_squared_error.py +110 -0
  14. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
  15. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
  16. arthur_common/aggregations/functions/numeric_stats.py +90 -0
  17. arthur_common/aggregations/functions/numeric_sum.py +87 -0
  18. arthur_common/aggregations/functions/py.typed +0 -0
  19. arthur_common/aggregations/functions/shield_aggregations.py +752 -0
  20. arthur_common/aggregations/py.typed +0 -0
  21. arthur_common/models/__init__.py +0 -0
  22. arthur_common/models/connectors.py +41 -0
  23. arthur_common/models/datasets.py +22 -0
  24. arthur_common/models/metrics.py +227 -0
  25. arthur_common/models/py.typed +0 -0
  26. arthur_common/models/schema_definitions.py +420 -0
  27. arthur_common/models/shield.py +504 -0
  28. arthur_common/models/task_job_specs.py +78 -0
  29. arthur_common/py.typed +0 -0
  30. arthur_common/tools/__init__.py +0 -0
  31. arthur_common/tools/aggregation_analyzer.py +243 -0
  32. arthur_common/tools/aggregation_loader.py +59 -0
  33. arthur_common/tools/duckdb_data_loader.py +329 -0
  34. arthur_common/tools/functions.py +46 -0
  35. arthur_common/tools/py.typed +0 -0
  36. arthur_common/tools/schema_inferer.py +104 -0
  37. arthur_common/tools/time_utils.py +33 -0
  38. arthur_common-1.0.1.dist-info/METADATA +74 -0
  39. arthur_common-1.0.1.dist-info/RECORD +40 -0
  40. arthur_common-1.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,752 @@
1
+ from typing import Annotated
2
+ from uuid import UUID
3
+
4
+ import pandas as pd
5
+ from arthur_common.aggregations.aggregator import (
6
+ NumericAggregationFunction,
7
+ SketchAggregationFunction,
8
+ )
9
+ from arthur_common.models.datasets import ModelProblemType
10
+ from arthur_common.models.metrics import DatasetReference, NumericMetric, SketchMetric
11
+ from arthur_common.models.schema_definitions import (
12
+ SHIELD_RESPONSE_SCHEMA,
13
+ MetricColumnParameterAnnotation,
14
+ MetricDatasetParameterAnnotation,
15
+ )
16
+ from duckdb import DuckDBPyConnection
17
+ from tokencost import calculate_cost_by_tokens
18
+
19
+
20
+ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
21
+ METRIC_NAME = "inference_count"
22
+
23
+ @staticmethod
24
+ def id() -> UUID:
25
+ return UUID("00000000-0000-0000-0000-000000000001")
26
+
27
+ @staticmethod
28
+ def display_name() -> str:
29
+ return "Inference Pass/Fail Count"
30
+
31
+ @staticmethod
32
+ def description() -> str:
33
+ return "Metric that counts the number of Shield inferences grouped by the prompt, response, and overall check results."
34
+
35
+ def aggregate(
36
+ self,
37
+ ddb_conn: DuckDBPyConnection,
38
+ dataset: Annotated[
39
+ DatasetReference,
40
+ MetricDatasetParameterAnnotation(
41
+ friendly_name="Dataset",
42
+ description="The task inference dataset sourced from Arthur Shield.",
43
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
44
+ ),
45
+ ],
46
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
47
+ shield_response_column: Annotated[
48
+ str,
49
+ MetricColumnParameterAnnotation(
50
+ source_dataset_parameter_key="dataset",
51
+ allowed_column_types=[
52
+ SHIELD_RESPONSE_SCHEMA,
53
+ ],
54
+ friendly_name="Shield Response Column",
55
+ description="The Shield response column from the task inference dataset.",
56
+ ),
57
+ ],
58
+ ) -> list[NumericMetric]:
59
+ results = ddb_conn.sql(
60
+ f"select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, count(*) as count, \
61
+ result, \
62
+ inference_prompt.result AS prompt_result, \
63
+ inference_response.result AS response_result \
64
+ from {dataset.dataset_table_name} \
65
+ group by ts, result, prompt_result, response_result \
66
+ order by ts desc; \
67
+ ",
68
+ ).df()
69
+ group_by_dims = ["result", "prompt_result", "response_result"]
70
+ series = self.group_query_results_to_numeric_metrics(
71
+ results,
72
+ "count",
73
+ group_by_dims,
74
+ "ts",
75
+ )
76
+ metric = self.series_to_metric(self.METRIC_NAME, series)
77
+ return [metric]
78
+
79
+
80
+ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
81
+ METRIC_NAME = "rule_count"
82
+
83
+ @staticmethod
84
+ def id() -> UUID:
85
+ return UUID("00000000-0000-0000-0000-000000000002")
86
+
87
+ @staticmethod
88
+ def display_name() -> str:
89
+ return "Rule Result Count"
90
+
91
+ @staticmethod
92
+ def description() -> str:
93
+ return "Metric that counts the number of Shield rule evaluations grouped by whether it was on the prompt or response, the rule type, the rule evaluation result, the rule name, and the rule id."
94
+
95
+ def aggregate(
96
+ self,
97
+ ddb_conn: DuckDBPyConnection,
98
+ dataset: Annotated[
99
+ DatasetReference,
100
+ MetricDatasetParameterAnnotation(
101
+ friendly_name="Dataset",
102
+ description="The task inference dataset sourced from Arthur Shield.",
103
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
104
+ ),
105
+ ],
106
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
107
+ shield_response_column: Annotated[
108
+ str,
109
+ MetricColumnParameterAnnotation(
110
+ source_dataset_parameter_key="dataset",
111
+ allowed_column_types=[
112
+ SHIELD_RESPONSE_SCHEMA,
113
+ ],
114
+ friendly_name="Shield Response Column",
115
+ description="The Shield response column from the task inference dataset.",
116
+ ),
117
+ ],
118
+ ) -> list[NumericMetric]:
119
+ results = ddb_conn.sql(
120
+ f" \
121
+ with unnessted_prompt_rules as (select unnest(inference_prompt.prompt_rule_results) as rule, \
122
+ 'prompt' as location, \
123
+ time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts \
124
+ from {dataset.dataset_table_name}), \
125
+ unnessted_result_rules as (select unnest(inference_response.response_rule_results) as rule,\
126
+ 'response' as location, \
127
+ time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts \
128
+ from {dataset.dataset_table_name}) \
129
+ select ts, \
130
+ count(*) as count, \
131
+ location, \
132
+ rule.rule_type, \
133
+ rule.result, \
134
+ rule.name, \
135
+ rule.id \
136
+ from unnessted_prompt_rules \
137
+ group by ts, location, rule.rule_type, rule.result, rule.name, rule.id \
138
+ UNION ALL \
139
+ select ts, \
140
+ count(*) as count, \
141
+ location, \
142
+ rule.rule_type, \
143
+ rule.result, \
144
+ rule.name, \
145
+ rule.id \
146
+ from unnessted_result_rules \
147
+ group by ts, location, rule.rule_type, rule.result, rule.name, rule.id \
148
+ order by ts desc, location, rule.rule_type, rule.result; \
149
+ ",
150
+ ).df()
151
+
152
+ group_by_dims = ["location", "rule_type", "result", "name", "id"]
153
+ series = self.group_query_results_to_numeric_metrics(
154
+ results,
155
+ "count",
156
+ group_by_dims,
157
+ "ts",
158
+ )
159
+ metric = self.series_to_metric(self.METRIC_NAME, series)
160
+ return [metric]
161
+
162
+
163
+ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
164
+ METRIC_NAME = "hallucination_count"
165
+
166
+ @staticmethod
167
+ def id() -> UUID:
168
+ return UUID("00000000-0000-0000-0000-000000000003")
169
+
170
+ @staticmethod
171
+ def display_name() -> str:
172
+ return "Hallucination Count"
173
+
174
+ @staticmethod
175
+ def description() -> str:
176
+ return "Metric that counts the number of Shield hallucination evaluations that failed."
177
+
178
+ def aggregate(
179
+ self,
180
+ ddb_conn: DuckDBPyConnection,
181
+ dataset: Annotated[
182
+ DatasetReference,
183
+ MetricDatasetParameterAnnotation(
184
+ friendly_name="Dataset",
185
+ description="The task inference dataset sourced from Arthur Shield.",
186
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
187
+ ),
188
+ ],
189
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
190
+ shield_response_column: Annotated[
191
+ str,
192
+ MetricColumnParameterAnnotation(
193
+ source_dataset_parameter_key="dataset",
194
+ allowed_column_types=[
195
+ SHIELD_RESPONSE_SCHEMA,
196
+ ],
197
+ friendly_name="Shield Response Column",
198
+ description="The Shield response column from the task inference dataset.",
199
+ ),
200
+ ],
201
+ ) -> list[NumericMetric]:
202
+ results = ddb_conn.sql(
203
+ f" \
204
+ select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
205
+ count(*) as count \
206
+ from {dataset.dataset_table_name} \
207
+ where length(list_filter(inference_response.response_rule_results, x -> (x.rule_type = 'ModelHallucinationRuleV2' or x.rule_type = 'ModelHallucinationRule') and x.result = 'Fail')) > 0 \
208
+ group by ts \
209
+ order by ts desc; \
210
+ ",
211
+ ).df()
212
+
213
+ series = [
214
+ self.dimensionless_query_results_to_numeric_metrics(results, "count", "ts"),
215
+ ]
216
+ metric = self.series_to_metric(self.METRIC_NAME, series)
217
+ return [metric]
218
+
219
+
220
+ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
221
+ METRIC_NAME = "toxicity_score"
222
+
223
+ @staticmethod
224
+ def id() -> UUID:
225
+ return UUID("00000000-0000-0000-0000-000000000004")
226
+
227
+ @staticmethod
228
+ def display_name() -> str:
229
+ return "Toxicity Distribution"
230
+
231
+ @staticmethod
232
+ def description() -> str:
233
+ return "Metric that reports a distribution (data sketch) on toxicity scores returned by the Shield toxicity rule."
234
+
235
+ def aggregate(
236
+ self,
237
+ ddb_conn: DuckDBPyConnection,
238
+ dataset: Annotated[
239
+ DatasetReference,
240
+ MetricDatasetParameterAnnotation(
241
+ friendly_name="Dataset",
242
+ description="The task inference dataset sourced from Arthur Shield.",
243
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
244
+ ),
245
+ ],
246
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
247
+ shield_response_column: Annotated[
248
+ str,
249
+ MetricColumnParameterAnnotation(
250
+ source_dataset_parameter_key="dataset",
251
+ allowed_column_types=[
252
+ SHIELD_RESPONSE_SCHEMA,
253
+ ],
254
+ friendly_name="Shield Response Column",
255
+ description="The Shield response column from the task inference dataset.",
256
+ ),
257
+ ],
258
+ ) -> list[SketchMetric]:
259
+ results = ddb_conn.sql(
260
+ f"\
261
+ with unnested_prompt_results as (select to_timestamp(created_at / 1000) as ts, \
262
+ unnest(inference_prompt.prompt_rule_results) as rule_results, \
263
+ 'prompt' as location \
264
+ from {dataset.dataset_table_name}), \
265
+ unnested_response_results as (select to_timestamp(created_at / 1000) as ts, \
266
+ unnest(inference_response.response_rule_results) as rule_results, \
267
+ 'response' as location \
268
+ from {dataset.dataset_table_name}) \
269
+ select ts as timestamp, \
270
+ rule_results.details.toxicity_score::DOUBLE as toxicity_score, \
271
+ rule_results.result as result, \
272
+ location \
273
+ from unnested_prompt_results \
274
+ where rule_results.details.toxicity_score IS NOT NULL \
275
+ UNION ALL \
276
+ select ts as timestamp, \
277
+ rule_results.details.toxicity_score::DOUBLE as toxicity_score, \
278
+ rule_results.result as result, \
279
+ location \
280
+ from unnested_response_results \
281
+ where rule_results.details.toxicity_score IS NOT NULL \
282
+ order by ts desc; \
283
+ ",
284
+ ).df()
285
+
286
+ series = self.group_query_results_to_sketch_metrics(
287
+ results,
288
+ "toxicity_score",
289
+ ["result", "location"],
290
+ "timestamp",
291
+ )
292
+ metric = self.series_to_metric(self.METRIC_NAME, series)
293
+ return [metric]
294
+
295
+
296
+ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
297
+ METRIC_NAME = "pii_score"
298
+
299
+ @staticmethod
300
+ def id() -> UUID:
301
+ return UUID("00000000-0000-0000-0000-000000000005")
302
+
303
+ @staticmethod
304
+ def display_name() -> str:
305
+ return "PII Score Distribution"
306
+
307
+ @staticmethod
308
+ def description() -> str:
309
+ return "Metric that reports a distribution (data sketch) on PII scores returned by the Shield PII rule."
310
+
311
+ def aggregate(
312
+ self,
313
+ ddb_conn: DuckDBPyConnection,
314
+ dataset: Annotated[
315
+ DatasetReference,
316
+ MetricDatasetParameterAnnotation(
317
+ friendly_name="Dataset",
318
+ description="The task inference dataset sourced from Arthur Shield.",
319
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
320
+ ),
321
+ ],
322
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
323
+ shield_response_column: Annotated[
324
+ str,
325
+ MetricColumnParameterAnnotation(
326
+ source_dataset_parameter_key="dataset",
327
+ allowed_column_types=[
328
+ SHIELD_RESPONSE_SCHEMA,
329
+ ],
330
+ friendly_name="Shield Response Column",
331
+ description="The Shield response column from the task inference dataset.",
332
+ ),
333
+ ],
334
+ ) -> list[SketchMetric]:
335
+ results = ddb_conn.sql(
336
+ f"\
337
+ with unnested_prompt_results as (select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
338
+ unnest(inference_prompt.prompt_rule_results) as rule_results, \
339
+ 'prompt' as location \
340
+ from {dataset.dataset_table_name}), \
341
+ unnested_response_results as (select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
342
+ unnest(inference_response.response_rule_results) as rule_results, \
343
+ 'response' as location \
344
+ from {dataset.dataset_table_name}), \
345
+ unnested_entites as (select ts, \
346
+ rule_results.result, \
347
+ rule_results.rule_type, \
348
+ location, \
349
+ unnest(rule_results.details.pii_entities) as pii_entity \
350
+ from unnested_response_results \
351
+ where rule_results.rule_type = 'PIIDataRule' \
352
+ \
353
+ UNION ALL \
354
+ \
355
+ select ts, \
356
+ rule_results.result, \
357
+ rule_results.rule_type, \
358
+ location, \
359
+ unnest(rule_results.details.pii_entities) as pii_entity \
360
+ from unnested_prompt_results \
361
+ where rule_results.rule_type = 'PIIDataRule') \
362
+ select ts as timestamp, result, rule_type, location, TRY_CAST(pii_entity.confidence AS FLOAT) as pii_score, pii_entity.entity as entity \
363
+ from unnested_entites \
364
+ order by ts desc; \
365
+ ",
366
+ ).df()
367
+
368
+ series = self.group_query_results_to_sketch_metrics(
369
+ results,
370
+ "pii_score",
371
+ ["result", "location", "entity"],
372
+ "timestamp",
373
+ )
374
+ metric = self.series_to_metric(self.METRIC_NAME, series)
375
+ return [metric]
376
+
377
+
378
+ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
379
+ METRIC_NAME = "claim_count"
380
+
381
+ @staticmethod
382
+ def id() -> UUID:
383
+ return UUID("00000000-0000-0000-0000-000000000006")
384
+
385
+ @staticmethod
386
+ def display_name() -> str:
387
+ return "Claim Count Distribution - All Claims"
388
+
389
+ @staticmethod
390
+ def description() -> str:
391
+ return "Metric that reports a distribution (data sketch) on over the number of claims identified by the Shield hallucination rule."
392
+
393
+ def aggregate(
394
+ self,
395
+ ddb_conn: DuckDBPyConnection,
396
+ dataset: Annotated[
397
+ DatasetReference,
398
+ MetricDatasetParameterAnnotation(
399
+ friendly_name="Dataset",
400
+ description="The task inference dataset sourced from Arthur Shield.",
401
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
402
+ ),
403
+ ],
404
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
405
+ shield_response_column: Annotated[
406
+ str,
407
+ MetricColumnParameterAnnotation(
408
+ source_dataset_parameter_key="dataset",
409
+ allowed_column_types=[
410
+ SHIELD_RESPONSE_SCHEMA,
411
+ ],
412
+ friendly_name="Shield Response Column",
413
+ description="The Shield response column from the task inference dataset.",
414
+ ),
415
+ ],
416
+ ) -> list[SketchMetric]:
417
+ results = ddb_conn.sql(
418
+ f"\
419
+ with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
420
+ unnest(inference_response.response_rule_results) as rule_results \
421
+ from {dataset.dataset_table_name}) \
422
+ select ts as timestamp, \
423
+ length(rule_results.details.claims) as num_claims, \
424
+ rule_results.result as result \
425
+ from unnested_results \
426
+ where rule_results.rule_type = 'ModelHallucinationRuleV2' \
427
+ and rule_results.result != 'Skipped' \
428
+ order by ts desc; \
429
+ ",
430
+ ).df()
431
+
432
+ series = self.group_query_results_to_sketch_metrics(
433
+ results,
434
+ "num_claims",
435
+ ["result"],
436
+ "timestamp",
437
+ )
438
+ metric = self.series_to_metric(self.METRIC_NAME, series)
439
+ return [metric]
440
+
441
+
442
+ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
443
+ METRIC_NAME = "claim_valid_count"
444
+
445
+ @staticmethod
446
+ def id() -> UUID:
447
+ return UUID("00000000-0000-0000-0000-000000000007")
448
+
449
+ @staticmethod
450
+ def display_name() -> str:
451
+ return "Claim Count Distribution - Valid Claims"
452
+
453
+ @staticmethod
454
+ def description() -> str:
455
+ return "Metric that reports a distribution (data sketch) on the number of valid claims determined by the Shield hallucination rule."
456
+
457
+ def aggregate(
458
+ self,
459
+ ddb_conn: DuckDBPyConnection,
460
+ dataset: Annotated[
461
+ DatasetReference,
462
+ MetricDatasetParameterAnnotation(
463
+ friendly_name="Dataset",
464
+ description="The task inference dataset sourced from Arthur Shield.",
465
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
466
+ ),
467
+ ],
468
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
469
+ shield_response_column: Annotated[
470
+ str,
471
+ MetricColumnParameterAnnotation(
472
+ source_dataset_parameter_key="dataset",
473
+ allowed_column_types=[
474
+ SHIELD_RESPONSE_SCHEMA,
475
+ ],
476
+ friendly_name="Shield Response Column",
477
+ description="The Shield response column from the task inference dataset.",
478
+ ),
479
+ ],
480
+ ) -> list[SketchMetric]:
481
+ results = ddb_conn.sql(
482
+ f"\
483
+ with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
484
+ unnest(inference_response.response_rule_results) as rule_results \
485
+ from {dataset.dataset_table_name}) \
486
+ select ts as timestamp, \
487
+ length(list_filter(rule_results.details.claims, x -> x.valid)) as num_valid_claims, \
488
+ rule_results.result as result \
489
+ from unnested_results \
490
+ where rule_results.rule_type = 'ModelHallucinationRuleV2' \
491
+ and rule_results.result != 'Skipped' \
492
+ order by ts desc; \
493
+ ",
494
+ ).df()
495
+
496
+ series = self.group_query_results_to_sketch_metrics(
497
+ results,
498
+ "num_valid_claims",
499
+ ["result"],
500
+ "timestamp",
501
+ )
502
+ metric = self.series_to_metric(self.METRIC_NAME, series)
503
+ return [metric]
504
+
505
+
506
+ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
507
+ METRIC_NAME = "claim_invalid_count"
508
+
509
+ @staticmethod
510
+ def id() -> UUID:
511
+ return UUID("00000000-0000-0000-0000-000000000008")
512
+
513
+ @staticmethod
514
+ def display_name() -> str:
515
+ return "Claim Count Distribution - Invalid Claims"
516
+
517
+ @staticmethod
518
+ def description() -> str:
519
+ return "Metric that reports a distribution (data sketch) on the number of invalid claims determined by the Shield hallucination rule."
520
+
521
+ def aggregate(
522
+ self,
523
+ ddb_conn: DuckDBPyConnection,
524
+ dataset: Annotated[
525
+ DatasetReference,
526
+ MetricDatasetParameterAnnotation(
527
+ friendly_name="Dataset",
528
+ description="The task inference dataset sourced from Arthur Shield.",
529
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
530
+ ),
531
+ ],
532
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
533
+ shield_response_column: Annotated[
534
+ str,
535
+ MetricColumnParameterAnnotation(
536
+ source_dataset_parameter_key="dataset",
537
+ allowed_column_types=[
538
+ SHIELD_RESPONSE_SCHEMA,
539
+ ],
540
+ friendly_name="Shield Response Column",
541
+ description="The Shield response column from the task inference dataset.",
542
+ ),
543
+ ],
544
+ ) -> list[SketchMetric]:
545
+ results = ddb_conn.sql(
546
+ f"\
547
+ with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
548
+ unnest(inference_response.response_rule_results) as rule_results \
549
+ from {dataset.dataset_table_name}) \
550
+ select ts as timestamp, \
551
+ length(list_filter(rule_results.details.claims, x -> not x.valid)) as num_failed_claims, \
552
+ rule_results.result as result \
553
+ from unnested_results \
554
+ where rule_results.rule_type = 'ModelHallucinationRuleV2' \
555
+ and rule_results.result != 'Skipped' \
556
+ order by ts desc; \
557
+ ",
558
+ ).df()
559
+
560
+ series = self.group_query_results_to_sketch_metrics(
561
+ results,
562
+ "num_failed_claims",
563
+ ["result"],
564
+ "timestamp",
565
+ )
566
+ metric = self.series_to_metric(self.METRIC_NAME, series)
567
+ return [metric]
568
+
569
+
570
+ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
571
+ METRIC_NAME = "rule_latency"
572
+
573
+ @staticmethod
574
+ def id() -> UUID:
575
+ return UUID("00000000-0000-0000-0000-000000000009")
576
+
577
+ @staticmethod
578
+ def display_name() -> str:
579
+ return "Rule Latency Distribution"
580
+
581
+ @staticmethod
582
+ def description() -> str:
583
+ return "Metric that reports a distribution (data sketch) on the latency of Shield rule evaluations. Dimensions are the rule result, rule type, and whether the rule was applicable to a prompt or response."
584
+
585
+ def aggregate(
586
+ self,
587
+ ddb_conn: DuckDBPyConnection,
588
+ dataset: Annotated[
589
+ DatasetReference,
590
+ MetricDatasetParameterAnnotation(
591
+ friendly_name="Dataset",
592
+ description="The task inference dataset sourced from Arthur Shield.",
593
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
594
+ ),
595
+ ],
596
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
597
+ shield_response_column: Annotated[
598
+ str,
599
+ MetricColumnParameterAnnotation(
600
+ source_dataset_parameter_key="dataset",
601
+ allowed_column_types=[
602
+ SHIELD_RESPONSE_SCHEMA,
603
+ ],
604
+ friendly_name="Shield Response Column",
605
+ description="The Shield response column from the task inference dataset.",
606
+ ),
607
+ ],
608
+ ) -> list[SketchMetric]:
609
+ results = ddb_conn.sql(
610
+ f" \
611
+ with unnested_prompt_rules as (select unnest(inference_prompt.prompt_rule_results) as rule, \
612
+ 'prompt' as location, \
613
+ to_timestamp(created_at / 1000) as ts, \
614
+ from {dataset.dataset_table_name}), \
615
+ unnested_response_rules as (select unnest(inference_response.response_rule_results) as rule,\
616
+ 'response' as location, \
617
+ to_timestamp(created_at / 1000) as ts, \
618
+ from {dataset.dataset_table_name}) \
619
+ select ts, \
620
+ location, \
621
+ rule.rule_type, \
622
+ rule.result, \
623
+ rule.latency_ms \
624
+ from unnested_prompt_rules \
625
+ UNION ALL \
626
+ select ts, \
627
+ location, \
628
+ rule.rule_type, \
629
+ rule.result, \
630
+ rule.latency_ms \
631
+ from unnested_response_rules \
632
+ ",
633
+ ).df()
634
+
635
+ series = self.group_query_results_to_sketch_metrics(
636
+ results,
637
+ "latency_ms",
638
+ ["result", "rule_type", "location"],
639
+ "ts",
640
+ )
641
+ metric = self.series_to_metric(self.METRIC_NAME, series)
642
+ return [metric]
643
+
644
+
645
+ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
646
+ METRIC_NAME = "token_count"
647
+
648
+ @staticmethod
649
+ def id() -> UUID:
650
+ return UUID("00000000-0000-0000-0000-000000000021")
651
+
652
+ @staticmethod
653
+ def display_name() -> str:
654
+ return "Token Count"
655
+
656
+ @staticmethod
657
+ def description() -> str:
658
+ return "Metric that reports the number of tokens in the Shield response and prompt schemas, and their estimated cost."
659
+
660
+ def aggregate(
661
+ self,
662
+ ddb_conn: DuckDBPyConnection,
663
+ dataset: Annotated[
664
+ DatasetReference,
665
+ MetricDatasetParameterAnnotation(
666
+ friendly_name="Dataset",
667
+ description="The task inference dataset sourced from Arthur Shield.",
668
+ model_problem_type=ModelProblemType.ARTHUR_SHIELD,
669
+ ),
670
+ ],
671
+ # This parameter exists mostly to work with the aggregation matcher such that we don't need to have any special handling for shield
672
+ shield_response_column: Annotated[
673
+ str,
674
+ MetricColumnParameterAnnotation(
675
+ source_dataset_parameter_key="dataset",
676
+ allowed_column_types=[
677
+ SHIELD_RESPONSE_SCHEMA,
678
+ ],
679
+ friendly_name="Shield Response Column",
680
+ description="The Shield response column from the task inference dataset.",
681
+ ),
682
+ ],
683
+ ) -> list[NumericMetric]:
684
+ results = ddb_conn.sql(
685
+ f" \
686
+ select \
687
+ time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
688
+ COALESCE(sum(inference_prompt.tokens), 0) as tokens, \
689
+ 'prompt' as location \
690
+ from {dataset.dataset_table_name} \
691
+ group by time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)), location \
692
+ UNION ALL \
693
+ select \
694
+ time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
695
+ COALESCE(sum(inference_response.tokens), 0) as tokens, \
696
+ 'response' as location \
697
+ from {dataset.dataset_table_name} \
698
+ group by time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)), location; \
699
+ ",
700
+ ).df()
701
+
702
+ series = self.group_query_results_to_numeric_metrics(
703
+ results,
704
+ "tokens",
705
+ ["location"],
706
+ "ts",
707
+ )
708
+ metric = self.series_to_metric(self.METRIC_NAME, series)
709
+ resp = [metric]
710
+
711
+ # Compute Cost for each model
712
+ models = [
713
+ "gpt-4o",
714
+ "gpt-4o-mini",
715
+ "gpt-3.5-turbo",
716
+ "o1-mini",
717
+ "deepseek-chat",
718
+ "claude-3-5-sonnet-20241022",
719
+ "gemini/gemini-1.5-pro",
720
+ "meta.llama3-1-8b-instruct-v1:0",
721
+ "meta.llama3-1-70b-instruct-v1:0",
722
+ "meta.llama3-2-11b-instruct-v1:0",
723
+ ]
724
+
725
+ # Precompute input/output classification to avoid recalculating in loop
726
+ location_type = results["location"].apply(
727
+ lambda x: "input" if x == "prompt" else "output",
728
+ )
729
+
730
+ for model in models:
731
+ # Efficient list comprehension instead of apply
732
+ cost_values = [
733
+ calculate_cost_by_tokens(int(tokens), model, loc_type)
734
+ for tokens, loc_type in zip(results["tokens"], location_type)
735
+ ]
736
+
737
+ model_df = pd.DataFrame(
738
+ {
739
+ "ts": results["ts"],
740
+ "cost": cost_values,
741
+ "location": results["location"],
742
+ },
743
+ )
744
+
745
+ model_series = self.group_query_results_to_numeric_metrics(
746
+ model_df,
747
+ "cost",
748
+ ["location"],
749
+ "ts",
750
+ )
751
+ resp.append(self.series_to_metric(f"token_cost.{model}", model_series))
752
+ return resp