arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. arthur_common/aggregations/aggregator.py +73 -9
  2. arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
  3. arthur_common/aggregations/functions/categorical_count.py +15 -15
  4. arthur_common/aggregations/functions/confusion_matrix.py +24 -26
  5. arthur_common/aggregations/functions/inference_count.py +5 -9
  6. arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
  7. arthur_common/aggregations/functions/inference_null_count.py +10 -13
  8. arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
  9. arthur_common/aggregations/functions/mean_squared_error.py +12 -18
  10. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
  11. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
  12. arthur_common/aggregations/functions/numeric_stats.py +13 -15
  13. arthur_common/aggregations/functions/numeric_sum.py +12 -15
  14. arthur_common/aggregations/functions/shield_aggregations.py +457 -215
  15. arthur_common/models/common_schemas.py +214 -0
  16. arthur_common/models/connectors.py +10 -2
  17. arthur_common/models/constants.py +24 -0
  18. arthur_common/models/datasets.py +0 -9
  19. arthur_common/models/enums.py +177 -0
  20. arthur_common/models/metric_schemas.py +63 -0
  21. arthur_common/models/metrics.py +2 -9
  22. arthur_common/models/request_schemas.py +870 -0
  23. arthur_common/models/response_schemas.py +785 -0
  24. arthur_common/models/schema_definitions.py +6 -1
  25. arthur_common/models/task_job_specs.py +3 -12
  26. arthur_common/tools/duckdb_data_loader.py +34 -2
  27. arthur_common/tools/duckdb_utils.py +3 -6
  28. arthur_common/tools/schema_inferer.py +3 -6
  29. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
  30. arthur_common-2.4.13.dist-info/RECORD +49 -0
  31. arthur_common/models/shield.py +0 -642
  32. arthur_common-2.1.58.dist-info/RECORD +0 -44
  33. {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
@@ -9,7 +9,7 @@ from arthur_common.aggregations.aggregator import (
9
9
  NumericAggregationFunction,
10
10
  SketchAggregationFunction,
11
11
  )
12
- from arthur_common.models.datasets import ModelProblemType
12
+ from arthur_common.models.enums import ModelProblemType
13
13
  from arthur_common.models.metrics import (
14
14
  BaseReportedAggregation,
15
15
  DatasetReference,
@@ -25,6 +25,7 @@ from arthur_common.models.schema_definitions import (
25
25
 
26
26
  class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
27
27
  METRIC_NAME = "inference_count"
28
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_PASS_FAIL_COUNT_AGGREGATION_SEGMENTATION"
28
29
 
29
30
  @staticmethod
30
31
  def id() -> UUID:
@@ -71,17 +72,41 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
71
72
  ),
72
73
  ],
73
74
  ) -> list[NumericMetric]:
74
- results = ddb_conn.sql(
75
- f"select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, count(*) as count, \
76
- result, \
77
- inference_prompt.result AS prompt_result, \
78
- inference_response.result AS response_result \
79
- from {dataset.dataset_table_name} \
80
- group by ts, result, prompt_result, response_result \
81
- order by ts desc; \
82
- ",
83
- ).df()
84
- group_by_dims = ["result", "prompt_result", "response_result"]
75
+ # Build SELECT clause
76
+ select_cols = [
77
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
78
+ "count(*) as count",
79
+ "result",
80
+ "inference_prompt.result AS prompt_result",
81
+ "inference_response.result AS response_result",
82
+ ]
83
+
84
+ # Build GROUP BY clause
85
+ group_by_cols = ["ts", "result", "prompt_result", "response_result"]
86
+
87
+ # Conditionally add conversation_id and user_id based on segmentation flag
88
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
89
+ select_cols.extend(["conversation_id", "user_id as user_id"])
90
+ group_by_cols.extend(["conversation_id", "user_id"])
91
+
92
+ query = f"""
93
+ select {", ".join(select_cols)}
94
+ from {dataset.dataset_table_name}
95
+ group by {", ".join(group_by_cols)}
96
+ order by ts desc;
97
+ """
98
+
99
+ results = ddb_conn.sql(query).df()
100
+
101
+ # Build group_by_dims list
102
+ group_by_dims = [
103
+ "result",
104
+ "prompt_result",
105
+ "response_result",
106
+ ]
107
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
108
+ group_by_dims.extend(["conversation_id", "user_id"])
109
+
85
110
  series = self.group_query_results_to_numeric_metrics(
86
111
  results,
87
112
  "count",
@@ -94,6 +119,7 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
94
119
 
95
120
  class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
96
121
  METRIC_NAME = "rule_count"
122
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_COUNT_AGGREGATION_SEGMENTATION"
97
123
 
98
124
  @staticmethod
99
125
  def id() -> UUID:
@@ -140,40 +166,72 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
140
166
  ),
141
167
  ],
142
168
  ) -> list[NumericMetric]:
143
- results = ddb_conn.sql(
144
- f" \
145
- with unnessted_prompt_rules as (select unnest(inference_prompt.prompt_rule_results) as rule, \
146
- 'prompt' as location, \
147
- time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts \
148
- from {dataset.dataset_table_name}), \
149
- unnessted_result_rules as (select unnest(inference_response.response_rule_results) as rule,\
150
- 'response' as location, \
151
- time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts \
152
- from {dataset.dataset_table_name}) \
153
- select ts, \
154
- count(*) as count, \
155
- location, \
156
- rule.rule_type, \
157
- rule.result, \
158
- rule.name, \
159
- rule.id \
160
- from unnessted_prompt_rules \
161
- group by ts, location, rule.rule_type, rule.result, rule.name, rule.id \
162
- UNION ALL \
163
- select ts, \
164
- count(*) as count, \
165
- location, \
166
- rule.rule_type, \
167
- rule.result, \
168
- rule.name, \
169
- rule.id \
170
- from unnessted_result_rules \
171
- group by ts, location, rule.rule_type, rule.result, rule.name, rule.id \
172
- order by ts desc, location, rule.rule_type, rule.result; \
173
- ",
174
- ).df()
175
-
176
- group_by_dims = ["location", "rule_type", "result", "name", "id"]
169
+ # Build CTE select columns
170
+ prompt_cte_select = [
171
+ "unnest(inference_prompt.prompt_rule_results) as rule",
172
+ "'prompt' as location",
173
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
174
+ ]
175
+ response_cte_select = [
176
+ "unnest(inference_response.response_rule_results) as rule",
177
+ "'response' as location",
178
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
179
+ ]
180
+
181
+ # Build main select columns
182
+ main_select_cols = [
183
+ "ts",
184
+ "count(*) as count",
185
+ "location",
186
+ "rule.rule_type",
187
+ "rule.result",
188
+ "rule.name",
189
+ "rule.id",
190
+ ]
191
+
192
+ # Build group by columns
193
+ group_by_cols = [
194
+ "ts",
195
+ "location",
196
+ "rule.rule_type",
197
+ "rule.result",
198
+ "rule.name",
199
+ "rule.id",
200
+ ]
201
+
202
+ # Conditionally add conversation_id and user_id
203
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
204
+ prompt_cte_select.extend(["conversation_id", "user_id"])
205
+ response_cte_select.extend(["conversation_id", "user_id"])
206
+ main_select_cols.extend(["conversation_id", "user_id"])
207
+ group_by_cols.extend(["conversation_id", "user_id"])
208
+
209
+ query = f"""
210
+ with unnessted_prompt_rules as (select {", ".join(prompt_cte_select)}
211
+ from {dataset.dataset_table_name}),
212
+ unnessted_result_rules as (select {", ".join(response_cte_select)}
213
+ from {dataset.dataset_table_name})
214
+ select {", ".join(main_select_cols)}
215
+ from unnessted_prompt_rules
216
+ group by {", ".join(group_by_cols)}
217
+ UNION ALL
218
+ select {", ".join(main_select_cols)}
219
+ from unnessted_result_rules
220
+ group by {", ".join(group_by_cols)}
221
+ order by ts desc, location, rule.rule_type, rule.result;
222
+ """
223
+
224
+ results = ddb_conn.sql(query).df()
225
+
226
+ group_by_dims = [
227
+ "location",
228
+ "rule_type",
229
+ "result",
230
+ "name",
231
+ "id",
232
+ ]
233
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
234
+ group_by_dims.extend(["conversation_id", "user_id"])
177
235
  series = self.group_query_results_to_numeric_metrics(
178
236
  results,
179
237
  "count",
@@ -186,6 +244,7 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
186
244
 
187
245
  class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
188
246
  METRIC_NAME = "hallucination_count"
247
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_HALLUCINATION_COUNT_AGGREGATION_SEGMENTATION"
189
248
 
190
249
  @staticmethod
191
250
  def id() -> UUID:
@@ -232,24 +291,46 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
232
291
  ),
233
292
  ],
234
293
  ) -> list[NumericMetric]:
235
- results = ddb_conn.sql(
236
- f" \
237
- select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
238
- count(*) as count \
239
- from {dataset.dataset_table_name} \
240
- where length(list_filter(inference_response.response_rule_results, x -> (x.rule_type = 'ModelHallucinationRuleV2' or x.rule_type = 'ModelHallucinationRule') and x.result = 'Fail')) > 0 \
241
- group by ts \
242
- order by ts desc; \
243
- ",
244
- ).df()
245
-
246
- series = self.group_query_results_to_numeric_metrics(results, "count", [], "ts")
294
+ # Build SELECT clause
295
+ select_cols = [
296
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
297
+ "count(*) as count",
298
+ ]
299
+
300
+ # Build GROUP BY clause
301
+ group_by_cols = ["ts"]
302
+
303
+ # Conditionally add conversation_id and user_id
304
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
305
+ select_cols.extend(["conversation_id", "user_id"])
306
+ group_by_cols.extend(["conversation_id", "user_id"])
307
+
308
+ query = f"""
309
+ select {", ".join(select_cols)}
310
+ from {dataset.dataset_table_name}
311
+ where length(list_filter(inference_response.response_rule_results, x -> (x.rule_type = 'ModelHallucinationRuleV2' or x.rule_type = 'ModelHallucinationRule') and x.result = 'Fail')) > 0
312
+ group by {", ".join(group_by_cols)}
313
+ order by ts desc;
314
+ """
315
+
316
+ results = ddb_conn.sql(query).df()
317
+
318
+ group_by_dims = []
319
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
320
+ group_by_dims.extend(["conversation_id", "user_id"])
321
+ series = self.group_query_results_to_numeric_metrics(
322
+ results,
323
+ "count",
324
+ group_by_dims,
325
+ "ts",
326
+ )
247
327
  metric = self.series_to_metric(self.METRIC_NAME, series)
248
328
  return [metric]
249
329
 
250
330
 
251
331
  class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
252
332
  METRIC_NAME = "toxicity_score"
333
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_TOXICITY_SCORE_AGGREGATION_SEGMENTATION"
253
334
 
254
335
  @staticmethod
255
336
  def id() -> UUID:
@@ -296,37 +377,57 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
296
377
  ),
297
378
  ],
298
379
  ) -> list[SketchMetric]:
299
- results = ddb_conn.sql(
300
- f"\
301
- with unnested_prompt_results as (select to_timestamp(created_at / 1000) as ts, \
302
- unnest(inference_prompt.prompt_rule_results) as rule_results, \
303
- 'prompt' as location \
304
- from {dataset.dataset_table_name}), \
305
- unnested_response_results as (select to_timestamp(created_at / 1000) as ts, \
306
- unnest(inference_response.response_rule_results) as rule_results, \
307
- 'response' as location \
308
- from {dataset.dataset_table_name}) \
309
- select ts as timestamp, \
310
- rule_results.details.toxicity_score::DOUBLE as toxicity_score, \
311
- rule_results.result as result, \
312
- location \
313
- from unnested_prompt_results \
314
- where rule_results.details.toxicity_score IS NOT NULL \
315
- UNION ALL \
316
- select ts as timestamp, \
317
- rule_results.details.toxicity_score::DOUBLE as toxicity_score, \
318
- rule_results.result as result, \
319
- location \
320
- from unnested_response_results \
321
- where rule_results.details.toxicity_score IS NOT NULL \
322
- order by ts desc; \
323
- ",
324
- ).df()
380
+ # Build CTE select columns
381
+ prompt_cte_select = [
382
+ "to_timestamp(created_at / 1000) as ts",
383
+ "unnest(inference_prompt.prompt_rule_results) as rule_results",
384
+ "'prompt' as location",
385
+ ]
386
+ response_cte_select = [
387
+ "to_timestamp(created_at / 1000) as ts",
388
+ "unnest(inference_response.response_rule_results) as rule_results",
389
+ "'response' as location",
390
+ ]
391
+
392
+ # Build main select columns
393
+ main_select_cols = [
394
+ "ts as timestamp",
395
+ "rule_results.details.toxicity_score::DOUBLE as toxicity_score",
396
+ "rule_results.result as result",
397
+ "location",
398
+ ]
399
+
400
+ # Conditionally add conversation_id and user_id
401
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
402
+ prompt_cte_select.extend(["conversation_id", "user_id"])
403
+ response_cte_select.extend(["conversation_id", "user_id"])
404
+ main_select_cols.extend(["conversation_id", "user_id"])
405
+
406
+ query = f"""
407
+ with unnested_prompt_results as (select {", ".join(prompt_cte_select)}
408
+ from {dataset.dataset_table_name}),
409
+ unnested_response_results as (select {", ".join(response_cte_select)}
410
+ from {dataset.dataset_table_name})
411
+ select {", ".join(main_select_cols)}
412
+ from unnested_prompt_results
413
+ where rule_results.details.toxicity_score IS NOT NULL
414
+ UNION ALL
415
+ select {", ".join(main_select_cols)}
416
+ from unnested_response_results
417
+ where rule_results.details.toxicity_score IS NOT NULL
418
+ order by ts desc;
419
+ """
420
+
421
+ results = ddb_conn.sql(query).df()
422
+
423
+ group_by_dims = ["result", "location"]
424
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
425
+ group_by_dims.extend(["conversation_id", "user_id"])
325
426
 
326
427
  series = self.group_query_results_to_sketch_metrics(
327
428
  results,
328
429
  "toxicity_score",
329
- ["result", "location"],
430
+ group_by_dims,
330
431
  "timestamp",
331
432
  )
332
433
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -335,6 +436,7 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
335
436
 
336
437
  class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
337
438
  METRIC_NAME = "pii_score"
439
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_PII_DATA_SCORE_AGGREGATION_SEGMENTATION"
338
440
 
339
441
  @staticmethod
340
442
  def id() -> UUID:
@@ -381,43 +483,71 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
381
483
  ),
382
484
  ],
383
485
  ) -> list[SketchMetric]:
384
- results = ddb_conn.sql(
385
- f"\
386
- with unnested_prompt_results as (select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
387
- unnest(inference_prompt.prompt_rule_results) as rule_results, \
388
- 'prompt' as location \
389
- from {dataset.dataset_table_name}), \
390
- unnested_response_results as (select time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
391
- unnest(inference_response.response_rule_results) as rule_results, \
392
- 'response' as location \
393
- from {dataset.dataset_table_name}), \
394
- unnested_entites as (select ts, \
395
- rule_results.result, \
396
- rule_results.rule_type, \
397
- location, \
398
- unnest(rule_results.details.pii_entities) as pii_entity \
399
- from unnested_response_results \
400
- where rule_results.rule_type = 'PIIDataRule' \
401
- \
402
- UNION ALL \
403
- \
404
- select ts, \
405
- rule_results.result, \
406
- rule_results.rule_type, \
407
- location, \
408
- unnest(rule_results.details.pii_entities) as pii_entity \
409
- from unnested_prompt_results \
410
- where rule_results.rule_type = 'PIIDataRule') \
411
- select ts as timestamp, result, rule_type, location, TRY_CAST(pii_entity.confidence AS FLOAT) as pii_score, pii_entity.entity as entity \
412
- from unnested_entites \
413
- order by ts desc; \
414
- ",
415
- ).df()
486
+ # Build CTE select columns
487
+ prompt_cte_select = [
488
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
489
+ "unnest(inference_prompt.prompt_rule_results) as rule_results",
490
+ "'prompt' as location",
491
+ ]
492
+ response_cte_select = [
493
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
494
+ "unnest(inference_response.response_rule_results) as rule_results",
495
+ "'response' as location",
496
+ ]
497
+
498
+ # Build unnested_entities select columns
499
+ entities_select_cols = [
500
+ "ts",
501
+ "rule_results.result",
502
+ "rule_results.rule_type",
503
+ "location",
504
+ "unnest(rule_results.details.pii_entities) as pii_entity",
505
+ ]
506
+
507
+ # Build final select columns
508
+ final_select_cols = [
509
+ "ts as timestamp",
510
+ "result",
511
+ "rule_type",
512
+ "location",
513
+ "TRY_CAST(pii_entity.confidence AS FLOAT) as pii_score",
514
+ "pii_entity.entity as entity",
515
+ ]
516
+
517
+ # Conditionally add conversation_id and user_id
518
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
519
+ prompt_cte_select.extend(["conversation_id", "user_id"])
520
+ response_cte_select.extend(["conversation_id", "user_id"])
521
+ entities_select_cols.extend(["conversation_id", "user_id"])
522
+ final_select_cols.extend(["conversation_id", "user_id"])
523
+
524
+ query = f"""
525
+ with unnested_prompt_results as (select {", ".join(prompt_cte_select)}
526
+ from {dataset.dataset_table_name}),
527
+ unnested_response_results as (select {", ".join(response_cte_select)}
528
+ from {dataset.dataset_table_name}),
529
+ unnested_entites as (select {", ".join(entities_select_cols)}
530
+ from unnested_response_results
531
+ where rule_results.rule_type = 'PIIDataRule'
532
+ UNION ALL
533
+ select {", ".join(entities_select_cols)}
534
+ from unnested_prompt_results
535
+ where rule_results.rule_type = 'PIIDataRule')
536
+ select {", ".join(final_select_cols)}
537
+ from unnested_entites
538
+ order by ts desc;
539
+ """
540
+
541
+ results = ddb_conn.sql(query).df()
542
+
543
+ group_by_dims = ["result", "location", "entity"]
544
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
545
+ group_by_dims.extend(["conversation_id", "user_id"])
416
546
 
417
547
  series = self.group_query_results_to_sketch_metrics(
418
548
  results,
419
549
  "pii_score",
420
- ["result", "location", "entity"],
550
+ group_by_dims,
421
551
  "timestamp",
422
552
  )
423
553
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -426,6 +556,7 @@ order by ts desc;
426
556
 
427
557
  class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
428
558
  METRIC_NAME = "claim_count"
559
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_CLAIM_COUNT_AGGREGATION_SEGMENTATION"
429
560
 
430
561
  @staticmethod
431
562
  def id() -> UUID:
@@ -472,25 +603,44 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
472
603
  ),
473
604
  ],
474
605
  ) -> list[SketchMetric]:
475
- results = ddb_conn.sql(
476
- f"\
477
- with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
478
- unnest(inference_response.response_rule_results) as rule_results \
479
- from {dataset.dataset_table_name}) \
480
- select ts as timestamp, \
481
- length(rule_results.details.claims) as num_claims, \
482
- rule_results.result as result \
483
- from unnested_results \
484
- where rule_results.rule_type = 'ModelHallucinationRuleV2' \
485
- and rule_results.result != 'Skipped' \
486
- order by ts desc; \
487
- ",
488
- ).df()
606
+ # Build CTE select columns
607
+ cte_select = [
608
+ "to_timestamp(created_at / 1000) as ts",
609
+ "unnest(inference_response.response_rule_results) as rule_results",
610
+ ]
611
+
612
+ # Build main select columns
613
+ main_select_cols = [
614
+ "ts as timestamp",
615
+ "length(rule_results.details.claims) as num_claims",
616
+ "rule_results.result as result",
617
+ ]
618
+
619
+ # Conditionally add conversation_id and user_id
620
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
621
+ cte_select.extend(["conversation_id", "user_id"])
622
+ main_select_cols.extend(["conversation_id", "user_id"])
623
+
624
+ query = f"""
625
+ with unnested_results as (select {", ".join(cte_select)}
626
+ from {dataset.dataset_table_name})
627
+ select {", ".join(main_select_cols)}
628
+ from unnested_results
629
+ where rule_results.rule_type = 'ModelHallucinationRuleV2'
630
+ and rule_results.result != 'Skipped'
631
+ order by ts desc;
632
+ """
633
+
634
+ results = ddb_conn.sql(query).df()
635
+
636
+ group_by_dims = ["result"]
637
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
638
+ group_by_dims.extend(["conversation_id", "user_id"])
489
639
 
490
640
  series = self.group_query_results_to_sketch_metrics(
491
641
  results,
492
642
  "num_claims",
493
- ["result"],
643
+ group_by_dims,
494
644
  "timestamp",
495
645
  )
496
646
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -499,6 +649,9 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
499
649
 
500
650
  class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
501
651
  METRIC_NAME = "claim_valid_count"
652
+ FEATURE_FLAG_NAME = (
653
+ "SHIELD_INFERENCE_RULE_CLAIM_PASS_COUNT_AGGREGATION_SEGMENTATION"
654
+ )
502
655
 
503
656
  @staticmethod
504
657
  def id() -> UUID:
@@ -545,25 +698,44 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
545
698
  ),
546
699
  ],
547
700
  ) -> list[SketchMetric]:
548
- results = ddb_conn.sql(
549
- f"\
550
- with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
551
- unnest(inference_response.response_rule_results) as rule_results \
552
- from {dataset.dataset_table_name}) \
553
- select ts as timestamp, \
554
- length(list_filter(rule_results.details.claims, x -> x.valid)) as num_valid_claims, \
555
- rule_results.result as result \
556
- from unnested_results \
557
- where rule_results.rule_type = 'ModelHallucinationRuleV2' \
558
- and rule_results.result != 'Skipped' \
559
- order by ts desc; \
560
- ",
561
- ).df()
701
+ # Build CTE select columns
702
+ cte_select = [
703
+ "to_timestamp(created_at / 1000) as ts",
704
+ "unnest(inference_response.response_rule_results) as rule_results",
705
+ ]
706
+
707
+ # Build main select columns
708
+ main_select_cols = [
709
+ "ts as timestamp",
710
+ "length(list_filter(rule_results.details.claims, x -> x.valid)) as num_valid_claims",
711
+ "rule_results.result as result",
712
+ ]
713
+
714
+ # Conditionally add conversation_id and user_id
715
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
716
+ cte_select.extend(["conversation_id", "user_id"])
717
+ main_select_cols.extend(["conversation_id", "user_id"])
718
+
719
+ query = f"""
720
+ with unnested_results as (select {", ".join(cte_select)}
721
+ from {dataset.dataset_table_name})
722
+ select {", ".join(main_select_cols)}
723
+ from unnested_results
724
+ where rule_results.rule_type = 'ModelHallucinationRuleV2'
725
+ and rule_results.result != 'Skipped'
726
+ order by ts desc;
727
+ """
728
+
729
+ results = ddb_conn.sql(query).df()
730
+
731
+ group_by_dims = ["result"]
732
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
733
+ group_by_dims.extend(["conversation_id", "user_id"])
562
734
 
563
735
  series = self.group_query_results_to_sketch_metrics(
564
736
  results,
565
737
  "num_valid_claims",
566
- ["result"],
738
+ group_by_dims,
567
739
  "timestamp",
568
740
  )
569
741
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -572,6 +744,9 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
572
744
 
573
745
  class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
574
746
  METRIC_NAME = "claim_invalid_count"
747
+ FEATURE_FLAG_NAME = (
748
+ "SHIELD_INFERENCE_RULE_CLAIM_FAIL_COUNT_AGGREGATION_SEGMENTATION"
749
+ )
575
750
 
576
751
  @staticmethod
577
752
  def id() -> UUID:
@@ -618,25 +793,44 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
618
793
  ),
619
794
  ],
620
795
  ) -> list[SketchMetric]:
621
- results = ddb_conn.sql(
622
- f"\
623
- with unnested_results as (select to_timestamp(created_at / 1000) as ts, \
624
- unnest(inference_response.response_rule_results) as rule_results \
625
- from {dataset.dataset_table_name}) \
626
- select ts as timestamp, \
627
- length(list_filter(rule_results.details.claims, x -> not x.valid)) as num_failed_claims, \
628
- rule_results.result as result \
629
- from unnested_results \
630
- where rule_results.rule_type = 'ModelHallucinationRuleV2' \
631
- and rule_results.result != 'Skipped' \
632
- order by ts desc; \
633
- ",
634
- ).df()
796
+ # Build CTE select columns
797
+ cte_select = [
798
+ "to_timestamp(created_at / 1000) as ts",
799
+ "unnest(inference_response.response_rule_results) as rule_results",
800
+ ]
801
+
802
+ # Build main select columns
803
+ main_select_cols = [
804
+ "ts as timestamp",
805
+ "length(list_filter(rule_results.details.claims, x -> not x.valid)) as num_failed_claims",
806
+ "rule_results.result as result",
807
+ ]
808
+
809
+ # Conditionally add conversation_id and user_id
810
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
811
+ cte_select.extend(["conversation_id", "user_id"])
812
+ main_select_cols.extend(["conversation_id", "user_id"])
813
+
814
+ query = f"""
815
+ with unnested_results as (select {", ".join(cte_select)}
816
+ from {dataset.dataset_table_name})
817
+ select {", ".join(main_select_cols)}
818
+ from unnested_results
819
+ where rule_results.rule_type = 'ModelHallucinationRuleV2'
820
+ and rule_results.result != 'Skipped'
821
+ order by ts desc;
822
+ """
823
+
824
+ results = ddb_conn.sql(query).df()
825
+
826
+ group_by_dims = ["result"]
827
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
828
+ group_by_dims.extend(["conversation_id", "user_id"])
635
829
 
636
830
  series = self.group_query_results_to_sketch_metrics(
637
831
  results,
638
832
  "num_failed_claims",
639
- ["result"],
833
+ group_by_dims,
640
834
  "timestamp",
641
835
  )
642
836
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -645,6 +839,7 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
645
839
 
646
840
  class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
647
841
  METRIC_NAME = "rule_latency"
842
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_LATENCY_AGGREGATION_SEGMENTATION"
648
843
 
649
844
  @staticmethod
650
845
  def id() -> UUID:
@@ -691,36 +886,55 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
691
886
  ),
692
887
  ],
693
888
  ) -> list[SketchMetric]:
694
- results = ddb_conn.sql(
695
- f" \
696
- with unnested_prompt_rules as (select unnest(inference_prompt.prompt_rule_results) as rule, \
697
- 'prompt' as location, \
698
- to_timestamp(created_at / 1000) as ts, \
699
- from {dataset.dataset_table_name}), \
700
- unnested_response_rules as (select unnest(inference_response.response_rule_results) as rule,\
701
- 'response' as location, \
702
- to_timestamp(created_at / 1000) as ts, \
703
- from {dataset.dataset_table_name}) \
704
- select ts, \
705
- location, \
706
- rule.rule_type, \
707
- rule.result, \
708
- rule.latency_ms \
709
- from unnested_prompt_rules \
710
- UNION ALL \
711
- select ts, \
712
- location, \
713
- rule.rule_type, \
714
- rule.result, \
715
- rule.latency_ms \
716
- from unnested_response_rules \
717
- ",
718
- ).df()
889
+ # Build CTE select columns
890
+ prompt_cte_select = [
891
+ "unnest(inference_prompt.prompt_rule_results) as rule",
892
+ "'prompt' as location",
893
+ "to_timestamp(created_at / 1000) as ts",
894
+ ]
895
+ response_cte_select = [
896
+ "unnest(inference_response.response_rule_results) as rule",
897
+ "'response' as location",
898
+ "to_timestamp(created_at / 1000) as ts",
899
+ ]
900
+
901
+ # Build main select columns
902
+ main_select_cols = [
903
+ "ts",
904
+ "location",
905
+ "rule.rule_type",
906
+ "rule.result",
907
+ "rule.latency_ms",
908
+ ]
909
+
910
+ # Conditionally add conversation_id and user_id
911
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
912
+ prompt_cte_select.extend(["conversation_id", "user_id"])
913
+ response_cte_select.extend(["conversation_id", "user_id"])
914
+ main_select_cols.extend(["conversation_id", "user_id"])
915
+
916
+ query = f"""
917
+ with unnested_prompt_rules as (select {", ".join(prompt_cte_select)}
918
+ from {dataset.dataset_table_name}),
919
+ unnested_response_rules as (select {", ".join(response_cte_select)}
920
+ from {dataset.dataset_table_name})
921
+ select {", ".join(main_select_cols)}
922
+ from unnested_prompt_rules
923
+ UNION ALL
924
+ select {", ".join(main_select_cols)}
925
+ from unnested_response_rules
926
+ """
927
+
928
+ results = ddb_conn.sql(query).df()
929
+
930
+ group_by_dims = ["result", "rule_type", "location"]
931
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
932
+ group_by_dims.extend(["conversation_id", "user_id"])
719
933
 
720
934
  series = self.group_query_results_to_sketch_metrics(
721
935
  results,
722
936
  "latency_ms",
723
- ["result", "rule_type", "location"],
937
+ group_by_dims,
724
938
  "ts",
725
939
  )
726
940
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -729,6 +943,7 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
729
943
 
730
944
  class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
731
945
  METRIC_NAME = "token_count"
946
+ FEATURE_FLAG_NAME = "SHIELD_INFERENCE_TOKEN_COUNT_AGGREGATION_SEGMENTATION"
732
947
  SUPPORTED_MODELS = [
733
948
  "gpt-4o",
734
949
  "gpt-4o-mini",
@@ -799,28 +1014,52 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
799
1014
  ),
800
1015
  ],
801
1016
  ) -> list[NumericMetric]:
802
- results = ddb_conn.sql(
803
- f" \
804
- select \
805
- time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
806
- COALESCE(sum(inference_prompt.tokens), 0) as tokens, \
807
- 'prompt' as location \
808
- from {dataset.dataset_table_name} \
809
- group by time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)), location \
810
- UNION ALL \
811
- select \
812
- time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts, \
813
- COALESCE(sum(inference_response.tokens), 0) as tokens, \
814
- 'response' as location \
815
- from {dataset.dataset_table_name} \
816
- group by time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)), location; \
817
- ",
818
- ).df()
1017
+ # Build SELECT clause for prompt
1018
+ prompt_select_cols = [
1019
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
1020
+ "COALESCE(sum(inference_prompt.tokens), 0) as tokens",
1021
+ "'prompt' as location",
1022
+ ]
1023
+
1024
+ # Build SELECT clause for response
1025
+ response_select_cols = [
1026
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
1027
+ "COALESCE(sum(inference_response.tokens), 0) as tokens",
1028
+ "'response' as location",
1029
+ ]
1030
+
1031
+ # Build GROUP BY clause
1032
+ group_by_cols = [
1033
+ "time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000))",
1034
+ "location",
1035
+ ]
1036
+
1037
+ # Conditionally add conversation_id and user_id
1038
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
1039
+ prompt_select_cols.extend(["conversation_id", "user_id"])
1040
+ response_select_cols.extend(["conversation_id", "user_id"])
1041
+ group_by_cols.extend(["conversation_id", "user_id"])
1042
+
1043
+ query = f"""
1044
+ select {", ".join(prompt_select_cols)}
1045
+ from {dataset.dataset_table_name}
1046
+ group by {", ".join(group_by_cols)}
1047
+ UNION ALL
1048
+ select {", ".join(response_select_cols)}
1049
+ from {dataset.dataset_table_name}
1050
+ group by {", ".join(group_by_cols)};
1051
+ """
1052
+
1053
+ results = ddb_conn.sql(query).df()
1054
+
1055
+ group_by_dims = ["location"]
1056
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
1057
+ group_by_dims.extend(["conversation_id", "user_id"])
819
1058
 
820
1059
  series = self.group_query_results_to_numeric_metrics(
821
1060
  results,
822
1061
  "tokens",
823
- ["location"],
1062
+ group_by_dims,
824
1063
  "ts",
825
1064
  )
826
1065
  metric = self.series_to_metric(self.METRIC_NAME, series)
@@ -839,18 +1078,21 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
839
1078
  for tokens, loc_type in zip(results["tokens"], location_type)
840
1079
  ]
841
1080
 
842
- model_df = pd.DataFrame(
843
- {
844
- "ts": results["ts"],
845
- "cost": cost_values,
846
- "location": results["location"],
847
- },
848
- )
1081
+ model_df_dict = {
1082
+ "ts": results["ts"],
1083
+ "cost": cost_values,
1084
+ "location": results["location"],
1085
+ }
1086
+ if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
1087
+ model_df_dict["conversation_id"] = results["conversation_id"]
1088
+ model_df_dict["user_id"] = results["user_id"]
1089
+
1090
+ model_df = pd.DataFrame(model_df_dict)
849
1091
 
850
1092
  model_series = self.group_query_results_to_numeric_metrics(
851
1093
  model_df,
852
1094
  "cost",
853
- ["location"],
1095
+ group_by_dims,
854
1096
  "ts",
855
1097
  )
856
1098
  resp.append(