arthur-common 2.1.58__py3-none-any.whl → 2.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arthur_common/aggregations/aggregator.py +73 -9
- arthur_common/aggregations/functions/agentic_aggregations.py +260 -85
- arthur_common/aggregations/functions/categorical_count.py +15 -15
- arthur_common/aggregations/functions/confusion_matrix.py +24 -26
- arthur_common/aggregations/functions/inference_count.py +5 -9
- arthur_common/aggregations/functions/inference_count_by_class.py +16 -27
- arthur_common/aggregations/functions/inference_null_count.py +10 -13
- arthur_common/aggregations/functions/mean_absolute_error.py +12 -18
- arthur_common/aggregations/functions/mean_squared_error.py +12 -18
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +13 -20
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +1 -1
- arthur_common/aggregations/functions/numeric_stats.py +13 -15
- arthur_common/aggregations/functions/numeric_sum.py +12 -15
- arthur_common/aggregations/functions/shield_aggregations.py +457 -215
- arthur_common/models/common_schemas.py +214 -0
- arthur_common/models/connectors.py +10 -2
- arthur_common/models/constants.py +24 -0
- arthur_common/models/datasets.py +0 -9
- arthur_common/models/enums.py +177 -0
- arthur_common/models/metric_schemas.py +63 -0
- arthur_common/models/metrics.py +2 -9
- arthur_common/models/request_schemas.py +870 -0
- arthur_common/models/response_schemas.py +785 -0
- arthur_common/models/schema_definitions.py +6 -1
- arthur_common/models/task_job_specs.py +3 -12
- arthur_common/tools/duckdb_data_loader.py +34 -2
- arthur_common/tools/duckdb_utils.py +3 -6
- arthur_common/tools/schema_inferer.py +3 -6
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/METADATA +12 -4
- arthur_common-2.4.13.dist-info/RECORD +49 -0
- arthur_common/models/shield.py +0 -642
- arthur_common-2.1.58.dist-info/RECORD +0 -44
- {arthur_common-2.1.58.dist-info → arthur_common-2.4.13.dist-info}/WHEEL +0 -0
|
@@ -9,7 +9,7 @@ from arthur_common.aggregations.aggregator import (
|
|
|
9
9
|
NumericAggregationFunction,
|
|
10
10
|
SketchAggregationFunction,
|
|
11
11
|
)
|
|
12
|
-
from arthur_common.models.
|
|
12
|
+
from arthur_common.models.enums import ModelProblemType
|
|
13
13
|
from arthur_common.models.metrics import (
|
|
14
14
|
BaseReportedAggregation,
|
|
15
15
|
DatasetReference,
|
|
@@ -25,6 +25,7 @@ from arthur_common.models.schema_definitions import (
|
|
|
25
25
|
|
|
26
26
|
class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
27
27
|
METRIC_NAME = "inference_count"
|
|
28
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_PASS_FAIL_COUNT_AGGREGATION_SEGMENTATION"
|
|
28
29
|
|
|
29
30
|
@staticmethod
|
|
30
31
|
def id() -> UUID:
|
|
@@ -71,17 +72,41 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
71
72
|
),
|
|
72
73
|
],
|
|
73
74
|
) -> list[NumericMetric]:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
75
|
+
# Build SELECT clause
|
|
76
|
+
select_cols = [
|
|
77
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
78
|
+
"count(*) as count",
|
|
79
|
+
"result",
|
|
80
|
+
"inference_prompt.result AS prompt_result",
|
|
81
|
+
"inference_response.result AS response_result",
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Build GROUP BY clause
|
|
85
|
+
group_by_cols = ["ts", "result", "prompt_result", "response_result"]
|
|
86
|
+
|
|
87
|
+
# Conditionally add conversation_id and user_id based on segmentation flag
|
|
88
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
89
|
+
select_cols.extend(["conversation_id", "user_id as user_id"])
|
|
90
|
+
group_by_cols.extend(["conversation_id", "user_id"])
|
|
91
|
+
|
|
92
|
+
query = f"""
|
|
93
|
+
select {", ".join(select_cols)}
|
|
94
|
+
from {dataset.dataset_table_name}
|
|
95
|
+
group by {", ".join(group_by_cols)}
|
|
96
|
+
order by ts desc;
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
results = ddb_conn.sql(query).df()
|
|
100
|
+
|
|
101
|
+
# Build group_by_dims list
|
|
102
|
+
group_by_dims = [
|
|
103
|
+
"result",
|
|
104
|
+
"prompt_result",
|
|
105
|
+
"response_result",
|
|
106
|
+
]
|
|
107
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
108
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
109
|
+
|
|
85
110
|
series = self.group_query_results_to_numeric_metrics(
|
|
86
111
|
results,
|
|
87
112
|
"count",
|
|
@@ -94,6 +119,7 @@ class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
|
|
|
94
119
|
|
|
95
120
|
class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
96
121
|
METRIC_NAME = "rule_count"
|
|
122
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_COUNT_AGGREGATION_SEGMENTATION"
|
|
97
123
|
|
|
98
124
|
@staticmethod
|
|
99
125
|
def id() -> UUID:
|
|
@@ -140,40 +166,72 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
140
166
|
),
|
|
141
167
|
],
|
|
142
168
|
) -> list[NumericMetric]:
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
",
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
169
|
+
# Build CTE select columns
|
|
170
|
+
prompt_cte_select = [
|
|
171
|
+
"unnest(inference_prompt.prompt_rule_results) as rule",
|
|
172
|
+
"'prompt' as location",
|
|
173
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
174
|
+
]
|
|
175
|
+
response_cte_select = [
|
|
176
|
+
"unnest(inference_response.response_rule_results) as rule",
|
|
177
|
+
"'response' as location",
|
|
178
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
# Build main select columns
|
|
182
|
+
main_select_cols = [
|
|
183
|
+
"ts",
|
|
184
|
+
"count(*) as count",
|
|
185
|
+
"location",
|
|
186
|
+
"rule.rule_type",
|
|
187
|
+
"rule.result",
|
|
188
|
+
"rule.name",
|
|
189
|
+
"rule.id",
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
# Build group by columns
|
|
193
|
+
group_by_cols = [
|
|
194
|
+
"ts",
|
|
195
|
+
"location",
|
|
196
|
+
"rule.rule_type",
|
|
197
|
+
"rule.result",
|
|
198
|
+
"rule.name",
|
|
199
|
+
"rule.id",
|
|
200
|
+
]
|
|
201
|
+
|
|
202
|
+
# Conditionally add conversation_id and user_id
|
|
203
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
204
|
+
prompt_cte_select.extend(["conversation_id", "user_id"])
|
|
205
|
+
response_cte_select.extend(["conversation_id", "user_id"])
|
|
206
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
207
|
+
group_by_cols.extend(["conversation_id", "user_id"])
|
|
208
|
+
|
|
209
|
+
query = f"""
|
|
210
|
+
with unnessted_prompt_rules as (select {", ".join(prompt_cte_select)}
|
|
211
|
+
from {dataset.dataset_table_name}),
|
|
212
|
+
unnessted_result_rules as (select {", ".join(response_cte_select)}
|
|
213
|
+
from {dataset.dataset_table_name})
|
|
214
|
+
select {", ".join(main_select_cols)}
|
|
215
|
+
from unnessted_prompt_rules
|
|
216
|
+
group by {", ".join(group_by_cols)}
|
|
217
|
+
UNION ALL
|
|
218
|
+
select {", ".join(main_select_cols)}
|
|
219
|
+
from unnessted_result_rules
|
|
220
|
+
group by {", ".join(group_by_cols)}
|
|
221
|
+
order by ts desc, location, rule.rule_type, rule.result;
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
results = ddb_conn.sql(query).df()
|
|
225
|
+
|
|
226
|
+
group_by_dims = [
|
|
227
|
+
"location",
|
|
228
|
+
"rule_type",
|
|
229
|
+
"result",
|
|
230
|
+
"name",
|
|
231
|
+
"id",
|
|
232
|
+
]
|
|
233
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
234
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
177
235
|
series = self.group_query_results_to_numeric_metrics(
|
|
178
236
|
results,
|
|
179
237
|
"count",
|
|
@@ -186,6 +244,7 @@ class ShieldInferenceRuleCountAggregation(NumericAggregationFunction):
|
|
|
186
244
|
|
|
187
245
|
class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
188
246
|
METRIC_NAME = "hallucination_count"
|
|
247
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_HALLUCINATION_COUNT_AGGREGATION_SEGMENTATION"
|
|
189
248
|
|
|
190
249
|
@staticmethod
|
|
191
250
|
def id() -> UUID:
|
|
@@ -232,24 +291,46 @@ class ShieldInferenceHallucinationCountAggregation(NumericAggregationFunction):
|
|
|
232
291
|
),
|
|
233
292
|
],
|
|
234
293
|
) -> list[NumericMetric]:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
count(*) as count
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
294
|
+
# Build SELECT clause
|
|
295
|
+
select_cols = [
|
|
296
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
297
|
+
"count(*) as count",
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
# Build GROUP BY clause
|
|
301
|
+
group_by_cols = ["ts"]
|
|
302
|
+
|
|
303
|
+
# Conditionally add conversation_id and user_id
|
|
304
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
305
|
+
select_cols.extend(["conversation_id", "user_id"])
|
|
306
|
+
group_by_cols.extend(["conversation_id", "user_id"])
|
|
307
|
+
|
|
308
|
+
query = f"""
|
|
309
|
+
select {", ".join(select_cols)}
|
|
310
|
+
from {dataset.dataset_table_name}
|
|
311
|
+
where length(list_filter(inference_response.response_rule_results, x -> (x.rule_type = 'ModelHallucinationRuleV2' or x.rule_type = 'ModelHallucinationRule') and x.result = 'Fail')) > 0
|
|
312
|
+
group by {", ".join(group_by_cols)}
|
|
313
|
+
order by ts desc;
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
results = ddb_conn.sql(query).df()
|
|
317
|
+
|
|
318
|
+
group_by_dims = []
|
|
319
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
320
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
321
|
+
series = self.group_query_results_to_numeric_metrics(
|
|
322
|
+
results,
|
|
323
|
+
"count",
|
|
324
|
+
group_by_dims,
|
|
325
|
+
"ts",
|
|
326
|
+
)
|
|
247
327
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
248
328
|
return [metric]
|
|
249
329
|
|
|
250
330
|
|
|
251
331
|
class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
252
332
|
METRIC_NAME = "toxicity_score"
|
|
333
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_TOXICITY_SCORE_AGGREGATION_SEGMENTATION"
|
|
253
334
|
|
|
254
335
|
@staticmethod
|
|
255
336
|
def id() -> UUID:
|
|
@@ -296,37 +377,57 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
296
377
|
),
|
|
297
378
|
],
|
|
298
379
|
) -> list[SketchMetric]:
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
",
|
|
324
|
-
|
|
380
|
+
# Build CTE select columns
|
|
381
|
+
prompt_cte_select = [
|
|
382
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
383
|
+
"unnest(inference_prompt.prompt_rule_results) as rule_results",
|
|
384
|
+
"'prompt' as location",
|
|
385
|
+
]
|
|
386
|
+
response_cte_select = [
|
|
387
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
388
|
+
"unnest(inference_response.response_rule_results) as rule_results",
|
|
389
|
+
"'response' as location",
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
# Build main select columns
|
|
393
|
+
main_select_cols = [
|
|
394
|
+
"ts as timestamp",
|
|
395
|
+
"rule_results.details.toxicity_score::DOUBLE as toxicity_score",
|
|
396
|
+
"rule_results.result as result",
|
|
397
|
+
"location",
|
|
398
|
+
]
|
|
399
|
+
|
|
400
|
+
# Conditionally add conversation_id and user_id
|
|
401
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
402
|
+
prompt_cte_select.extend(["conversation_id", "user_id"])
|
|
403
|
+
response_cte_select.extend(["conversation_id", "user_id"])
|
|
404
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
405
|
+
|
|
406
|
+
query = f"""
|
|
407
|
+
with unnested_prompt_results as (select {", ".join(prompt_cte_select)}
|
|
408
|
+
from {dataset.dataset_table_name}),
|
|
409
|
+
unnested_response_results as (select {", ".join(response_cte_select)}
|
|
410
|
+
from {dataset.dataset_table_name})
|
|
411
|
+
select {", ".join(main_select_cols)}
|
|
412
|
+
from unnested_prompt_results
|
|
413
|
+
where rule_results.details.toxicity_score IS NOT NULL
|
|
414
|
+
UNION ALL
|
|
415
|
+
select {", ".join(main_select_cols)}
|
|
416
|
+
from unnested_response_results
|
|
417
|
+
where rule_results.details.toxicity_score IS NOT NULL
|
|
418
|
+
order by ts desc;
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
results = ddb_conn.sql(query).df()
|
|
422
|
+
|
|
423
|
+
group_by_dims = ["result", "location"]
|
|
424
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
425
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
325
426
|
|
|
326
427
|
series = self.group_query_results_to_sketch_metrics(
|
|
327
428
|
results,
|
|
328
429
|
"toxicity_score",
|
|
329
|
-
|
|
430
|
+
group_by_dims,
|
|
330
431
|
"timestamp",
|
|
331
432
|
)
|
|
332
433
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -335,6 +436,7 @@ class ShieldInferenceRuleToxicityScoreAggregation(SketchAggregationFunction):
|
|
|
335
436
|
|
|
336
437
|
class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
337
438
|
METRIC_NAME = "pii_score"
|
|
439
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_PII_DATA_SCORE_AGGREGATION_SEGMENTATION"
|
|
338
440
|
|
|
339
441
|
@staticmethod
|
|
340
442
|
def id() -> UUID:
|
|
@@ -381,43 +483,71 @@ class ShieldInferenceRulePIIDataScoreAggregation(SketchAggregationFunction):
|
|
|
381
483
|
),
|
|
382
484
|
],
|
|
383
485
|
) -> list[SketchMetric]:
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
486
|
+
# Build CTE select columns
|
|
487
|
+
prompt_cte_select = [
|
|
488
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
489
|
+
"unnest(inference_prompt.prompt_rule_results) as rule_results",
|
|
490
|
+
"'prompt' as location",
|
|
491
|
+
]
|
|
492
|
+
response_cte_select = [
|
|
493
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
494
|
+
"unnest(inference_response.response_rule_results) as rule_results",
|
|
495
|
+
"'response' as location",
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
# Build unnested_entities select columns
|
|
499
|
+
entities_select_cols = [
|
|
500
|
+
"ts",
|
|
501
|
+
"rule_results.result",
|
|
502
|
+
"rule_results.rule_type",
|
|
503
|
+
"location",
|
|
504
|
+
"unnest(rule_results.details.pii_entities) as pii_entity",
|
|
505
|
+
]
|
|
506
|
+
|
|
507
|
+
# Build final select columns
|
|
508
|
+
final_select_cols = [
|
|
509
|
+
"ts as timestamp",
|
|
510
|
+
"result",
|
|
511
|
+
"rule_type",
|
|
512
|
+
"location",
|
|
513
|
+
"TRY_CAST(pii_entity.confidence AS FLOAT) as pii_score",
|
|
514
|
+
"pii_entity.entity as entity",
|
|
515
|
+
]
|
|
516
|
+
|
|
517
|
+
# Conditionally add conversation_id and user_id
|
|
518
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
519
|
+
prompt_cte_select.extend(["conversation_id", "user_id"])
|
|
520
|
+
response_cte_select.extend(["conversation_id", "user_id"])
|
|
521
|
+
entities_select_cols.extend(["conversation_id", "user_id"])
|
|
522
|
+
final_select_cols.extend(["conversation_id", "user_id"])
|
|
523
|
+
|
|
524
|
+
query = f"""
|
|
525
|
+
with unnested_prompt_results as (select {", ".join(prompt_cte_select)}
|
|
526
|
+
from {dataset.dataset_table_name}),
|
|
527
|
+
unnested_response_results as (select {", ".join(response_cte_select)}
|
|
528
|
+
from {dataset.dataset_table_name}),
|
|
529
|
+
unnested_entites as (select {", ".join(entities_select_cols)}
|
|
530
|
+
from unnested_response_results
|
|
531
|
+
where rule_results.rule_type = 'PIIDataRule'
|
|
532
|
+
UNION ALL
|
|
533
|
+
select {", ".join(entities_select_cols)}
|
|
534
|
+
from unnested_prompt_results
|
|
535
|
+
where rule_results.rule_type = 'PIIDataRule')
|
|
536
|
+
select {", ".join(final_select_cols)}
|
|
537
|
+
from unnested_entites
|
|
538
|
+
order by ts desc;
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
results = ddb_conn.sql(query).df()
|
|
542
|
+
|
|
543
|
+
group_by_dims = ["result", "location", "entity"]
|
|
544
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
545
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
416
546
|
|
|
417
547
|
series = self.group_query_results_to_sketch_metrics(
|
|
418
548
|
results,
|
|
419
549
|
"pii_score",
|
|
420
|
-
|
|
550
|
+
group_by_dims,
|
|
421
551
|
"timestamp",
|
|
422
552
|
)
|
|
423
553
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -426,6 +556,7 @@ order by ts desc;
|
|
|
426
556
|
|
|
427
557
|
class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
428
558
|
METRIC_NAME = "claim_count"
|
|
559
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_CLAIM_COUNT_AGGREGATION_SEGMENTATION"
|
|
429
560
|
|
|
430
561
|
@staticmethod
|
|
431
562
|
def id() -> UUID:
|
|
@@ -472,25 +603,44 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
472
603
|
),
|
|
473
604
|
],
|
|
474
605
|
) -> list[SketchMetric]:
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
606
|
+
# Build CTE select columns
|
|
607
|
+
cte_select = [
|
|
608
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
609
|
+
"unnest(inference_response.response_rule_results) as rule_results",
|
|
610
|
+
]
|
|
611
|
+
|
|
612
|
+
# Build main select columns
|
|
613
|
+
main_select_cols = [
|
|
614
|
+
"ts as timestamp",
|
|
615
|
+
"length(rule_results.details.claims) as num_claims",
|
|
616
|
+
"rule_results.result as result",
|
|
617
|
+
]
|
|
618
|
+
|
|
619
|
+
# Conditionally add conversation_id and user_id
|
|
620
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
621
|
+
cte_select.extend(["conversation_id", "user_id"])
|
|
622
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
623
|
+
|
|
624
|
+
query = f"""
|
|
625
|
+
with unnested_results as (select {", ".join(cte_select)}
|
|
626
|
+
from {dataset.dataset_table_name})
|
|
627
|
+
select {", ".join(main_select_cols)}
|
|
628
|
+
from unnested_results
|
|
629
|
+
where rule_results.rule_type = 'ModelHallucinationRuleV2'
|
|
630
|
+
and rule_results.result != 'Skipped'
|
|
631
|
+
order by ts desc;
|
|
632
|
+
"""
|
|
633
|
+
|
|
634
|
+
results = ddb_conn.sql(query).df()
|
|
635
|
+
|
|
636
|
+
group_by_dims = ["result"]
|
|
637
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
638
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
489
639
|
|
|
490
640
|
series = self.group_query_results_to_sketch_metrics(
|
|
491
641
|
results,
|
|
492
642
|
"num_claims",
|
|
493
|
-
|
|
643
|
+
group_by_dims,
|
|
494
644
|
"timestamp",
|
|
495
645
|
)
|
|
496
646
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -499,6 +649,9 @@ class ShieldInferenceRuleClaimCountAggregation(SketchAggregationFunction):
|
|
|
499
649
|
|
|
500
650
|
class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
501
651
|
METRIC_NAME = "claim_valid_count"
|
|
652
|
+
FEATURE_FLAG_NAME = (
|
|
653
|
+
"SHIELD_INFERENCE_RULE_CLAIM_PASS_COUNT_AGGREGATION_SEGMENTATION"
|
|
654
|
+
)
|
|
502
655
|
|
|
503
656
|
@staticmethod
|
|
504
657
|
def id() -> UUID:
|
|
@@ -545,25 +698,44 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
545
698
|
),
|
|
546
699
|
],
|
|
547
700
|
) -> list[SketchMetric]:
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
701
|
+
# Build CTE select columns
|
|
702
|
+
cte_select = [
|
|
703
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
704
|
+
"unnest(inference_response.response_rule_results) as rule_results",
|
|
705
|
+
]
|
|
706
|
+
|
|
707
|
+
# Build main select columns
|
|
708
|
+
main_select_cols = [
|
|
709
|
+
"ts as timestamp",
|
|
710
|
+
"length(list_filter(rule_results.details.claims, x -> x.valid)) as num_valid_claims",
|
|
711
|
+
"rule_results.result as result",
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
# Conditionally add conversation_id and user_id
|
|
715
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
716
|
+
cte_select.extend(["conversation_id", "user_id"])
|
|
717
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
718
|
+
|
|
719
|
+
query = f"""
|
|
720
|
+
with unnested_results as (select {", ".join(cte_select)}
|
|
721
|
+
from {dataset.dataset_table_name})
|
|
722
|
+
select {", ".join(main_select_cols)}
|
|
723
|
+
from unnested_results
|
|
724
|
+
where rule_results.rule_type = 'ModelHallucinationRuleV2'
|
|
725
|
+
and rule_results.result != 'Skipped'
|
|
726
|
+
order by ts desc;
|
|
727
|
+
"""
|
|
728
|
+
|
|
729
|
+
results = ddb_conn.sql(query).df()
|
|
730
|
+
|
|
731
|
+
group_by_dims = ["result"]
|
|
732
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
733
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
562
734
|
|
|
563
735
|
series = self.group_query_results_to_sketch_metrics(
|
|
564
736
|
results,
|
|
565
737
|
"num_valid_claims",
|
|
566
|
-
|
|
738
|
+
group_by_dims,
|
|
567
739
|
"timestamp",
|
|
568
740
|
)
|
|
569
741
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -572,6 +744,9 @@ class ShieldInferenceRuleClaimPassCountAggregation(SketchAggregationFunction):
|
|
|
572
744
|
|
|
573
745
|
class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
574
746
|
METRIC_NAME = "claim_invalid_count"
|
|
747
|
+
FEATURE_FLAG_NAME = (
|
|
748
|
+
"SHIELD_INFERENCE_RULE_CLAIM_FAIL_COUNT_AGGREGATION_SEGMENTATION"
|
|
749
|
+
)
|
|
575
750
|
|
|
576
751
|
@staticmethod
|
|
577
752
|
def id() -> UUID:
|
|
@@ -618,25 +793,44 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
618
793
|
),
|
|
619
794
|
],
|
|
620
795
|
) -> list[SketchMetric]:
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
796
|
+
# Build CTE select columns
|
|
797
|
+
cte_select = [
|
|
798
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
799
|
+
"unnest(inference_response.response_rule_results) as rule_results",
|
|
800
|
+
]
|
|
801
|
+
|
|
802
|
+
# Build main select columns
|
|
803
|
+
main_select_cols = [
|
|
804
|
+
"ts as timestamp",
|
|
805
|
+
"length(list_filter(rule_results.details.claims, x -> not x.valid)) as num_failed_claims",
|
|
806
|
+
"rule_results.result as result",
|
|
807
|
+
]
|
|
808
|
+
|
|
809
|
+
# Conditionally add conversation_id and user_id
|
|
810
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
811
|
+
cte_select.extend(["conversation_id", "user_id"])
|
|
812
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
813
|
+
|
|
814
|
+
query = f"""
|
|
815
|
+
with unnested_results as (select {", ".join(cte_select)}
|
|
816
|
+
from {dataset.dataset_table_name})
|
|
817
|
+
select {", ".join(main_select_cols)}
|
|
818
|
+
from unnested_results
|
|
819
|
+
where rule_results.rule_type = 'ModelHallucinationRuleV2'
|
|
820
|
+
and rule_results.result != 'Skipped'
|
|
821
|
+
order by ts desc;
|
|
822
|
+
"""
|
|
823
|
+
|
|
824
|
+
results = ddb_conn.sql(query).df()
|
|
825
|
+
|
|
826
|
+
group_by_dims = ["result"]
|
|
827
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
828
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
635
829
|
|
|
636
830
|
series = self.group_query_results_to_sketch_metrics(
|
|
637
831
|
results,
|
|
638
832
|
"num_failed_claims",
|
|
639
|
-
|
|
833
|
+
group_by_dims,
|
|
640
834
|
"timestamp",
|
|
641
835
|
)
|
|
642
836
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -645,6 +839,7 @@ class ShieldInferenceRuleClaimFailCountAggregation(SketchAggregationFunction):
|
|
|
645
839
|
|
|
646
840
|
class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
647
841
|
METRIC_NAME = "rule_latency"
|
|
842
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_RULE_LATENCY_AGGREGATION_SEGMENTATION"
|
|
648
843
|
|
|
649
844
|
@staticmethod
|
|
650
845
|
def id() -> UUID:
|
|
@@ -691,36 +886,55 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
691
886
|
),
|
|
692
887
|
],
|
|
693
888
|
) -> list[SketchMetric]:
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
",
|
|
718
|
-
|
|
889
|
+
# Build CTE select columns
|
|
890
|
+
prompt_cte_select = [
|
|
891
|
+
"unnest(inference_prompt.prompt_rule_results) as rule",
|
|
892
|
+
"'prompt' as location",
|
|
893
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
894
|
+
]
|
|
895
|
+
response_cte_select = [
|
|
896
|
+
"unnest(inference_response.response_rule_results) as rule",
|
|
897
|
+
"'response' as location",
|
|
898
|
+
"to_timestamp(created_at / 1000) as ts",
|
|
899
|
+
]
|
|
900
|
+
|
|
901
|
+
# Build main select columns
|
|
902
|
+
main_select_cols = [
|
|
903
|
+
"ts",
|
|
904
|
+
"location",
|
|
905
|
+
"rule.rule_type",
|
|
906
|
+
"rule.result",
|
|
907
|
+
"rule.latency_ms",
|
|
908
|
+
]
|
|
909
|
+
|
|
910
|
+
# Conditionally add conversation_id and user_id
|
|
911
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
912
|
+
prompt_cte_select.extend(["conversation_id", "user_id"])
|
|
913
|
+
response_cte_select.extend(["conversation_id", "user_id"])
|
|
914
|
+
main_select_cols.extend(["conversation_id", "user_id"])
|
|
915
|
+
|
|
916
|
+
query = f"""
|
|
917
|
+
with unnested_prompt_rules as (select {", ".join(prompt_cte_select)}
|
|
918
|
+
from {dataset.dataset_table_name}),
|
|
919
|
+
unnested_response_rules as (select {", ".join(response_cte_select)}
|
|
920
|
+
from {dataset.dataset_table_name})
|
|
921
|
+
select {", ".join(main_select_cols)}
|
|
922
|
+
from unnested_prompt_rules
|
|
923
|
+
UNION ALL
|
|
924
|
+
select {", ".join(main_select_cols)}
|
|
925
|
+
from unnested_response_rules
|
|
926
|
+
"""
|
|
927
|
+
|
|
928
|
+
results = ddb_conn.sql(query).df()
|
|
929
|
+
|
|
930
|
+
group_by_dims = ["result", "rule_type", "location"]
|
|
931
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
932
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
719
933
|
|
|
720
934
|
series = self.group_query_results_to_sketch_metrics(
|
|
721
935
|
results,
|
|
722
936
|
"latency_ms",
|
|
723
|
-
|
|
937
|
+
group_by_dims,
|
|
724
938
|
"ts",
|
|
725
939
|
)
|
|
726
940
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -729,6 +943,7 @@ class ShieldInferenceRuleLatencyAggregation(SketchAggregationFunction):
|
|
|
729
943
|
|
|
730
944
|
class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
731
945
|
METRIC_NAME = "token_count"
|
|
946
|
+
FEATURE_FLAG_NAME = "SHIELD_INFERENCE_TOKEN_COUNT_AGGREGATION_SEGMENTATION"
|
|
732
947
|
SUPPORTED_MODELS = [
|
|
733
948
|
"gpt-4o",
|
|
734
949
|
"gpt-4o-mini",
|
|
@@ -799,28 +1014,52 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
799
1014
|
),
|
|
800
1015
|
],
|
|
801
1016
|
) -> list[NumericMetric]:
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
1017
|
+
# Build SELECT clause for prompt
|
|
1018
|
+
prompt_select_cols = [
|
|
1019
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
1020
|
+
"COALESCE(sum(inference_prompt.tokens), 0) as tokens",
|
|
1021
|
+
"'prompt' as location",
|
|
1022
|
+
]
|
|
1023
|
+
|
|
1024
|
+
# Build SELECT clause for response
|
|
1025
|
+
response_select_cols = [
|
|
1026
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000)) as ts",
|
|
1027
|
+
"COALESCE(sum(inference_response.tokens), 0) as tokens",
|
|
1028
|
+
"'response' as location",
|
|
1029
|
+
]
|
|
1030
|
+
|
|
1031
|
+
# Build GROUP BY clause
|
|
1032
|
+
group_by_cols = [
|
|
1033
|
+
"time_bucket(INTERVAL '5 minutes', to_timestamp(created_at / 1000))",
|
|
1034
|
+
"location",
|
|
1035
|
+
]
|
|
1036
|
+
|
|
1037
|
+
# Conditionally add conversation_id and user_id
|
|
1038
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
1039
|
+
prompt_select_cols.extend(["conversation_id", "user_id"])
|
|
1040
|
+
response_select_cols.extend(["conversation_id", "user_id"])
|
|
1041
|
+
group_by_cols.extend(["conversation_id", "user_id"])
|
|
1042
|
+
|
|
1043
|
+
query = f"""
|
|
1044
|
+
select {", ".join(prompt_select_cols)}
|
|
1045
|
+
from {dataset.dataset_table_name}
|
|
1046
|
+
group by {", ".join(group_by_cols)}
|
|
1047
|
+
UNION ALL
|
|
1048
|
+
select {", ".join(response_select_cols)}
|
|
1049
|
+
from {dataset.dataset_table_name}
|
|
1050
|
+
group by {", ".join(group_by_cols)};
|
|
1051
|
+
"""
|
|
1052
|
+
|
|
1053
|
+
results = ddb_conn.sql(query).df()
|
|
1054
|
+
|
|
1055
|
+
group_by_dims = ["location"]
|
|
1056
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
1057
|
+
group_by_dims.extend(["conversation_id", "user_id"])
|
|
819
1058
|
|
|
820
1059
|
series = self.group_query_results_to_numeric_metrics(
|
|
821
1060
|
results,
|
|
822
1061
|
"tokens",
|
|
823
|
-
|
|
1062
|
+
group_by_dims,
|
|
824
1063
|
"ts",
|
|
825
1064
|
)
|
|
826
1065
|
metric = self.series_to_metric(self.METRIC_NAME, series)
|
|
@@ -839,18 +1078,21 @@ class ShieldInferenceTokenCountAggregation(NumericAggregationFunction):
|
|
|
839
1078
|
for tokens, loc_type in zip(results["tokens"], location_type)
|
|
840
1079
|
]
|
|
841
1080
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
1081
|
+
model_df_dict = {
|
|
1082
|
+
"ts": results["ts"],
|
|
1083
|
+
"cost": cost_values,
|
|
1084
|
+
"location": results["location"],
|
|
1085
|
+
}
|
|
1086
|
+
if self.is_feature_flag_enabled(self.FEATURE_FLAG_NAME):
|
|
1087
|
+
model_df_dict["conversation_id"] = results["conversation_id"]
|
|
1088
|
+
model_df_dict["user_id"] = results["user_id"]
|
|
1089
|
+
|
|
1090
|
+
model_df = pd.DataFrame(model_df_dict)
|
|
849
1091
|
|
|
850
1092
|
model_series = self.group_query_results_to_numeric_metrics(
|
|
851
1093
|
model_df,
|
|
852
1094
|
"cost",
|
|
853
|
-
|
|
1095
|
+
group_by_dims,
|
|
854
1096
|
"ts",
|
|
855
1097
|
)
|
|
856
1098
|
resp.append(
|