data-designer-config 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. data_designer/config/__init__.py +149 -0
  2. data_designer/config/_version.py +34 -0
  3. data_designer/config/analysis/__init__.py +2 -0
  4. data_designer/config/analysis/column_profilers.py +159 -0
  5. data_designer/config/analysis/column_statistics.py +421 -0
  6. data_designer/config/analysis/dataset_profiler.py +84 -0
  7. data_designer/config/analysis/utils/errors.py +10 -0
  8. data_designer/config/analysis/utils/reporting.py +192 -0
  9. data_designer/config/base.py +69 -0
  10. data_designer/config/column_configs.py +476 -0
  11. data_designer/config/column_types.py +141 -0
  12. data_designer/config/config_builder.py +595 -0
  13. data_designer/config/data_designer_config.py +40 -0
  14. data_designer/config/dataset_builders.py +13 -0
  15. data_designer/config/dataset_metadata.py +18 -0
  16. data_designer/config/default_model_settings.py +129 -0
  17. data_designer/config/errors.py +24 -0
  18. data_designer/config/interface.py +55 -0
  19. data_designer/config/models.py +486 -0
  20. data_designer/config/preview_results.py +41 -0
  21. data_designer/config/processors.py +148 -0
  22. data_designer/config/run_config.py +56 -0
  23. data_designer/config/sampler_constraints.py +52 -0
  24. data_designer/config/sampler_params.py +639 -0
  25. data_designer/config/seed.py +116 -0
  26. data_designer/config/seed_source.py +84 -0
  27. data_designer/config/seed_source_types.py +19 -0
  28. data_designer/config/testing/__init__.py +6 -0
  29. data_designer/config/testing/fixtures.py +308 -0
  30. data_designer/config/utils/code_lang.py +93 -0
  31. data_designer/config/utils/constants.py +365 -0
  32. data_designer/config/utils/errors.py +21 -0
  33. data_designer/config/utils/info.py +94 -0
  34. data_designer/config/utils/io_helpers.py +258 -0
  35. data_designer/config/utils/misc.py +78 -0
  36. data_designer/config/utils/numerical_helpers.py +30 -0
  37. data_designer/config/utils/type_helpers.py +106 -0
  38. data_designer/config/utils/visualization.py +482 -0
  39. data_designer/config/validator_params.py +94 -0
  40. data_designer/errors.py +7 -0
  41. data_designer/lazy_heavy_imports.py +56 -0
  42. data_designer/logging.py +180 -0
  43. data_designer/plugin_manager.py +78 -0
  44. data_designer/plugins/__init__.py +8 -0
  45. data_designer/plugins/errors.py +15 -0
  46. data_designer/plugins/plugin.py +141 -0
  47. data_designer/plugins/registry.py +88 -0
  48. data_designer_config-0.4.0.dist-info/METADATA +75 -0
  49. data_designer_config-0.4.0.dist-info/RECORD +50 -0
  50. data_designer_config-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,116 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from enum import Enum
7
+
8
+ from pydantic import Field, model_validator
9
+ from typing_extensions import Self
10
+
11
+ from data_designer.config.base import ConfigBase
12
+ from data_designer.config.seed_source_types import SeedSourceT
13
+
14
+
15
+ class SamplingStrategy(str, Enum):
16
+ ORDERED = "ordered"
17
+ SHUFFLE = "shuffle"
18
+
19
+
20
+ class IndexRange(ConfigBase):
21
+ start: int = Field(ge=0, description="The start index of the index range (inclusive)")
22
+ end: int = Field(ge=0, description="The end index of the index range (inclusive)")
23
+
24
+ @model_validator(mode="after")
25
+ def _validate_index_range(self) -> Self:
26
+ if self.start > self.end:
27
+ raise ValueError("'start' index must be less than or equal to 'end' index")
28
+ return self
29
+
30
+ @property
31
+ def size(self) -> int:
32
+ return self.end - self.start + 1
33
+
34
+
35
+ class PartitionBlock(ConfigBase):
36
+ index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
37
+ num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
38
+
39
+ @model_validator(mode="after")
40
+ def _validate_partition_block(self) -> Self:
41
+ if self.index >= self.num_partitions:
42
+ raise ValueError("'index' must be less than 'num_partitions'")
43
+ return self
44
+
45
+ def to_index_range(self, dataset_size: int) -> IndexRange:
46
+ partition_size = dataset_size // self.num_partitions
47
+ start = self.index * partition_size
48
+
49
+ # For the last partition, extend to the end of the dataset to include remainder rows
50
+ if self.index == self.num_partitions - 1:
51
+ end = dataset_size - 1
52
+ else:
53
+ end = ((self.index + 1) * partition_size) - 1
54
+ return IndexRange(start=start, end=end)
55
+
56
+
57
+ class SeedConfig(ConfigBase):
58
+ """Configuration for sampling data from a seed dataset.
59
+
60
+ Args:
61
+ source: A SeedSource defining where the seed data exists
62
+ sampling_strategy: Strategy for how to sample rows from the dataset.
63
+ - ORDERED: Read rows sequentially in their original order.
64
+ - SHUFFLE: Randomly shuffle rows before sampling. When used with
65
+ selection_strategy, shuffling occurs within the selected range/partition.
66
+ selection_strategy: Optional strategy to select a subset of the dataset.
67
+ - IndexRange: Select a specific range of indices (e.g., rows 100-200).
68
+ - PartitionBlock: Select a partition by splitting the dataset into N equal parts.
69
+ Partition indices are zero-based (index=0 is the first partition, index=1 is
70
+ the second, etc.).
71
+
72
+ Examples:
73
+ Read rows sequentially from start to end:
74
+ SeedConfig(
75
+ source=LocalFileSeedSource(path="my_data.parquet"),
76
+ sampling_strategy=SamplingStrategy.ORDERED
77
+ )
78
+
79
+ Read rows in random order:
80
+ SeedConfig(
81
+ source=LocalFileSeedSource(path="my_data.parquet"),
82
+ sampling_strategy=SamplingStrategy.SHUFFLE
83
+ )
84
+
85
+ Read specific index range (rows 100-199):
86
+ SeedConfig(
87
+ source=LocalFileSeedSource(path="my_data.parquet"),
88
+ sampling_strategy=SamplingStrategy.ORDERED,
89
+ selection_strategy=IndexRange(start=100, end=199)
90
+ )
91
+
92
+ Read random rows from a specific index range (shuffles within rows 100-199):
93
+ SeedConfig(
94
+ source=LocalFileSeedSource(path="my_data.parquet"),
95
+ sampling_strategy=SamplingStrategy.SHUFFLE,
96
+ selection_strategy=IndexRange(start=100, end=199)
97
+ )
98
+
99
+ Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
100
+ SeedConfig(
101
+ source=LocalFileSeedSource(path="my_data.parquet"),
102
+ sampling_strategy=SamplingStrategy.ORDERED,
103
+ selection_strategy=PartitionBlock(index=2, num_partitions=5)
104
+ )
105
+
106
+ Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
107
+ SeedConfig(
108
+ source=LocalFileSeedSource(path="my_data.parquet"),
109
+ sampling_strategy=SamplingStrategy.SHUFFLE,
110
+ selection_strategy=PartitionBlock(index=0, num_partitions=10)
111
+ )
112
+ """
113
+
114
+ source: SeedSourceT
115
+ sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
116
+ selection_strategy: IndexRange | PartitionBlock | None = None
@@ -0,0 +1,84 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC
7
+ from typing import TYPE_CHECKING, Literal
8
+
9
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
10
+ from pydantic.json_schema import SkipJsonSchema
11
+ from typing_extensions import Self
12
+
13
+ from data_designer.config.utils.io_helpers import (
14
+ VALID_DATASET_FILE_EXTENSIONS,
15
+ validate_dataset_file_path,
16
+ validate_path_contains_files_of_type,
17
+ )
18
+ from data_designer.lazy_heavy_imports import pd
19
+
20
+ if TYPE_CHECKING:
21
+ import pandas as pd
22
+
23
+
24
+ class SeedSource(BaseModel, ABC):
25
+ """Base class for seed dataset configurations.
26
+
27
+ All subclasses must define a `seed_type` field with a Literal value.
28
+ This serves as a discriminated union discriminator.
29
+ """
30
+
31
+ seed_type: str
32
+
33
+
34
+ class LocalFileSeedSource(SeedSource):
35
+ seed_type: Literal["local"] = "local"
36
+
37
+ path: str
38
+
39
+ @field_validator("path", mode="after")
40
+ def validate_path(cls, v: str) -> str:
41
+ valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
42
+ if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
43
+ parts = v.split("*.")
44
+ file_path = parts[0]
45
+ file_extension = parts[-1]
46
+ validate_path_contains_files_of_type(file_path, file_extension)
47
+ else:
48
+ validate_dataset_file_path(v)
49
+ return v
50
+
51
+ @classmethod
52
+ def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self:
53
+ df.to_parquet(path, index=False)
54
+ return cls(path=path)
55
+
56
+
57
+ class HuggingFaceSeedSource(SeedSource):
58
+ seed_type: Literal["hf"] = "hf"
59
+
60
+ path: str = Field(
61
+ ...,
62
+ description=(
63
+ "Path to the seed data in HuggingFace. Wildcards are allowed. Examples include "
64
+ "'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', "
65
+ "and 'datasets/my-username/my-dataset/**/*.parquet'"
66
+ ),
67
+ )
68
+ token: str | None = None
69
+ endpoint: str = "https://huggingface.co"
70
+
71
+
72
+ class DataFrameSeedSource(SeedSource):
73
+ seed_type: Literal["df"] = "df"
74
+
75
+ model_config = ConfigDict(arbitrary_types_allowed=True)
76
+
77
+ df: SkipJsonSchema[pd.DataFrame] = Field(
78
+ ...,
79
+ exclude=True,
80
+ description=(
81
+ "DataFrame to use directly as the seed dataset. NOTE: if you need to write a Data Designer config, "
82
+ "you must use `LocalFileSeedSource` instead, since DataFrame objects are not serializable."
83
+ ),
84
+ )
@@ -0,0 +1,19 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Annotated
7
+
8
+ from pydantic import Field
9
+ from typing_extensions import TypeAlias
10
+
11
+ from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource, LocalFileSeedSource
12
+ from data_designer.plugin_manager import PluginManager
13
+
14
+ plugin_manager = PluginManager()
15
+
16
+ _SeedSourceT: TypeAlias = LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource
17
+ _SeedSourceT = plugin_manager.inject_into_seed_source_type_union(_SeedSourceT)
18
+
19
+ SeedSourceT = Annotated[_SeedSourceT, Field(discriminator="seed_type")]
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Testing utilities for the config package."""
5
+
6
+ from __future__ import annotations
@@ -0,0 +1,308 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Pytest fixtures for config testing."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import tarfile
10
+ import tempfile
11
+ import textwrap
12
+ from typing import TYPE_CHECKING
13
+
14
+ import pytest
15
+ import yaml
16
+
17
+ from data_designer.config.analysis.column_statistics import GeneralColumnStatistics
18
+ from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
19
+ from data_designer.config.column_configs import SamplerColumnConfig
20
+ from data_designer.config.config_builder import DataDesignerConfigBuilder
21
+ from data_designer.config.data_designer_config import DataDesignerConfig
22
+ from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
23
+ from data_designer.lazy_heavy_imports import pd
24
+
25
+ if TYPE_CHECKING:
26
+ import pandas as pd
27
+
28
+
29
+ @pytest.fixture
30
+ def stub_data_designer_config_str() -> str:
31
+ return """
32
+ model_configs:
33
+ - alias: my_own_code_model
34
+ model: openai/meta/llama-3.3-70b-instruct
35
+ inference_parameters:
36
+ temperature:
37
+ distribution_type: uniform
38
+ params:
39
+ low: 0.5
40
+ high: 0.9
41
+ top_p:
42
+ distribution_type: manual
43
+ params:
44
+ values: [0.1, 0.2, 0.33]
45
+ weights: [0.3, 0.2, 0.50]
46
+
47
+ seed_config:
48
+ source:
49
+ seed_type: hf
50
+ path: datasets/test-repo/testing/data.csv
51
+ sampling_strategy: shuffle
52
+
53
+ columns:
54
+ - name: code_id
55
+ sampler_type: uuid
56
+ column_type: sampler
57
+ params:
58
+ prefix: code_
59
+ short_form: true
60
+ uppercase: true
61
+ - name: age
62
+ sampler_type: uniform
63
+ column_type: sampler
64
+ params:
65
+ low: 35
66
+ high: 88
67
+ - name: domain
68
+ sampler_type: category
69
+ column_type: sampler
70
+ params:
71
+ values: [Healthcare, Finance, Education, Government]
72
+ - name: topic
73
+ sampler_type: category
74
+ column_type: sampler
75
+ params:
76
+ values: [Web Development, Data Science, Machine Learning, Cloud Computing]
77
+ - name: text
78
+ column_type: llm-text
79
+ prompt: Write a description of python code in topic {topic} and domain {domain}
80
+ model_alias: my_own_code_model
81
+ - name: code
82
+ column_type: llm-code
83
+ prompt: Write Python code that will be paired with the following prompt {text}
84
+ model_alias: my_own_code_model
85
+ code_lang: python
86
+ - name: code_validation_result
87
+ column_type: validation
88
+ target_columns:
89
+ - code
90
+ validator_type: code
91
+ validator_params:
92
+ code_lang: python
93
+ - name: code_judge_result
94
+ model_alias: my_own_code_model
95
+ column_type: llm-judge
96
+ prompt: You are an expert in Python programming and make appropriate judgement on the quality of the code.
97
+ scores:
98
+ - name: Pythonic
99
+ description: Pythonic Code and Best Practices (Does the code follow Python conventions and best practices?)
100
+ options:
101
+ "4": The code exemplifies Pythonic principles, making excellent use of Python-specific constructs, standard library modules and programming idioms; follows all relevant PEPs.
102
+ "3": The code closely follows Python conventions and adheres to many best practices; good use of Python-specific constructs, standard library modules and programming idioms.
103
+ "2": The code generally follows Python conventions but has room for better alignment with Pythonic practices.
104
+ "1": The code loosely follows Python conventions, with several deviations from best practices.
105
+ "0": The code does not follow Python conventions or best practices, using non-Pythonic approaches.
106
+ - name: Readability
107
+ description: Readability and Maintainability (Is the Python code easy to understand and maintain?)
108
+ options:
109
+ "4": The code is excellently formatted, follows PEP 8 guidelines, is elegantly concise and clear, uses meaningful variable names, ensuring high readability and ease of maintenance; organizes complex logic well. Docstrings are given in a Google Docstring format.
110
+ "3": The code is well-formatted in the sense of code-as-documentation, making it relatively easy to understand and maintain; uses descriptive names and organizes logic clearly.
111
+ "2": The code is somewhat readable with basic formatting and some comments, but improvements are needed; needs better use of descriptive names and organization.
112
+ "1": The code has minimal formatting, making it hard to understand; lacks meaningful names and organization.
113
+ "0": The code is unreadable, with no attempt at formatting or description.
114
+
115
+ constraints:
116
+ - target_column: age
117
+ operator: "lt"
118
+ rhs: 65
119
+ """
120
+
121
+
122
+ @pytest.fixture
123
+ def stub_data_designer_builder_config_str(stub_data_designer_config_str: str) -> str:
124
+ return f"""
125
+ data_designer:
126
+ {textwrap.indent(stub_data_designer_config_str, prefix=" ")}
127
+ """
128
+
129
+
130
+ @pytest.fixture
131
+ def stub_data_designer_config(stub_data_designer_config_str: str) -> DataDesignerConfig:
132
+ json_config = yaml.safe_load(stub_data_designer_config_str)
133
+ return DataDesignerConfig.model_validate(json_config)
134
+
135
+
136
+ @pytest.fixture
137
+ def stub_model_configs() -> list[ModelConfig]:
138
+ return [
139
+ ModelConfig(
140
+ alias="stub-model",
141
+ model="stub-model",
142
+ inference_parameters=ChatCompletionInferenceParams(
143
+ temperature=0.9,
144
+ top_p=0.9,
145
+ max_tokens=2048,
146
+ ),
147
+ )
148
+ ]
149
+
150
+
151
+ @pytest.fixture
152
+ def stub_model_providers() -> list[ModelProvider]:
153
+ return [
154
+ ModelProvider(
155
+ name="provider-1",
156
+ endpoint="https://api.provider-1.com/v1",
157
+ api_key="PROVIDER_1_API_KEY",
158
+ )
159
+ ]
160
+
161
+
162
+ @pytest.fixture
163
+ def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder:
164
+ """Test builder with model configs."""
165
+ return DataDesignerConfigBuilder(model_configs=stub_model_configs)
166
+
167
+
168
+ @pytest.fixture
169
+ def stub_complete_builder(stub_data_designer_builder_config_str: str) -> DataDesignerConfigBuilder:
170
+ return DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str)
171
+
172
+
173
+ @pytest.fixture
174
+ def stub_dataframe() -> pd.DataFrame:
175
+ return pd.DataFrame(
176
+ {
177
+ "name": ["John", "Jane", "Jim", "Jill", "Mike", "Mary", "Mark", "Martha", "Alex", "Alice", "Bob", "Bella"],
178
+ "age": [25, 30, 35, 40, 45, 50, 55, 60, 22, 28, 65, 38],
179
+ "city": [
180
+ "New York",
181
+ "Los Angeles",
182
+ "Chicago",
183
+ "Houston",
184
+ "Miami",
185
+ "Seattle",
186
+ "San Francisco",
187
+ "Boston",
188
+ "Denver",
189
+ "Austin",
190
+ "Portland",
191
+ "Atlanta",
192
+ ],
193
+ "state": ["NY", "CA", "IL", "TX", "FL", "WA", "CA", "MA", "CO", "TX", "OR", "GA"],
194
+ "zip": [
195
+ "10001",
196
+ "90001",
197
+ "60601",
198
+ "77001",
199
+ "33101",
200
+ "98101",
201
+ "94101",
202
+ "02101",
203
+ "80201",
204
+ "73301",
205
+ "97201",
206
+ "30301",
207
+ ],
208
+ "email": [
209
+ "john@example.com",
210
+ "jane@example.com",
211
+ "jim@example.com",
212
+ "jill@example.com",
213
+ "mike@example.com",
214
+ "mary@example.com",
215
+ "mark@example.com",
216
+ "martha@example.com",
217
+ "alex.smith@example.co.uk",
218
+ "alice.wu@example.ca",
219
+ "bob.martin@example.org",
220
+ "bella.rossi@example.it",
221
+ ],
222
+ "phone": [
223
+ "123-456-7890",
224
+ "213-555-1234",
225
+ "312-222-3333",
226
+ "713-444-5555",
227
+ "305-888-9999",
228
+ "206-777-8888",
229
+ "415-999-0000",
230
+ "617-111-2222",
231
+ "+44 20 7946 0958",
232
+ "+1-416-555-0199",
233
+ "+39 06 6982 1234",
234
+ "+49 30 123456",
235
+ ],
236
+ "address": [
237
+ "123 Main St",
238
+ "456 Oak Ave",
239
+ "789 Pine Rd",
240
+ "101 Maple Blvd",
241
+ "202 Elm St",
242
+ "303 Cedar Ave",
243
+ "404 Spruce Dr",
244
+ "505 Birch Ln",
245
+ "12 Baker St",
246
+ "88 King St W",
247
+ "Via Roma 1",
248
+ "Unter den Linden 5",
249
+ ],
250
+ }
251
+ )
252
+
253
+
254
+ @pytest.fixture
255
+ def stub_dataset_tar_file():
256
+ with tempfile.TemporaryDirectory() as temp_dir:
257
+ # Create valid parquet files with actual data
258
+ df1 = pd.DataFrame({"id": ["1", "2"], "name": ["test", "sample"]})
259
+ df2 = pd.DataFrame({"id": ["3", "4"], "name": ["data", "example"]})
260
+
261
+ # Write parquet files
262
+ os.makedirs(temp_dir + "/dataset", exist_ok=True)
263
+ df1.to_parquet(temp_dir + "/dataset/dataset-001.parquet", index=False)
264
+ df2.to_parquet(temp_dir + "/dataset/dataset-002.parquet", index=False)
265
+
266
+ # Create tar file
267
+ tar_path = temp_dir + "/dataset.tar"
268
+ with tarfile.open(tar_path, "w:gz") as tar:
269
+ tar.add(temp_dir + "/dataset/dataset-001.parquet", arcname="dataset/dataset-001.parquet")
270
+ tar.add(temp_dir + "/dataset/dataset-002.parquet", arcname="dataset/dataset-002.parquet")
271
+ with open(tar_path, "rb") as tar_file:
272
+ yield tar_file
273
+
274
+
275
+ @pytest.fixture
276
+ def stub_dataset_profiler_results() -> DatasetProfilerResults:
277
+ stub_column_statistics = GeneralColumnStatistics(
278
+ column_name="some",
279
+ num_records=1,
280
+ num_unique=1,
281
+ num_null=0,
282
+ pyarrow_dtype="string",
283
+ simple_dtype="string",
284
+ )
285
+ return DatasetProfilerResults(
286
+ num_records=1,
287
+ target_num_records=100,
288
+ column_statistics=[stub_column_statistics],
289
+ side_effect_column_names=None,
290
+ column_profiles=None,
291
+ )
292
+
293
+
294
+ @pytest.fixture
295
+ def stub_sampler_only_config_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder:
296
+ config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
297
+ config_builder.add_column(
298
+ SamplerColumnConfig(
299
+ name="uuid", sampler_type="uuid", params={"prefix": "code_", "short_form": True, "uppercase": True}
300
+ )
301
+ )
302
+ config_builder.add_column(
303
+ SamplerColumnConfig(name="category", sampler_type="category", params={"values": ["a", "b", "c"]})
304
+ )
305
+ config_builder.add_column(
306
+ SamplerColumnConfig(name="uniform", sampler_type="uniform", params={"low": 1, "high": 100})
307
+ )
308
+ return config_builder
@@ -0,0 +1,93 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from enum import Enum
7
+
8
+
9
+ class CodeLang(str, Enum):
10
+ BASH = "bash"
11
+ C = "c"
12
+ COBOL = "cobol"
13
+ CPP = "cpp"
14
+ CSHARP = "csharp"
15
+ GO = "go"
16
+ JAVA = "java"
17
+ JAVASCRIPT = "javascript"
18
+ KOTLIN = "kotlin"
19
+ PYTHON = "python"
20
+ RUBY = "ruby"
21
+ RUST = "rust"
22
+ SCALA = "scala"
23
+ SWIFT = "swift"
24
+ TYPESCRIPT = "typescript"
25
+ SQL_SQLITE = "sql:sqlite"
26
+ SQL_TSQL = "sql:tsql"
27
+ SQL_BIGQUERY = "sql:bigquery"
28
+ SQL_MYSQL = "sql:mysql"
29
+ SQL_POSTGRES = "sql:postgres"
30
+ SQL_ANSI = "sql:ansi"
31
+
32
+ @staticmethod
33
+ def parse(value: str | CodeLang) -> tuple[str, str | None]:
34
+ value = value.value if isinstance(value, CodeLang) else value
35
+ split_vals = value.split(":")
36
+ return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
37
+
38
+ @staticmethod
39
+ def parse_lang(value: str | CodeLang) -> str:
40
+ return CodeLang.parse(value)[0]
41
+
42
+ @staticmethod
43
+ def parse_dialect(value: str | CodeLang) -> str | None:
44
+ return CodeLang.parse(value)[1]
45
+
46
+ @staticmethod
47
+ def supported_values() -> set[str]:
48
+ return {lang.value for lang in CodeLang}
49
+
50
+
51
+ SQL_DIALECTS: set[CodeLang] = {
52
+ CodeLang.SQL_SQLITE,
53
+ CodeLang.SQL_TSQL,
54
+ CodeLang.SQL_BIGQUERY,
55
+ CodeLang.SQL_MYSQL,
56
+ CodeLang.SQL_POSTGRES,
57
+ CodeLang.SQL_ANSI,
58
+ }
59
+
60
+ ##########################################################
61
+ # Helper functions
62
+ ##########################################################
63
+
64
+
65
+ def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
66
+ """Convert the code language to a syntax lexer for Pygments.
67
+
68
+ Reference: https://pygments.org/docs/lexers/
69
+ """
70
+ code_lang_to_lexer = {
71
+ CodeLang.BASH: "bash",
72
+ CodeLang.C: "c",
73
+ CodeLang.COBOL: "cobol",
74
+ CodeLang.CPP: "cpp",
75
+ CodeLang.CSHARP: "csharp",
76
+ CodeLang.GO: "golang",
77
+ CodeLang.JAVA: "java",
78
+ CodeLang.JAVASCRIPT: "javascript",
79
+ CodeLang.KOTLIN: "kotlin",
80
+ CodeLang.PYTHON: "python",
81
+ CodeLang.RUBY: "ruby",
82
+ CodeLang.RUST: "rust",
83
+ CodeLang.SCALA: "scala",
84
+ CodeLang.SWIFT: "swift",
85
+ CodeLang.TYPESCRIPT: "typescript",
86
+ CodeLang.SQL_SQLITE: "sql",
87
+ CodeLang.SQL_ANSI: "sql",
88
+ CodeLang.SQL_TSQL: "tsql",
89
+ CodeLang.SQL_BIGQUERY: "sql",
90
+ CodeLang.SQL_MYSQL: "mysql",
91
+ CodeLang.SQL_POSTGRES: "postgres",
92
+ }
93
+ return code_lang_to_lexer.get(code_lang, code_lang)