arthur-common 1.0.1__py3-none-any.whl → 2.1.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
7
7
  DType,
8
8
  MetricColumnParameterAnnotation,
9
9
  MetricDatasetParameterAnnotation,
10
+ MetricMultipleColumnParameterAnnotation,
10
11
  ScalarType,
11
12
  ScopeSchemaTag,
12
13
  )
@@ -60,23 +61,62 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
60
61
  description="A column containing nullable values to count.",
61
62
  ),
62
63
  ],
64
+ segmentation_cols: Annotated[
65
+ Optional[list[str]],
66
+ MetricMultipleColumnParameterAnnotation(
67
+ source_dataset_parameter_key="dataset",
68
+ allowed_column_types=[
69
+ ScalarType(dtype=DType.INT),
70
+ ScalarType(dtype=DType.BOOL),
71
+ ScalarType(dtype=DType.STRING),
72
+ ScalarType(dtype=DType.UUID),
73
+ ],
74
+ tag_hints=[],
75
+ friendly_name="Segmentation Columns",
76
+ description="All columns to include as dimensions for segmentation.",
77
+ optional=True,
78
+ ),
79
+ ] = None,
63
80
  ) -> list[NumericMetric]:
81
+ """Executed SQL with no segmentation columns:
82
+ select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
83
+ count(*) as count \
84
+ from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
85
+ group by ts \
86
+ """
87
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
64
88
  escaped_timestamp_col = escape_identifier(timestamp_col)
65
89
  escaped_nullable_col = escape_identifier(nullable_col)
66
- count_query = f" \
67
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
68
- count(*) as count \
69
- from {dataset.dataset_table_name} where {escaped_nullable_col} is null \
70
- group by ts \
71
- "
90
+
91
+ # build query components with segmentation columns
92
+ escaped_segmentation_cols = [
93
+ escape_identifier(col) for col in segmentation_cols
94
+ ]
95
+ all_select_clause_cols = [
96
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
97
+ f"count(*) as count",
98
+ ] + escaped_segmentation_cols
99
+ all_group_by_cols = ["ts"] + escaped_segmentation_cols
100
+
101
+ # build query
102
+ count_query = f"""
103
+ select {", ".join(all_select_clause_cols)}
104
+ from {dataset.dataset_table_name}
105
+ where {escaped_nullable_col} is null
106
+ group by {", ".join(all_group_by_cols)}
107
+ """
108
+
72
109
  results = ddb_conn.sql(count_query).df()
73
110
 
74
- series = self.dimensionless_query_results_to_numeric_metrics(
111
+ series = self.group_query_results_to_numeric_metrics(
75
112
  results,
76
113
  "count",
114
+ segmentation_cols,
77
115
  "ts",
78
116
  )
79
- series.dimensions = [Dimension(name="column_name", value=nullable_col)]
117
+ # preserve dimension that identifies the name of the nullable column used for the aggregation
118
+ for point in series:
119
+ point.dimensions.append(Dimension(name="column_name", value=nullable_col))
80
120
 
81
- metric = self.series_to_metric(self.METRIC_NAME, [series])
121
+ metric = self.series_to_metric(self.METRIC_NAME, series)
82
122
  return [metric]
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -8,6 +8,7 @@ from arthur_common.models.schema_definitions import (
8
8
  DType,
9
9
  MetricColumnParameterAnnotation,
10
10
  MetricDatasetParameterAnnotation,
11
+ MetricMultipleColumnParameterAnnotation,
11
12
  ScalarType,
12
13
  ScopeSchemaTag,
13
14
  )
@@ -75,36 +76,75 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
75
76
  description="A column containing float typed ground truth values.",
76
77
  ),
77
78
  ],
79
+ segmentation_cols: Annotated[
80
+ Optional[list[str]],
81
+ MetricMultipleColumnParameterAnnotation(
82
+ source_dataset_parameter_key="dataset",
83
+ allowed_column_types=[
84
+ ScalarType(dtype=DType.INT),
85
+ ScalarType(dtype=DType.BOOL),
86
+ ScalarType(dtype=DType.STRING),
87
+ ScalarType(dtype=DType.UUID),
88
+ ],
89
+ tag_hints=[],
90
+ friendly_name="Segmentation Columns",
91
+ description="All columns to include as dimensions for segmentation.",
92
+ optional=True,
93
+ ),
94
+ ] = None,
78
95
  ) -> list[NumericMetric]:
96
+ """Executed SQL with no segmentation columns:
97
+ SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
98
+ SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
99
+ COUNT(*) as count \
100
+ FROM {dataset.dataset_table_name} \
101
+ WHERE {escaped_prediction_col} IS NOT NULL \
102
+ AND {escaped_ground_truth_col} IS NOT NULL \
103
+ GROUP BY ts order by ts desc \
104
+ """
105
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
79
106
  escaped_timestamp_col = escape_identifier(timestamp_col)
80
107
  escaped_prediction_col = escape_identifier(prediction_col)
81
108
  escaped_ground_truth_col = escape_identifier(ground_truth_col)
82
- count_query = f" \
83
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
84
- SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae, \
85
- COUNT(*) as count \
86
- FROM {dataset.dataset_table_name} \
87
- WHERE {escaped_prediction_col} IS NOT NULL \
88
- AND {escaped_ground_truth_col} IS NOT NULL \
89
- GROUP BY ts order by ts desc \
90
- "
91
109
 
92
- results = ddb_conn.sql(count_query).df()
93
- count_series = self.dimensionless_query_results_to_numeric_metrics(
110
+ # build query components with segmentation columns
111
+ escaped_segmentation_cols = [
112
+ escape_identifier(col) for col in segmentation_cols
113
+ ]
114
+ all_select_clause_cols = [
115
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
116
+ f"SUM(ABS({escaped_prediction_col} - {escaped_ground_truth_col})) as ae",
117
+ f"COUNT(*) as count",
118
+ ] + escaped_segmentation_cols
119
+ all_group_by_cols = ["ts"] + escaped_segmentation_cols
120
+
121
+ # build query
122
+ mae_query = f"""
123
+ SELECT {", ".join(all_select_clause_cols)}
124
+ FROM {dataset.dataset_table_name}
125
+ WHERE {escaped_prediction_col} IS NOT NULL
126
+ AND {escaped_ground_truth_col} IS NOT NULL
127
+ GROUP BY {", ".join(all_group_by_cols)} order by ts desc
128
+ """
129
+
130
+ results = ddb_conn.sql(mae_query).df()
131
+ count_series = self.group_query_results_to_numeric_metrics(
94
132
  results,
95
133
  "count",
134
+ segmentation_cols,
96
135
  "ts",
97
136
  )
98
- absolute_error_series = self.dimensionless_query_results_to_numeric_metrics(
137
+ absolute_error_series = self.group_query_results_to_numeric_metrics(
99
138
  results,
100
139
  "ae",
140
+ segmentation_cols,
101
141
  "ts",
102
142
  )
103
143
 
104
- count_metric = self.series_to_metric("absolute_error_count", [count_series])
144
+ count_metric = self.series_to_metric("absolute_error_count", count_series)
105
145
  absolute_error_metric = self.series_to_metric(
106
146
  "absolute_error_sum",
107
- [absolute_error_series],
147
+ absolute_error_series,
108
148
  )
109
149
 
110
150
  return [count_metric, absolute_error_metric]
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -8,6 +8,7 @@ from arthur_common.models.schema_definitions import (
8
8
  DType,
9
9
  MetricColumnParameterAnnotation,
10
10
  MetricDatasetParameterAnnotation,
11
+ MetricMultipleColumnParameterAnnotation,
11
12
  ScalarType,
12
13
  ScopeSchemaTag,
13
14
  )
@@ -75,36 +76,75 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
75
76
  description="A column containing float typed ground truth values.",
76
77
  ),
77
78
  ],
79
+ segmentation_cols: Annotated[
80
+ Optional[list[str]],
81
+ MetricMultipleColumnParameterAnnotation(
82
+ source_dataset_parameter_key="dataset",
83
+ allowed_column_types=[
84
+ ScalarType(dtype=DType.INT),
85
+ ScalarType(dtype=DType.BOOL),
86
+ ScalarType(dtype=DType.STRING),
87
+ ScalarType(dtype=DType.UUID),
88
+ ],
89
+ tag_hints=[],
90
+ friendly_name="Segmentation Columns",
91
+ description="All columns to include as dimensions for segmentation.",
92
+ optional=True,
93
+ ),
94
+ ] = None,
78
95
  ) -> list[NumericMetric]:
96
+ """Executed SQL with no segmentation columns:
97
+ SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
98
+ SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
99
+ COUNT(*) as count \
100
+ FROM {dataset.dataset_table_name} \
101
+ WHERE {escaped_prediction_col} IS NOT NULL \
102
+ AND {escaped_ground_truth_col} IS NOT NULL \
103
+ GROUP BY ts order by ts desc \
104
+ """
105
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
79
106
  escaped_timestamp_col = escape_identifier(timestamp_col)
80
107
  escaped_prediction_col = escape_identifier(prediction_col)
81
108
  escaped_ground_truth_col = escape_identifier(ground_truth_col)
82
- count_query = f" \
83
- SELECT time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
84
- SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error, \
85
- COUNT(*) as count \
86
- FROM {dataset.dataset_table_name} \
87
- WHERE {escaped_prediction_col} IS NOT NULL \
88
- AND {escaped_ground_truth_col} IS NOT NULL \
89
- GROUP BY ts order by ts desc \
90
- "
91
109
 
92
- results = ddb_conn.sql(count_query).df()
93
- count_series = self.dimensionless_query_results_to_numeric_metrics(
110
+ # build query components with segmentation columns
111
+ escaped_segmentation_cols = [
112
+ escape_identifier(col) for col in segmentation_cols
113
+ ]
114
+ all_select_clause_cols = [
115
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
116
+ f"SUM(POW({escaped_prediction_col} - {escaped_ground_truth_col}, 2)) as squared_error",
117
+ f"COUNT(*) as count",
118
+ ] + escaped_segmentation_cols
119
+ all_group_by_cols = ["ts"] + escaped_segmentation_cols
120
+
121
+ # build query
122
+ mse_query = f"""
123
+ SELECT {", ".join(all_select_clause_cols)}
124
+ FROM {dataset.dataset_table_name}
125
+ WHERE {escaped_prediction_col} IS NOT NULL
126
+ AND {escaped_ground_truth_col} IS NOT NULL
127
+ GROUP BY {", ".join(all_group_by_cols)} order by ts desc
128
+ """
129
+
130
+ results = ddb_conn.sql(mse_query).df()
131
+ count_series = self.group_query_results_to_numeric_metrics(
94
132
  results,
95
133
  "count",
134
+ segmentation_cols,
96
135
  "ts",
97
136
  )
98
- squared_error_series = self.dimensionless_query_results_to_numeric_metrics(
137
+ squared_error_series = self.group_query_results_to_numeric_metrics(
99
138
  results,
100
139
  "squared_error",
140
+ segmentation_cols,
101
141
  "ts",
102
142
  )
103
143
 
104
- count_metric = self.series_to_metric("squared_error_count", [count_series])
144
+ count_metric = self.series_to_metric("squared_error_count", count_series)
105
145
  absolute_error_metric = self.series_to_metric(
106
146
  "squared_error_sum",
107
- [squared_error_series],
147
+ squared_error_series,
108
148
  )
109
149
 
110
150
  return [count_metric, absolute_error_metric]
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -9,6 +9,7 @@ from arthur_common.models.schema_definitions import (
9
9
  MetricColumnParameterAnnotation,
10
10
  MetricDatasetParameterAnnotation,
11
11
  MetricLiteralParameterAnnotation,
12
+ MetricMultipleColumnParameterAnnotation,
12
13
  ScalarType,
13
14
  ScopeSchemaTag,
14
15
  )
@@ -90,7 +91,24 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
90
91
  description="The label indicating a positive class.",
91
92
  ),
92
93
  ],
94
+ segmentation_cols: Annotated[
95
+ Optional[list[str]],
96
+ MetricMultipleColumnParameterAnnotation(
97
+ source_dataset_parameter_key="dataset",
98
+ allowed_column_types=[
99
+ ScalarType(dtype=DType.INT),
100
+ ScalarType(dtype=DType.BOOL),
101
+ ScalarType(dtype=DType.STRING),
102
+ ScalarType(dtype=DType.UUID),
103
+ ],
104
+ tag_hints=[],
105
+ friendly_name="Segmentation Columns",
106
+ description="All columns to include as dimensions for segmentation.",
107
+ optional=True,
108
+ ),
109
+ ] = None,
93
110
  ) -> list[NumericMetric]:
111
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
94
112
  escaped_positive_class_label = escape_str_literal(positive_class_label)
95
113
  normalization_case = f"""
96
114
  CASE
@@ -107,6 +125,7 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
107
125
  normalization_case,
108
126
  dataset,
109
127
  escaped_positive_class_label,
128
+ segmentation_cols,
110
129
  )
111
130
 
112
131
  def generate_confusion_matrix_metrics(
@@ -119,6 +138,7 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
119
138
  gt_normalization_case: str,
120
139
  dataset: DatasetReference,
121
140
  escaped_positive_class_label: str,
141
+ segmentation_cols: list[str],
122
142
  ) -> list[NumericMetric]:
123
143
  """
124
144
  Generate a SQL query to compute confusion matrix metrics over time.
@@ -132,58 +152,92 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
132
152
  gt_normalization_case: SQL CASE statement for normalizing ground truth values to 0 / 1 / null using 'value' as the target column name
133
153
  dataset: DatasetReference containing dataset metadata
134
154
  escaped_positive_class_label: escaped label for the class to include in the dimensions
155
+ segmentation_cols: List of columns to segment by
135
156
 
136
157
  Returns:
137
158
  str: SQL query that computes confusion matrix metrics
159
+ Returns the following SQL with no segmentation:
160
+ WITH normalized_data AS (
161
+ SELECT
162
+ {escaped_timestamp_col} AS timestamp,
163
+ {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
164
+ {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
165
+ FROM {dataset.dataset_table_name}
166
+ WHERE {escaped_timestamp_col} IS NOT NULL
167
+ )
168
+ SELECT
169
+ time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
170
+ SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
171
+ SUM(CASE WHEN prediction = 1 AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
172
+ SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
173
+ SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
174
+ any_value({escaped_positive_class_label}) as class_label
175
+ FROM normalized_data
176
+ GROUP BY ts
177
+ ORDER BY ts
178
+
138
179
  """
139
180
  escaped_timestamp_col = escape_identifier(timestamp_col)
140
181
  escaped_prediction_col = escape_identifier(prediction_col)
141
182
  escaped_gt_values_col = escape_identifier(gt_values_col)
183
+
184
+ # build query components with segmentation columns
185
+ escaped_segmentation_cols = [
186
+ escape_identifier(col) for col in segmentation_cols
187
+ ]
188
+ first_subquery_select_cols = [
189
+ f"{escaped_timestamp_col} AS timestamp",
190
+ f"{prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction",
191
+ f"{gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value",
192
+ ] + escaped_segmentation_cols
193
+ second_subquery_select_cols = [
194
+ "time_bucket(INTERVAL '5 minutes', timestamp) AS ts",
195
+ "SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count",
196
+ "SUM(CASE WHEN prediction = 1 AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count",
197
+ "SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count",
198
+ "SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count",
199
+ f"any_value({escaped_positive_class_label}) as class_label",
200
+ ] + escaped_segmentation_cols
201
+ second_subquery_group_by_cols = ["ts"] + escaped_segmentation_cols
202
+ extra_dims = ["class_label"]
203
+
204
+ # build query
142
205
  confusion_matrix_query = f"""
143
- WITH normalized_data AS (
144
- SELECT
145
- {escaped_timestamp_col} AS timestamp,
146
- {prediction_normalization_case.replace('value', escaped_prediction_col)} AS prediction,
147
- {gt_normalization_case.replace('value', escaped_gt_values_col)} AS actual_value
148
- FROM {dataset.dataset_table_name}
149
- WHERE {escaped_timestamp_col} IS NOT NULL
150
- )
151
- SELECT
152
- time_bucket(INTERVAL '5 minutes', timestamp) AS ts,
153
- SUM(CASE WHEN prediction = 1 AND actual_value = 1 THEN 1 ELSE 0 END) AS true_positive_count,
154
- SUM(CASE WHEN prediction = 1 AND actual_value = 0 THEN 1 ELSE 0 END) AS false_positive_count,
155
- SUM(CASE WHEN prediction = 0 AND actual_value = 1 THEN 1 ELSE 0 END) AS false_negative_count,
156
- SUM(CASE WHEN prediction = 0 AND actual_value = 0 THEN 1 ELSE 0 END) AS true_negative_count,
157
- any_value({escaped_positive_class_label}) as class_label
158
- FROM normalized_data
159
- GROUP BY ts
160
- ORDER BY ts
161
- """
206
+ WITH normalized_data AS (
207
+ SELECT {", ".join(first_subquery_select_cols)}
208
+ FROM {dataset.dataset_table_name}
209
+ WHERE {escaped_timestamp_col} IS NOT NULL
210
+ )
211
+ SELECT {", ".join(second_subquery_select_cols)}
212
+ FROM normalized_data
213
+ GROUP BY {", ".join(second_subquery_group_by_cols)}
214
+ ORDER BY ts
215
+ """
162
216
 
163
217
  results = ddb_conn.sql(confusion_matrix_query).df()
164
218
 
165
219
  tp = self.group_query_results_to_numeric_metrics(
166
220
  results,
167
221
  "true_positive_count",
168
- dim_columns=["class_label"],
222
+ dim_columns=segmentation_cols + extra_dims,
169
223
  timestamp_col="ts",
170
224
  )
171
225
  fp = self.group_query_results_to_numeric_metrics(
172
226
  results,
173
227
  "false_positive_count",
174
- dim_columns=["class_label"],
228
+ dim_columns=segmentation_cols + extra_dims,
175
229
  timestamp_col="ts",
176
230
  )
177
231
  fn = self.group_query_results_to_numeric_metrics(
178
232
  results,
179
233
  "false_negative_count",
180
- dim_columns=["class_label"],
234
+ dim_columns=segmentation_cols + extra_dims,
181
235
  timestamp_col="ts",
182
236
  )
183
237
  tn = self.group_query_results_to_numeric_metrics(
184
238
  results,
185
239
  "true_negative_count",
186
- dim_columns=["class_label"],
240
+ dim_columns=segmentation_cols + extra_dims,
187
241
  timestamp_col="ts",
188
242
  )
189
243
  tp_metric = self.series_to_metric(
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.functions.inference_count_by_class import (
@@ -10,6 +10,7 @@ from arthur_common.models.schema_definitions import (
10
10
  DType,
11
11
  MetricColumnParameterAnnotation,
12
12
  MetricDatasetParameterAnnotation,
13
+ MetricMultipleColumnParameterAnnotation,
13
14
  ScalarType,
14
15
  ScopeSchemaTag,
15
16
  )
@@ -81,10 +82,27 @@ class MulticlassClassifierCountByClassAggregationFunction(
81
82
  description="A column containing boolean, integer, or string labelled prediction values.",
82
83
  ),
83
84
  ],
85
+ segmentation_cols: Annotated[
86
+ Optional[list[str]],
87
+ MetricMultipleColumnParameterAnnotation(
88
+ source_dataset_parameter_key="dataset",
89
+ allowed_column_types=[
90
+ ScalarType(dtype=DType.INT),
91
+ ScalarType(dtype=DType.BOOL),
92
+ ScalarType(dtype=DType.STRING),
93
+ ScalarType(dtype=DType.UUID),
94
+ ],
95
+ tag_hints=[],
96
+ friendly_name="Segmentation Columns",
97
+ description="All columns to include as dimensions for segmentation.",
98
+ optional=True,
99
+ ),
100
+ ] = None,
84
101
  ) -> list[NumericMetric]:
85
102
  return super().aggregate(
86
103
  ddb_conn=ddb_conn,
87
104
  dataset=dataset,
88
105
  timestamp_col=timestamp_col,
89
106
  prediction_col=prediction_col,
107
+ segmentation_cols=segmentation_cols,
90
108
  )
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import SketchAggregationFunction
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
7
7
  DType,
8
8
  MetricColumnParameterAnnotation,
9
9
  MetricDatasetParameterAnnotation,
10
+ MetricMultipleColumnParameterAnnotation,
10
11
  ScalarType,
11
12
  ScopeSchemaTag,
12
13
  )
@@ -66,23 +67,59 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
66
67
  description="A column containing numeric values to calculate a data sketch on.",
67
68
  ),
68
69
  ],
70
+ segmentation_cols: Annotated[
71
+ Optional[list[str]],
72
+ MetricMultipleColumnParameterAnnotation(
73
+ source_dataset_parameter_key="dataset",
74
+ allowed_column_types=[
75
+ ScalarType(dtype=DType.INT),
76
+ ScalarType(dtype=DType.BOOL),
77
+ ScalarType(dtype=DType.STRING),
78
+ ScalarType(dtype=DType.UUID),
79
+ ],
80
+ tag_hints=[],
81
+ friendly_name="Segmentation Columns",
82
+ description="All columns to include as dimensions for segmentation.",
83
+ optional=True,
84
+ ),
85
+ ] = None,
69
86
  ) -> list[SketchMetric]:
87
+ """Executed SQL with no segmentation columns:
88
+ select {escaped_timestamp_col_id} as ts, \
89
+ {escaped_numeric_col_id}, \
90
+ {numeric_col_name_str} as column_name \
91
+ from {dataset.dataset_table_name} \
92
+ where {escaped_numeric_col_id} is not null \
93
+ """
94
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
70
95
  escaped_timestamp_col_id = escape_identifier(timestamp_col)
71
96
  escaped_numeric_col_id = escape_identifier(numeric_col)
72
97
  numeric_col_name_str = escape_str_literal(numeric_col)
73
- data_query = f" \
74
- select {escaped_timestamp_col_id} as ts, \
75
- {escaped_numeric_col_id}, \
76
- {numeric_col_name_str} as column_name \
77
- from {dataset.dataset_table_name} \
78
- where {escaped_numeric_col_id} is not null \
79
- "
98
+
99
+ # build query components with segmentation columns
100
+ escaped_segmentation_cols = [
101
+ escape_identifier(col) for col in segmentation_cols
102
+ ]
103
+ all_select_clause_cols = [
104
+ f"{escaped_timestamp_col_id} as ts",
105
+ f"{escaped_numeric_col_id}",
106
+ f"{numeric_col_name_str} as column_name",
107
+ ] + escaped_segmentation_cols
108
+ extra_dims = ["column_name"]
109
+
110
+ # build query
111
+ data_query = f"""
112
+ select {", ".join(all_select_clause_cols)}
113
+ from {dataset.dataset_table_name}
114
+ where {escaped_numeric_col_id} is not null
115
+ """
116
+
80
117
  results = ddb_conn.sql(data_query).df()
81
118
 
82
119
  series = self.group_query_results_to_sketch_metrics(
83
120
  results,
84
121
  numeric_col,
85
- ["column_name"],
122
+ segmentation_cols + extra_dims,
86
123
  "ts",
87
124
  )
88
125
 
@@ -1,4 +1,4 @@
1
- from typing import Annotated
1
+ from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
4
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
@@ -7,6 +7,7 @@ from arthur_common.models.schema_definitions import (
7
7
  DType,
8
8
  MetricColumnParameterAnnotation,
9
9
  MetricDatasetParameterAnnotation,
10
+ MetricMultipleColumnParameterAnnotation,
10
11
  ScalarType,
11
12
  ScopeSchemaTag,
12
13
  )
@@ -64,24 +65,63 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
64
65
  description="A column containing numeric values to sum.",
65
66
  ),
66
67
  ],
68
+ segmentation_cols: Annotated[
69
+ Optional[list[str]],
70
+ MetricMultipleColumnParameterAnnotation(
71
+ source_dataset_parameter_key="dataset",
72
+ allowed_column_types=[
73
+ ScalarType(dtype=DType.INT),
74
+ ScalarType(dtype=DType.BOOL),
75
+ ScalarType(dtype=DType.STRING),
76
+ ScalarType(dtype=DType.UUID),
77
+ ],
78
+ tag_hints=[],
79
+ friendly_name="Segmentation Columns",
80
+ description="All columns to include as dimensions for segmentation.",
81
+ optional=True,
82
+ ),
83
+ ] = None,
67
84
  ) -> list[NumericMetric]:
85
+ """Executed SQL with no segmentation columns:
86
+ select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
87
+ sum({escaped_numeric_col}) as sum \
88
+ from {dataset.dataset_table_name} \
89
+ where {escaped_numeric_col} is not null \
90
+ group by ts \
91
+ """
92
+ segmentation_cols = [] if not segmentation_cols else segmentation_cols
68
93
  escaped_timestamp_col = escape_identifier(timestamp_col)
69
94
  escaped_numeric_col = escape_identifier(numeric_col)
70
- count_query = f" \
71
- select time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts, \
72
- sum({escaped_numeric_col}) as sum \
73
- from {dataset.dataset_table_name} \
74
- where {escaped_numeric_col} is not null \
75
- group by ts \
76
- "
77
- results = ddb_conn.sql(count_query).df()
78
95
 
79
- series = self.dimensionless_query_results_to_numeric_metrics(
96
+ # build query components with segmentation columns
97
+ escaped_segmentation_cols = [
98
+ escape_identifier(col) for col in segmentation_cols
99
+ ]
100
+ all_select_clause_cols = [
101
+ f"time_bucket(INTERVAL '5 minutes', {escaped_timestamp_col}) as ts",
102
+ f"sum({escaped_numeric_col}) as sum",
103
+ ] + escaped_segmentation_cols
104
+ all_group_by_cols = ["ts"] + escaped_segmentation_cols
105
+
106
+ # build query
107
+ query = f"""
108
+ select {", ".join(all_select_clause_cols)}
109
+ from {dataset.dataset_table_name}
110
+ where {escaped_numeric_col} is not null
111
+ group by {", ".join(all_group_by_cols)}
112
+ """
113
+
114
+ results = ddb_conn.sql(query).df()
115
+
116
+ series = self.group_query_results_to_numeric_metrics(
80
117
  results,
81
118
  "sum",
119
+ segmentation_cols,
82
120
  "ts",
83
121
  )
84
- series.dimensions = [Dimension(name="column_name", value=numeric_col)]
122
+ # preserve dimension that identifies the name of the numeric column used for the aggregation
123
+ for point in series:
124
+ point.dimensions.append(Dimension(name="column_name", value=numeric_col))
85
125
 
86
- metric = self.series_to_metric(self.METRIC_NAME, [series])
126
+ metric = self.series_to_metric(self.METRIC_NAME, series)
87
127
  return [metric]