arthur-common 2.1.50__tar.gz → 2.1.52__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (43) hide show
  1. {arthur_common-2.1.50 → arthur_common-2.1.52}/PKG-INFO +2 -1
  2. {arthur_common-2.1.50 → arthur_common-2.1.52}/pyproject.toml +3 -1
  3. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/aggregator.py +2 -1
  4. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/categorical_count.py +5 -8
  5. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/confusion_matrix.py +18 -31
  6. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/inference_count.py +5 -8
  7. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/inference_count_by_class.py +7 -15
  8. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/inference_null_count.py +5 -8
  9. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/mean_absolute_error.py +5 -8
  10. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/mean_squared_error.py +5 -8
  11. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/multiclass_confusion_matrix.py +5 -8
  12. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +5 -8
  13. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/numeric_stats.py +5 -8
  14. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/numeric_sum.py +5 -8
  15. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/shield_aggregations.py +3 -2
  16. arthur_common-2.1.52/src/arthur_common/config/config.py +42 -0
  17. arthur_common-2.1.52/src/arthur_common/config/settings.yaml +4 -0
  18. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/connectors.py +8 -7
  19. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/metrics.py +77 -13
  20. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/schema_definitions.py +8 -1
  21. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/shield.py +0 -18
  22. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/task_job_specs.py +2 -1
  23. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/aggregation_analyzer.py +1 -1
  24. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/duckdb_data_loader.py +4 -3
  25. arthur_common-2.1.52/src/arthur_common/tools/duckdb_utils.py +35 -0
  26. arthur_common-2.1.52/src/arthur_common/tools/py.typed +0 -0
  27. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/schema_inferer.py +23 -2
  28. {arthur_common-2.1.50 → arthur_common-2.1.52}/README.md +0 -0
  29. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/__init__.py +0 -0
  30. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/__init__.py +0 -0
  31. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/README.md +0 -0
  32. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/__init__.py +0 -0
  33. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/functions/py.typed +0 -0
  34. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/aggregations/py.typed +0 -0
  35. {arthur_common-2.1.50/src/arthur_common/models → arthur_common-2.1.52/src/arthur_common/config}/__init__.py +0 -0
  36. {arthur_common-2.1.50/src/arthur_common/tools → arthur_common-2.1.52/src/arthur_common/models}/__init__.py +0 -0
  37. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/datasets.py +0 -0
  38. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/models/py.typed +0 -0
  39. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/py.typed +0 -0
  40. /arthur_common-2.1.50/src/arthur_common/tools/py.typed → /arthur_common-2.1.52/src/arthur_common/tools/__init__.py +0 -0
  41. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/aggregation_loader.py +0 -0
  42. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/functions.py +0 -0
  43. {arthur_common-2.1.50 → arthur_common-2.1.52}/src/arthur_common/tools/time_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: arthur-common
3
- Version: 2.1.50
3
+ Version: 2.1.52
4
4
  Summary: Utility code common to Arthur platform components.
5
5
  License: MIT
6
6
  Author: Arthur
@@ -16,6 +16,7 @@ Requires-Dist: fastapi (>=0.115.8)
16
16
  Requires-Dist: fsspec (>=2024.10.0)
17
17
  Requires-Dist: pandas (>=2.2.2)
18
18
  Requires-Dist: pydantic (>=2)
19
+ Requires-Dist: simple-settings (>=1.2.0)
19
20
  Requires-Dist: tokencost (==0.1.24)
20
21
  Requires-Dist: types-python-dateutil (>=2.9.0)
21
22
  Requires-Dist: types-requests (>=2.32.0.20241016)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "arthur-common"
3
- version = "2.1.50"
3
+ version = "2.1.52"
4
4
  description = "Utility code common to Arthur platform components."
5
5
  authors = ["Arthur <engineering@arthur.ai>"]
6
6
  license = "MIT"
@@ -18,6 +18,7 @@ types-python-dateutil = ">=2.9.0"
18
18
  fsspec = ">=2024.10.0"
19
19
  tokencost = "0.1.24"
20
20
  fastapi = ">=0.115.8"
21
+ simple-settings = ">=1.2.0"
21
22
 
22
23
 
23
24
  [tool.poetry.group.dev.dependencies]
@@ -26,6 +27,7 @@ responses = "0.25.7"
26
27
  pytest-xdist = "3.6.1"
27
28
  pytest-cov = "^6.1.1"
28
29
  pre-commit = "^4.2.0"
30
+ mypy = "^1.16.1"
29
31
 
30
32
  [tool.pytest.ini_options]
31
33
  pythonpath = ["src"]
@@ -3,10 +3,11 @@ from base64 import b64encode
3
3
  from typing import Any, Type, Union
4
4
 
5
5
  import pandas as pd
6
- from arthur_common.models.metrics import *
7
6
  from datasketches import kll_floats_sketch
8
7
  from duckdb import DuckDBPyConnection
9
8
 
9
+ from arthur_common.models.metrics import *
10
+
10
11
 
11
12
  class AggregationFunction(ABC):
12
13
  @staticmethod
@@ -1,9 +1,12 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.metrics import DatasetReference, NumericMetric
6
8
  from arthur_common.models.schema_definitions import (
9
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
7
10
  DType,
8
11
  MetricColumnParameterAnnotation,
9
12
  MetricDatasetParameterAnnotation,
@@ -12,7 +15,6 @@ from arthur_common.models.schema_definitions import (
12
15
  ScopeSchemaTag,
13
16
  )
14
17
  from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
15
- from duckdb import DuckDBPyConnection
16
18
 
17
19
 
18
20
  class CategoricalCountAggregationFunction(NumericAggregationFunction):
@@ -69,13 +71,8 @@ class CategoricalCountAggregationFunction(NumericAggregationFunction):
69
71
  Optional[list[str]],
70
72
  MetricMultipleColumnParameterAnnotation(
71
73
  source_dataset_parameter_key="dataset",
72
- allowed_column_types=[
73
- ScalarType(dtype=DType.INT),
74
- ScalarType(dtype=DType.BOOL),
75
- ScalarType(dtype=DType.STRING),
76
- ScalarType(dtype=DType.UUID),
77
- ],
78
- tag_hints=[],
74
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
75
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
79
76
  friendly_name="Segmentation Columns",
80
77
  description="All columns to include as dimensions for segmentation.",
81
78
  optional=True,
@@ -1,10 +1,13 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.datasets import ModelProblemType
6
8
  from arthur_common.models.metrics import DatasetReference, NumericMetric
7
9
  from arthur_common.models.schema_definitions import (
10
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
8
11
  DType,
9
12
  MetricColumnParameterAnnotation,
10
13
  MetricDatasetParameterAnnotation,
@@ -14,7 +17,6 @@ from arthur_common.models.schema_definitions import (
14
17
  ScopeSchemaTag,
15
18
  )
16
19
  from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
17
- from duckdb import DuckDBPyConnection
18
20
 
19
21
 
20
22
  class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
@@ -27,7 +29,7 @@ class ConfusionMatrixAggregationFunction(NumericAggregationFunction):
27
29
  prediction_normalization_case: str,
28
30
  gt_normalization_case: str,
29
31
  dataset: DatasetReference,
30
- segmentation_cols: list[str],
32
+ segmentation_cols: Optional[list[str]] = None,
31
33
  ) -> list[NumericMetric]:
32
34
  """
33
35
  Generate a SQL query to compute confusion matrix metrics over time.
@@ -202,13 +204,8 @@ class BinaryClassifierIntBoolConfusionMatrixAggregationFunction(
202
204
  Optional[list[str]],
203
205
  MetricMultipleColumnParameterAnnotation(
204
206
  source_dataset_parameter_key="dataset",
205
- allowed_column_types=[
206
- ScalarType(dtype=DType.INT),
207
- ScalarType(dtype=DType.BOOL),
208
- ScalarType(dtype=DType.STRING),
209
- ScalarType(dtype=DType.UUID),
210
- ],
211
- tag_hints=[],
207
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
208
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
212
209
  friendly_name="Segmentation Columns",
213
210
  description="All columns to include as dimensions for segmentation.",
214
211
  optional=True,
@@ -338,13 +335,8 @@ class BinaryClassifierStringLabelConfusionMatrixAggregationFunction(
338
335
  Optional[list[str]],
339
336
  MetricMultipleColumnParameterAnnotation(
340
337
  source_dataset_parameter_key="dataset",
341
- allowed_column_types=[
342
- ScalarType(dtype=DType.INT),
343
- ScalarType(dtype=DType.BOOL),
344
- ScalarType(dtype=DType.STRING),
345
- ScalarType(dtype=DType.UUID),
346
- ],
347
- tag_hints=[],
338
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
339
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
348
340
  friendly_name="Segmentation Columns",
349
341
  description="All columns to include as dimensions for segmentation.",
350
342
  optional=True,
@@ -446,13 +438,8 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
446
438
  Optional[list[str]],
447
439
  MetricMultipleColumnParameterAnnotation(
448
440
  source_dataset_parameter_key="dataset",
449
- allowed_column_types=[
450
- ScalarType(dtype=DType.INT),
451
- ScalarType(dtype=DType.BOOL),
452
- ScalarType(dtype=DType.STRING),
453
- ScalarType(dtype=DType.UUID),
454
- ],
455
- tag_hints=[],
441
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
442
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
456
443
  friendly_name="Segmentation Columns",
457
444
  description="All columns to include as dimensions for segmentation.",
458
445
  optional=True,
@@ -495,12 +482,12 @@ class BinaryClassifierProbabilityThresholdConfusionMatrixAggregationFunction(
495
482
  raise ValueError(f"Unsupported column type: {col_type}")
496
483
 
497
484
  return self.generate_confusion_matrix_metrics(
498
- ddb_conn,
499
- timestamp_col,
500
- prediction_col,
501
- gt_values_col,
502
- prediction_normalization_case,
503
- gt_normalization_case,
504
- dataset,
505
- segmentation_cols,
485
+ ddb_conn=ddb_conn,
486
+ timestamp_col=timestamp_col,
487
+ prediction_col=prediction_col,
488
+ gt_values_col=gt_values_col,
489
+ prediction_normalization_case=prediction_normalization_case,
490
+ gt_normalization_case=gt_normalization_case,
491
+ dataset=dataset,
492
+ segmentation_cols=segmentation_cols,
506
493
  )
@@ -1,9 +1,12 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.metrics import DatasetReference, NumericMetric
6
8
  from arthur_common.models.schema_definitions import (
9
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
7
10
  DType,
8
11
  MetricColumnParameterAnnotation,
9
12
  MetricDatasetParameterAnnotation,
@@ -12,7 +15,6 @@ from arthur_common.models.schema_definitions import (
12
15
  ScopeSchemaTag,
13
16
  )
14
17
  from arthur_common.tools.duckdb_data_loader import escape_identifier
15
- from duckdb import DuckDBPyConnection
16
18
 
17
19
 
18
20
  class InferenceCountAggregationFunction(NumericAggregationFunction):
@@ -56,13 +58,8 @@ class InferenceCountAggregationFunction(NumericAggregationFunction):
56
58
  Optional[list[str]],
57
59
  MetricMultipleColumnParameterAnnotation(
58
60
  source_dataset_parameter_key="dataset",
59
- allowed_column_types=[
60
- ScalarType(dtype=DType.INT),
61
- ScalarType(dtype=DType.BOOL),
62
- ScalarType(dtype=DType.STRING),
63
- ScalarType(dtype=DType.UUID),
64
- ],
65
- tag_hints=[],
61
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
62
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
66
63
  friendly_name="Segmentation Columns",
67
64
  description="All columns to include as dimensions for segmentation.",
68
65
  optional=True,
@@ -1,10 +1,13 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.datasets import ModelProblemType
6
8
  from arthur_common.models.metrics import DatasetReference, NumericMetric
7
9
  from arthur_common.models.schema_definitions import (
10
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
8
11
  DType,
9
12
  MetricColumnParameterAnnotation,
10
13
  MetricDatasetParameterAnnotation,
@@ -14,7 +17,6 @@ from arthur_common.models.schema_definitions import (
14
17
  ScopeSchemaTag,
15
18
  )
16
19
  from arthur_common.tools.duckdb_data_loader import escape_identifier
17
- from duckdb import DuckDBPyConnection
18
20
 
19
21
 
20
22
  class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction):
@@ -75,13 +77,8 @@ class BinaryClassifierCountByClassAggregationFunction(NumericAggregationFunction
75
77
  Optional[list[str]],
76
78
  MetricMultipleColumnParameterAnnotation(
77
79
  source_dataset_parameter_key="dataset",
78
- allowed_column_types=[
79
- ScalarType(dtype=DType.INT),
80
- ScalarType(dtype=DType.BOOL),
81
- ScalarType(dtype=DType.STRING),
82
- ScalarType(dtype=DType.UUID),
83
- ],
84
- tag_hints=[],
80
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
81
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
85
82
  friendly_name="Segmentation Columns",
86
83
  description="All columns to include as dimensions for segmentation.",
87
84
  optional=True,
@@ -219,13 +216,8 @@ class BinaryClassifierCountThresholdClassAggregationFunction(
219
216
  Optional[list[str]],
220
217
  MetricMultipleColumnParameterAnnotation(
221
218
  source_dataset_parameter_key="dataset",
222
- allowed_column_types=[
223
- ScalarType(dtype=DType.INT),
224
- ScalarType(dtype=DType.BOOL),
225
- ScalarType(dtype=DType.STRING),
226
- ScalarType(dtype=DType.UUID),
227
- ],
228
- tag_hints=[],
219
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
220
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
229
221
  friendly_name="Segmentation Columns",
230
222
  description="All columns to include as dimensions for segmentation.",
231
223
  optional=True,
@@ -1,9 +1,12 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
6
8
  from arthur_common.models.schema_definitions import (
9
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
7
10
  DType,
8
11
  MetricColumnParameterAnnotation,
9
12
  MetricDatasetParameterAnnotation,
@@ -12,7 +15,6 @@ from arthur_common.models.schema_definitions import (
12
15
  ScopeSchemaTag,
13
16
  )
14
17
  from arthur_common.tools.duckdb_data_loader import escape_identifier
15
- from duckdb import DuckDBPyConnection
16
18
 
17
19
 
18
20
  class InferenceNullCountAggregationFunction(NumericAggregationFunction):
@@ -65,13 +67,8 @@ class InferenceNullCountAggregationFunction(NumericAggregationFunction):
65
67
  Optional[list[str]],
66
68
  MetricMultipleColumnParameterAnnotation(
67
69
  source_dataset_parameter_key="dataset",
68
- allowed_column_types=[
69
- ScalarType(dtype=DType.INT),
70
- ScalarType(dtype=DType.BOOL),
71
- ScalarType(dtype=DType.STRING),
72
- ScalarType(dtype=DType.UUID),
73
- ],
74
- tag_hints=[],
70
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
71
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
75
72
  friendly_name="Segmentation Columns",
76
73
  description="All columns to include as dimensions for segmentation.",
77
74
  optional=True,
@@ -1,10 +1,13 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.datasets import ModelProblemType
6
8
  from arthur_common.models.metrics import DatasetReference, NumericMetric
7
9
  from arthur_common.models.schema_definitions import (
10
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
8
11
  DType,
9
12
  MetricColumnParameterAnnotation,
10
13
  MetricDatasetParameterAnnotation,
@@ -13,7 +16,6 @@ from arthur_common.models.schema_definitions import (
13
16
  ScopeSchemaTag,
14
17
  )
15
18
  from arthur_common.tools.duckdb_data_loader import escape_identifier
16
- from duckdb import DuckDBPyConnection
17
19
 
18
20
 
19
21
  class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
@@ -80,13 +82,8 @@ class MeanAbsoluteErrorAggregationFunction(NumericAggregationFunction):
80
82
  Optional[list[str]],
81
83
  MetricMultipleColumnParameterAnnotation(
82
84
  source_dataset_parameter_key="dataset",
83
- allowed_column_types=[
84
- ScalarType(dtype=DType.INT),
85
- ScalarType(dtype=DType.BOOL),
86
- ScalarType(dtype=DType.STRING),
87
- ScalarType(dtype=DType.UUID),
88
- ],
89
- tag_hints=[],
85
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
86
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
90
87
  friendly_name="Segmentation Columns",
91
88
  description="All columns to include as dimensions for segmentation.",
92
89
  optional=True,
@@ -1,10 +1,13 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.datasets import ModelProblemType
6
8
  from arthur_common.models.metrics import DatasetReference, NumericMetric
7
9
  from arthur_common.models.schema_definitions import (
10
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
8
11
  DType,
9
12
  MetricColumnParameterAnnotation,
10
13
  MetricDatasetParameterAnnotation,
@@ -13,7 +16,6 @@ from arthur_common.models.schema_definitions import (
13
16
  ScopeSchemaTag,
14
17
  )
15
18
  from arthur_common.tools.duckdb_data_loader import escape_identifier
16
- from duckdb import DuckDBPyConnection
17
19
 
18
20
 
19
21
  class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
@@ -80,13 +82,8 @@ class MeanSquaredErrorAggregationFunction(NumericAggregationFunction):
80
82
  Optional[list[str]],
81
83
  MetricMultipleColumnParameterAnnotation(
82
84
  source_dataset_parameter_key="dataset",
83
- allowed_column_types=[
84
- ScalarType(dtype=DType.INT),
85
- ScalarType(dtype=DType.BOOL),
86
- ScalarType(dtype=DType.STRING),
87
- ScalarType(dtype=DType.UUID),
88
- ],
89
- tag_hints=[],
85
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
86
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
90
87
  friendly_name="Segmentation Columns",
91
88
  description="All columns to include as dimensions for segmentation.",
92
89
  optional=True,
@@ -1,10 +1,13 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.datasets import ModelProblemType
6
8
  from arthur_common.models.metrics import DatasetReference, NumericMetric
7
9
  from arthur_common.models.schema_definitions import (
10
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
8
11
  DType,
9
12
  MetricColumnParameterAnnotation,
10
13
  MetricDatasetParameterAnnotation,
@@ -14,7 +17,6 @@ from arthur_common.models.schema_definitions import (
14
17
  ScopeSchemaTag,
15
18
  )
16
19
  from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
17
- from duckdb import DuckDBPyConnection
18
20
 
19
21
 
20
22
  class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFunction(
@@ -95,13 +97,8 @@ class MulticlassClassifierStringLabelSingleClassConfusionMatrixAggregationFuncti
95
97
  Optional[list[str]],
96
98
  MetricMultipleColumnParameterAnnotation(
97
99
  source_dataset_parameter_key="dataset",
98
- allowed_column_types=[
99
- ScalarType(dtype=DType.INT),
100
- ScalarType(dtype=DType.BOOL),
101
- ScalarType(dtype=DType.STRING),
102
- ScalarType(dtype=DType.UUID),
103
- ],
104
- tag_hints=[],
100
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
101
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
105
102
  friendly_name="Segmentation Columns",
106
103
  description="All columns to include as dimensions for segmentation.",
107
104
  optional=True,
@@ -1,12 +1,15 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.functions.inference_count_by_class import (
5
7
  BinaryClassifierCountByClassAggregationFunction,
6
8
  )
7
9
  from arthur_common.models.datasets import ModelProblemType
8
10
  from arthur_common.models.metrics import DatasetReference, NumericMetric
9
11
  from arthur_common.models.schema_definitions import (
12
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
10
13
  DType,
11
14
  MetricColumnParameterAnnotation,
12
15
  MetricDatasetParameterAnnotation,
@@ -14,7 +17,6 @@ from arthur_common.models.schema_definitions import (
14
17
  ScalarType,
15
18
  ScopeSchemaTag,
16
19
  )
17
- from duckdb import DuckDBPyConnection
18
20
 
19
21
 
20
22
  class MulticlassClassifierCountByClassAggregationFunction(
@@ -86,13 +88,8 @@ class MulticlassClassifierCountByClassAggregationFunction(
86
88
  Optional[list[str]],
87
89
  MetricMultipleColumnParameterAnnotation(
88
90
  source_dataset_parameter_key="dataset",
89
- allowed_column_types=[
90
- ScalarType(dtype=DType.INT),
91
- ScalarType(dtype=DType.BOOL),
92
- ScalarType(dtype=DType.STRING),
93
- ScalarType(dtype=DType.UUID),
94
- ],
95
- tag_hints=[],
91
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
92
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
96
93
  friendly_name="Segmentation Columns",
97
94
  description="All columns to include as dimensions for segmentation.",
98
95
  optional=True,
@@ -1,9 +1,12 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import SketchAggregationFunction
5
7
  from arthur_common.models.metrics import DatasetReference, SketchMetric
6
8
  from arthur_common.models.schema_definitions import (
9
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
7
10
  DType,
8
11
  MetricColumnParameterAnnotation,
9
12
  MetricDatasetParameterAnnotation,
@@ -12,7 +15,6 @@ from arthur_common.models.schema_definitions import (
12
15
  ScopeSchemaTag,
13
16
  )
14
17
  from arthur_common.tools.duckdb_data_loader import escape_identifier, escape_str_literal
15
- from duckdb import DuckDBPyConnection
16
18
 
17
19
 
18
20
  class NumericSketchAggregationFunction(SketchAggregationFunction):
@@ -71,13 +73,8 @@ class NumericSketchAggregationFunction(SketchAggregationFunction):
71
73
  Optional[list[str]],
72
74
  MetricMultipleColumnParameterAnnotation(
73
75
  source_dataset_parameter_key="dataset",
74
- allowed_column_types=[
75
- ScalarType(dtype=DType.INT),
76
- ScalarType(dtype=DType.BOOL),
77
- ScalarType(dtype=DType.STRING),
78
- ScalarType(dtype=DType.UUID),
79
- ],
80
- tag_hints=[],
76
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
77
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
81
78
  friendly_name="Segmentation Columns",
82
79
  description="All columns to include as dimensions for segmentation.",
83
80
  optional=True,
@@ -1,9 +1,12 @@
1
1
  from typing import Annotated, Optional
2
2
  from uuid import UUID
3
3
 
4
+ from duckdb import DuckDBPyConnection
5
+
4
6
  from arthur_common.aggregations.aggregator import NumericAggregationFunction
5
7
  from arthur_common.models.metrics import DatasetReference, Dimension, NumericMetric
6
8
  from arthur_common.models.schema_definitions import (
9
+ SEGMENTATION_ALLOWED_COLUMN_TYPES,
7
10
  DType,
8
11
  MetricColumnParameterAnnotation,
9
12
  MetricDatasetParameterAnnotation,
@@ -12,7 +15,6 @@ from arthur_common.models.schema_definitions import (
12
15
  ScopeSchemaTag,
13
16
  )
14
17
  from arthur_common.tools.duckdb_data_loader import escape_identifier
15
- from duckdb import DuckDBPyConnection
16
18
 
17
19
 
18
20
  class NumericSumAggregationFunction(NumericAggregationFunction):
@@ -69,13 +71,8 @@ class NumericSumAggregationFunction(NumericAggregationFunction):
69
71
  Optional[list[str]],
70
72
  MetricMultipleColumnParameterAnnotation(
71
73
  source_dataset_parameter_key="dataset",
72
- allowed_column_types=[
73
- ScalarType(dtype=DType.INT),
74
- ScalarType(dtype=DType.BOOL),
75
- ScalarType(dtype=DType.STRING),
76
- ScalarType(dtype=DType.UUID),
77
- ],
78
- tag_hints=[],
74
+ allowed_column_types=SEGMENTATION_ALLOWED_COLUMN_TYPES,
75
+ tag_hints=[ScopeSchemaTag.POSSIBLE_SEGMENTATION],
79
76
  friendly_name="Segmentation Columns",
80
77
  description="All columns to include as dimensions for segmentation.",
81
78
  optional=True,
@@ -2,6 +2,9 @@ from typing import Annotated
2
2
  from uuid import UUID
3
3
 
4
4
  import pandas as pd
5
+ from duckdb import DuckDBPyConnection
6
+ from tokencost import calculate_cost_by_tokens
7
+
5
8
  from arthur_common.aggregations.aggregator import (
6
9
  NumericAggregationFunction,
7
10
  SketchAggregationFunction,
@@ -13,8 +16,6 @@ from arthur_common.models.schema_definitions import (
13
16
  MetricColumnParameterAnnotation,
14
17
  MetricDatasetParameterAnnotation,
15
18
  )
16
- from duckdb import DuckDBPyConnection
17
- from tokencost import calculate_cost_by_tokens
18
19
 
19
20
 
20
21
  class ShieldInferencePassFailCountAggregation(NumericAggregationFunction):
@@ -0,0 +1,42 @@
1
+ # get the current directory of this file
2
+ import logging
3
+ import pathlib
4
+
5
+ directory = pathlib.Path(__file__).parent.resolve()
6
+
7
+ # create settings object that reads from settings.yaml and takes overrides from env
8
+ # can also be overwritten via the CLI
9
+ # https://github.com/drgarcia1986/simple-settings
10
+ from simple_settings import LazySettings
11
+
12
+ settings = LazySettings(f"{directory}/settings.yaml", ".environ")
13
+
14
+ logger = logging.getLogger()
15
+
16
+
17
+ class Config:
18
+ settings = settings
19
+
20
+ @staticmethod
21
+ def convert_to_int(value: str | int, setting_name: str) -> int:
22
+ if isinstance(value, int):
23
+ return value
24
+ elif value == "":
25
+ raise ValueError(
26
+ f"Config setting {setting_name} could not be cast to an int.",
27
+ )
28
+
29
+ # attempt to convert setting to int
30
+ try:
31
+ return int(value.strip())
32
+ except TypeError:
33
+ raise ValueError(
34
+ f"Config setting {setting_name} could not be cast to an int.",
35
+ )
36
+
37
+ @staticmethod
38
+ def segmentation_col_unique_values_limit() -> int:
39
+ return Config.convert_to_int(
40
+ settings.SEGMENTATION_COL_UNIQUE_VALUE_LIMIT,
41
+ "SEGMENTATION_COL_UNIQUE_VALUE_LIMIT",
42
+ )
@@ -0,0 +1,4 @@
1
+ # add arthur-common default settings here
2
+ ################################################
3
+ # Aggregation Configurations
4
+ SEGMENTATION_COL_UNIQUE_VALUE_LIMIT: 100
@@ -29,13 +29,14 @@ GOOGLE_CONNECTOR_PROJECT_ID_FIELD = "project_id"
29
29
  GOOGLE_CONNECTOR_LOCATION_FIELD = "location"
30
30
  SHIELD_CONNECTOR_API_KEY_FIELD = "api_key"
31
31
  SHIELD_CONNECTOR_ENDPOINT_FIELD = "endpoint"
32
- MSSQL_CONNECTOR_HOST_FIELD = "host"
33
- MSSQL_CONNECTOR_PORT_FIELD = "port"
34
- MSSQL_CONNECTOR_DATABASE_FIELD = "database"
35
- MSSQL_CONNECTOR_USERNAME_FIELD = "username"
36
- MSSQL_CONNECTOR_PASSWORD_FIELD = "password"
37
- MSSQL_CONNECTOR_DRIVER_FIELD = "driver"
38
- MSSQL_CONNECTOR_TABLE_NAME_FIELD = "table_name"
32
+ ODBC_CONNECTOR_HOST_FIELD = "host"
33
+ ODBC_CONNECTOR_PORT_FIELD = "port"
34
+ ODBC_CONNECTOR_DATABASE_FIELD = "database"
35
+ ODBC_CONNECTOR_USERNAME_FIELD = "username"
36
+ ODBC_CONNECTOR_PASSWORD_FIELD = "password"
37
+ ODBC_CONNECTOR_DRIVER_FIELD = "driver"
38
+ ODBC_CONNECTOR_TABLE_NAME_FIELD = "table_name"
39
+ ODBC_CONNECTOR_DIALECT_FIELD = "dialect"
39
40
 
40
41
 
41
42
  # dataset (connector type dependent) constants
@@ -4,14 +4,15 @@ from enum import Enum
4
4
  from typing import Literal, Optional
5
5
  from uuid import UUID
6
6
 
7
+ from pydantic import BaseModel, Field, field_validator, model_validator
8
+ from typing_extensions import Self
9
+
7
10
  from arthur_common.models.datasets import ModelProblemType
8
11
  from arthur_common.models.schema_definitions import (
9
12
  DType,
10
13
  SchemaTypeUnion,
11
14
  ScopeSchemaTag,
12
15
  )
13
- from pydantic import BaseModel, Field, model_validator
14
- from typing_extensions import Self
15
16
 
16
17
 
17
18
  # Temporary limited list, expand this as we grow and make it more in line with custom transformations later on
@@ -111,12 +112,9 @@ class AggregationMetricType(Enum):
111
112
  NUMERIC = "numeric"
112
113
 
113
114
 
114
- class MetricsParameterSchema(BaseModel):
115
+ class BaseAggregationParameterSchema(BaseModel):
116
+ # fields for aggregation parameters shared across all parameter types and between default and custom metrics
115
117
  parameter_key: str = Field(description="Name of the parameter.")
116
- optional: bool = Field(
117
- False,
118
- description="Boolean denoting if the parameter is optional.",
119
- )
120
118
  friendly_name: str = Field(
121
119
  description="User facing name of the parameter.",
122
120
  )
@@ -125,7 +123,16 @@ class MetricsParameterSchema(BaseModel):
125
123
  )
126
124
 
127
125
 
128
- class MetricsDatasetParameterSchema(MetricsParameterSchema):
126
+ class MetricsParameterSchema(BaseAggregationParameterSchema):
127
+ # specific to default metrics/Python metrics—not available to custom aggregations
128
+ optional: bool = Field(
129
+ False,
130
+ description="Boolean denoting if the parameter is optional.",
131
+ )
132
+
133
+
134
+ class BaseDatasetParameterSchema(BaseAggregationParameterSchema):
135
+ # fields specific to dataset parameters shared across default and custom metrics
129
136
  parameter_type: Literal["dataset"] = "dataset"
130
137
  model_problem_type: Optional[ModelProblemType] = Field(
131
138
  default=None,
@@ -133,13 +140,24 @@ class MetricsDatasetParameterSchema(MetricsParameterSchema):
133
140
  )
134
141
 
135
142
 
136
- class MetricsLiteralParameterSchema(MetricsParameterSchema):
143
+ class MetricsDatasetParameterSchema(MetricsParameterSchema, BaseDatasetParameterSchema):
144
+ # dataset parameter schema including fields specific to default metrics
145
+ pass
146
+
147
+
148
+ class BaseLiteralParameterSchema(BaseAggregationParameterSchema):
149
+ # fields specific to literal parameters shared across default and custom metrics
137
150
  parameter_type: Literal["literal"] = "literal"
138
151
  parameter_dtype: DType = Field(description="Data type of the parameter.")
139
152
 
140
153
 
141
- class MetricsColumnParameterSchema(MetricsParameterSchema):
142
- parameter_type: Literal["column"] = "column"
154
+ class MetricsLiteralParameterSchema(MetricsParameterSchema, BaseLiteralParameterSchema):
155
+ # literal parameter schema including fields specific to default metrics
156
+ pass
157
+
158
+
159
+ class BaseColumnBaseParameterSchema(BaseAggregationParameterSchema):
160
+ # fields specific to all single or multiple column parameters shared across default and custom metrics
143
161
  tag_hints: list[ScopeSchemaTag] = Field(
144
162
  [],
145
163
  description="List of tags that are applicable to this parameter. Datasets with columns that have matching tags can be inferred this way.",
@@ -165,8 +183,18 @@ class MetricsColumnParameterSchema(MetricsParameterSchema):
165
183
  return self
166
184
 
167
185
 
168
- # Not used /implemented yet. Might turn into group by column list
169
- class MetricsColumnListParameterSchema(MetricsColumnParameterSchema):
186
+ class BaseColumnParameterSchema(BaseColumnBaseParameterSchema):
187
+ # single column parameter schema common across default and custom metrics
188
+ parameter_type: Literal["column"] = "column"
189
+
190
+
191
+ class MetricsColumnParameterSchema(MetricsParameterSchema, BaseColumnParameterSchema):
192
+ # single column parameter schema specific to default metrics
193
+ parameter_type: Literal["column"] = "column"
194
+
195
+
196
+ class MetricsColumnListParameterSchema(MetricsParameterSchema, BaseColumnParameterSchema):
197
+ # list column parameter schema specific to default metrics
170
198
  parameter_type: Literal["column_list"] = "column_list"
171
199
 
172
200
 
@@ -177,6 +205,17 @@ MetricsParameterSchemaUnion = (
177
205
  | MetricsColumnListParameterSchema
178
206
  )
179
207
 
208
+ MetricsColumnSchemaUnion = (
209
+ MetricsColumnParameterSchema | MetricsColumnListParameterSchema
210
+ )
211
+
212
+
213
+ CustomAggregationParametersSchemaUnion = (
214
+ BaseDatasetParameterSchema
215
+ | BaseLiteralParameterSchema
216
+ | BaseColumnParameterSchema
217
+ )
218
+
180
219
 
181
220
  @dataclass
182
221
  class DatasetReference:
@@ -221,3 +260,28 @@ class AggregationSpecSchema(BaseModel):
221
260
  f"Column parameter '{param.parameter_key}' references dataset parameter '{param.source_dataset_parameter_key}' which does not exist.",
222
261
  )
223
262
  return self
263
+
264
+
265
+ class BaseReportedAggregation(BaseModel):
266
+ # in future will be used by default metrics
267
+ metric_name: str = Field(description="Name of the reported aggregation metric.")
268
+ description: str = Field(
269
+ description="Description of the reported aggregation metric and what it aggregates.",
270
+ )
271
+
272
+
273
+ class ReportedCustomAggregation(BaseReportedAggregation):
274
+ value_column: str = Field(description="Name of the column returned from the SQL query holding the metric value.")
275
+ timestamp_column: str = Field(description="Name of the column returned from the SQL query holding the timestamp buckets.")
276
+ metric_kind: AggregationMetricType = Field(
277
+ description="Return type of the reported aggregation metric value.",
278
+ )
279
+ dimension_columns: list[str] = Field(description="Name of any dimension columns returned from the SQL query. Max length is 1.")
280
+
281
+ @field_validator('dimension_columns')
282
+ @classmethod
283
+ def validate_dimension_columns_length(cls, v: list[str]) -> str:
284
+ if len(v) > 1:
285
+ raise ValueError('Only one dimension column can be specified.')
286
+ return v
287
+
@@ -4,9 +4,10 @@ from enum import Enum
4
4
  from typing import Optional, Self, Union
5
5
  from uuid import UUID, uuid4
6
6
 
7
- from arthur_common.models.datasets import ModelProblemType
8
7
  from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
9
8
 
9
+ from arthur_common.models.datasets import ModelProblemType
10
+
10
11
 
11
12
  class ScopeSchemaTag(str, Enum):
12
13
  LLM_CONTEXT = "llm_context"
@@ -18,6 +19,7 @@ class ScopeSchemaTag(str, Enum):
18
19
  PREDICTION = "prediction"
19
20
  GROUND_TRUTH = "ground_truth"
20
21
  PIN_IN_DEEP_DIVE = "pin_in_deep_dive"
22
+ POSSIBLE_SEGMENTATION = "possible_segmentation"
21
23
 
22
24
 
23
25
  class DType(str, Enum):
@@ -420,3 +422,8 @@ def SHIELD_SCHEMA() -> DatasetSchema:
420
422
 
421
423
  SHIELD_RESPONSE_SCHEMA = create_shield_response_schema().to_base_type()
422
424
  SHIELD_PROMPT_SCHEMA = create_shield_prompt_schema().to_base_type()
425
+
426
+ SEGMENTATION_ALLOWED_DTYPES = [DType.INT, DType.BOOL, DType.STRING, DType.UUID]
427
+ SEGMENTATION_ALLOWED_COLUMN_TYPES = [
428
+ ScalarType(dtype=d_type) for d_type in SEGMENTATION_ALLOWED_DTYPES
429
+ ]
@@ -10,9 +10,7 @@ DEFAULT_PII_RULE_CONFIDENCE_SCORE_THRESHOLD = 0
10
10
 
11
11
  class RuleType(str, Enum):
12
12
  KEYWORD = "KeywordRule"
13
- MODEL_HALLUCINATION = "ModelHallucinationRule"
14
13
  MODEL_HALLUCINATION_V2 = "ModelHallucinationRuleV2"
15
- MODEL_HALLUCINATION_V3 = "ModelHallucinationRuleV3"
16
14
  MODEL_SENSITIVE_DATA = "ModelSensitiveDataRule"
17
15
  PII_DATA = "PIIDataRule"
18
16
  PROMPT_INJECTION = "PromptInjectionRule"
@@ -456,14 +454,6 @@ class NewRuleRequest(BaseModel):
456
454
  detail="PromptInjectionRule can only be enabled for prompt. Please set the 'apply_to_response' field "
457
455
  "to false.",
458
456
  )
459
- if (self.type == RuleType.MODEL_HALLUCINATION) and (
460
- self.apply_to_prompt is True
461
- ):
462
- raise HTTPException(
463
- status_code=400,
464
- detail="ModelHallucinationRule can only be enabled for response. Please set the 'apply_to_prompt' "
465
- "field to false.",
466
- )
467
457
  if (self.type == RuleType.MODEL_HALLUCINATION_V2) and (
468
458
  self.apply_to_prompt is True
469
459
  ):
@@ -472,14 +462,6 @@ class NewRuleRequest(BaseModel):
472
462
  detail="ModelHallucinationRuleV2 can only be enabled for response. Please set the 'apply_to_prompt' "
473
463
  "field to false.",
474
464
  )
475
- if (self.type == RuleType.MODEL_HALLUCINATION_V3) and (
476
- self.apply_to_prompt is True
477
- ):
478
- raise HTTPException(
479
- status_code=400,
480
- detail="ModelHallucinationRuleV3 can only be enabled for response. Please set the "
481
- "'apply_to_prompt' field to false.",
482
- )
483
465
  if (self.apply_to_prompt is False) and (self.apply_to_response is False):
484
466
  raise HTTPException(
485
467
  status_code=400,
@@ -1,9 +1,10 @@
1
1
  from typing import Literal, Optional
2
2
  from uuid import UUID
3
3
 
4
- from arthur_common.models.shield import NewRuleRequest
5
4
  from pydantic import BaseModel, Field
6
5
 
6
+ from arthur_common.models.shield import NewRuleRequest
7
+
7
8
  onboarding_id_desc = "An identifier to assign to the created model to make it easy to retrieve. Used by the UI during the GenAI model creation flow."
8
9
 
9
10
 
@@ -84,7 +84,7 @@ class FunctionAnalyzer:
84
84
  @staticmethod
85
85
  def _get_scope_metric_parameter_from_annotation(
86
86
  param_name: str,
87
- param_dtype: typing.Optional[DType],
87
+ param_dtype: DType,
88
88
  optional: bool,
89
89
  annotation: typing.Annotated, # type: ignore
90
90
  ) -> MetricsParameterSchemaUnion:
@@ -3,6 +3,10 @@ from typing import Any
3
3
 
4
4
  import duckdb
5
5
  import pandas as pd
6
+ from dateutil.parser import parse
7
+ from fsspec import filesystem
8
+ from pydantic import BaseModel
9
+
6
10
  from arthur_common.models.datasets import DatasetJoinKind
7
11
  from arthur_common.models.schema_definitions import (
8
12
  DatasetListType,
@@ -11,9 +15,6 @@ from arthur_common.models.schema_definitions import (
11
15
  DatasetSchema,
12
16
  DType,
13
17
  )
14
- from dateutil.parser import parse
15
- from fsspec import filesystem
16
- from pydantic import BaseModel
17
18
 
18
19
 
19
20
  class ColumnFormat(BaseModel):
@@ -0,0 +1,35 @@
1
+ import duckdb
2
+
3
+ from arthur_common.config.config import Config
4
+ from arthur_common.models.schema_definitions import SEGMENTATION_ALLOWED_DTYPES, DType
5
+ from arthur_common.tools.duckdb_data_loader import escape_identifier
6
+
7
+
8
+ def is_column_possible_segmentation(
9
+ conn: duckdb.DuckDBPyConnection,
10
+ table: str,
11
+ column_name: str,
12
+ column_dtype: DType,
13
+ ) -> bool:
14
+ """Returns whether column fits segmentation criteria:
15
+ 1. Has fewer than SEGMENTATION_COL_UNIQUE_VALUE_LIMIT unique values.
16
+ 2. Has an allowed DType.
17
+
18
+ PreReq: Table with column should already be loaded in DuckDB
19
+ """
20
+ segmentation_col_unique_val_limit = Config.segmentation_col_unique_values_limit()
21
+ if column_dtype not in SEGMENTATION_ALLOWED_DTYPES:
22
+ return False
23
+
24
+ # check column for unique value count
25
+ escaped_column = escape_identifier(column_name)
26
+
27
+ # count distinct values in this column
28
+ distinct_count_query = f"""
29
+ SELECT COUNT(DISTINCT {escaped_column}) as distinct_count
30
+ FROM {table}
31
+ """
32
+ result = conn.sql(distinct_count_query).fetchone()
33
+ distinct_count = result[0] if result else 0
34
+
35
+ return distinct_count < segmentation_col_unique_val_limit
File without changes
@@ -2,6 +2,7 @@ from typing import Any
2
2
  from uuid import uuid4
3
3
 
4
4
  import pandas as pd
5
+
5
6
  from arthur_common.models.schema_definitions import (
6
7
  DatasetColumn,
7
8
  DatasetListType,
@@ -12,6 +13,7 @@ from arthur_common.models.schema_definitions import (
12
13
  ScopeSchemaTag,
13
14
  )
14
15
  from arthur_common.tools.duckdb_data_loader import DuckDBOperator, escape_identifier
16
+ from arthur_common.tools.duckdb_utils import is_column_possible_segmentation
15
17
 
16
18
 
17
19
  class SchemaInferer:
@@ -38,14 +40,21 @@ class SchemaInferer:
38
40
  self.conn.sql(
39
41
  f"CREATE OR REPLACE TEMP TABLE {escaped_col} AS SELECT UNNEST({escaped_col}) as {escaped_col} FROM {table}",
40
42
  )
41
- return self._infer_schema(escaped_col)
43
+ return self._infer_schema(escaped_col, is_nested_col=True)
42
44
 
43
- def _infer_schema(self, table: str = "root") -> DatasetObjectType:
45
+ def _infer_schema(
46
+ self,
47
+ table: str = "root",
48
+ is_nested_col: bool = False,
49
+ ) -> DatasetObjectType:
50
+ """is_nested_col indicates whether the function is being called on an unnested/flattened table that represents
51
+ a struct column or list column in the root table."""
44
52
  ddb_schema: list[tuple[Any, Any, Any]] = self.conn.sql(
45
53
  f"DESCRIBE {table}",
46
54
  ).fetchall()
47
55
 
48
56
  obj = DatasetObjectType(id=uuid4(), object={}, nullable=False)
57
+ # object has a dict of each column
49
58
  timestamp_cols = []
50
59
 
51
60
  for column in ddb_schema:
@@ -94,6 +103,18 @@ class SchemaInferer:
94
103
  timestamp_cols.append(scalar_schema)
95
104
  case _:
96
105
  raise NotImplementedError(f"Type {col_type} not mappable.")
106
+
107
+ # tag column as a possible segmentation column if it meets criteria
108
+ # we only support top-level column aggregations right now (ie you can't aggregate on a nested column)
109
+ # so we don't want to tag nested columns as possible segmentation columns
110
+ if not is_nested_col and is_column_possible_segmentation(
111
+ self.conn,
112
+ table,
113
+ col_name,
114
+ scalar_schema.dtype,
115
+ ):
116
+ scalar_schema.tag_hints.append(ScopeSchemaTag.POSSIBLE_SEGMENTATION)
117
+
97
118
  obj.object[col_name] = scalar_schema
98
119
 
99
120
  # auto assign primary timestamp tag if there's only one timestamp column
File without changes