data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,427 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections import OrderedDict
7
+ from enum import Enum
8
+ from functools import cached_property
9
+ import json
10
+ import os
11
+ from typing import TYPE_CHECKING, Optional, Union
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ from rich.console import Console, Group
16
+ from rich.padding import Padding
17
+ from rich.panel import Panel
18
+ from rich.pretty import Pretty
19
+ from rich.rule import Rule
20
+ from rich.syntax import Syntax
21
+ from rich.table import Table
22
+ from rich.text import Text
23
+
24
+ from ..base import ConfigBase
25
+ from ..column_types import DataDesignerColumnType
26
+ from ..models import ModelConfig, ModelProvider
27
+ from ..sampler_params import SamplerType
28
+ from .code_lang import code_lang_to_syntax_lexer
29
+ from .constants import NVIDIA_API_KEY_ENV_VAR_NAME, OPENAI_API_KEY_ENV_VAR_NAME
30
+ from .errors import DatasetSampleDisplayError
31
+
32
+ if TYPE_CHECKING:
33
+ from ..config_builder import DataDesignerConfigBuilder
34
+
35
+
36
+ console = Console()
37
+
38
+
39
+ def get_nvidia_api_key() -> Optional[str]:
40
+ return os.getenv(NVIDIA_API_KEY_ENV_VAR_NAME)
41
+
42
+
43
+ def get_openai_api_key() -> Optional[str]:
44
+ return os.getenv(OPENAI_API_KEY_ENV_VAR_NAME)
45
+
46
+
47
+ class ColorPalette(str, Enum):
48
+ NVIDIA_GREEN = "#76b900"
49
+ PURPLE = "#9525c6"
50
+ YELLOW = "#f9c500"
51
+ BLUE = "#0074df"
52
+ RED = "#e52020"
53
+ ORANGE = "#ef9100"
54
+ MAGENTA = "#d2308e"
55
+ TEAL = "#1dbba4"
56
+
57
+
58
+ class WithRecordSamplerMixin:
59
+ _display_cycle_index: int = 0
60
+
61
+ @cached_property
62
+ def _record_sampler_dataset(self) -> pd.DataFrame:
63
+ if hasattr(self, "dataset") and self.dataset is not None and isinstance(self.dataset, pd.DataFrame):
64
+ return self.dataset
65
+ elif (
66
+ hasattr(self, "load_dataset")
67
+ and callable(self.load_dataset)
68
+ and (dataset := self.load_dataset()) is not None
69
+ and isinstance(dataset, pd.DataFrame)
70
+ ):
71
+ return dataset
72
+ else:
73
+ raise DatasetSampleDisplayError("No valid dataset found in results object.")
74
+
75
+ def display_sample_record(
76
+ self,
77
+ index: Optional[int] = None,
78
+ *,
79
+ hide_seed_columns: bool = False,
80
+ syntax_highlighting_theme: str = "dracula",
81
+ background_color: Optional[str] = None,
82
+ ) -> None:
83
+ """Display a sample record from the Data Designer dataset preview.
84
+
85
+ Args:
86
+ index: Index of the record to display. If None, the next record will be displayed.
87
+ This is useful for running the cell in a notebook multiple times.
88
+ hide_seed_columns: If True, the columns from the seed dataset (if any) will not be displayed.
89
+ syntax_highlighting_theme: Theme to use for syntax highlighting. See the `Syntax`
90
+ documentation from `rich` for information about available themes.
91
+ background_color: Background color to use for the record. See the `Syntax`
92
+ documentation from `rich` for information about available background colors.
93
+ """
94
+ i = index or self._display_cycle_index
95
+
96
+ try:
97
+ record = self._record_sampler_dataset.iloc[i]
98
+ num_records = len(self._record_sampler_dataset)
99
+ except IndexError:
100
+ raise DatasetSampleDisplayError(f"Index {i} is out of bounds for dataset of length {num_records}.")
101
+
102
+ display_sample_record(
103
+ record=record,
104
+ config_builder=self._config_builder,
105
+ background_color=background_color,
106
+ syntax_highlighting_theme=syntax_highlighting_theme,
107
+ hide_seed_columns=hide_seed_columns,
108
+ record_index=i,
109
+ )
110
+ if index is None:
111
+ self._display_cycle_index = (self._display_cycle_index + 1) % num_records
112
+
113
+
114
+ def create_rich_histogram_table(
115
+ data: dict[str, Union[int, float]],
116
+ column_names: tuple[int, int],
117
+ name_style: str = ColorPalette.BLUE.value,
118
+ value_style: str = ColorPalette.TEAL.value,
119
+ title: Optional[str] = None,
120
+ **kwargs,
121
+ ) -> Table:
122
+ table = Table(title=title, **kwargs)
123
+ table.add_column(column_names[0], justify="right", style=name_style)
124
+ table.add_column(column_names[1], justify="left", style=value_style)
125
+
126
+ max_count = max(data.values())
127
+ for name, value in data.items():
128
+ bar = "" if max_count <= 0 else "█" * int((value / max_count) * 20)
129
+ table.add_row(str(name), f"{bar} {value:.1f}")
130
+
131
+ return table
132
+
133
+
134
+ def display_sample_record(
135
+ record: Union[dict, pd.Series, pd.DataFrame],
136
+ config_builder: DataDesignerConfigBuilder,
137
+ background_color: Optional[str] = None,
138
+ syntax_highlighting_theme: str = "dracula",
139
+ record_index: Optional[int] = None,
140
+ hide_seed_columns: bool = False,
141
+ ):
142
+ if isinstance(record, (dict, pd.Series)):
143
+ record = pd.DataFrame([record]).iloc[0]
144
+ elif isinstance(record, pd.DataFrame):
145
+ if record.shape[0] > 1:
146
+ raise DatasetSampleDisplayError(
147
+ f"The record must be a single record. You provided a DataFrame with {record.shape[0]} records."
148
+ )
149
+ record = record.iloc[0]
150
+ else:
151
+ raise DatasetSampleDisplayError(
152
+ "The record must be a single record in a dictionary, pandas Series, "
153
+ f"or pandas DataFrame. You provided: {type(record)}."
154
+ )
155
+
156
+ render_list = []
157
+ table_kws = dict(show_lines=True, expand=True)
158
+
159
+ seed_columns = config_builder.get_columns_of_type(DataDesignerColumnType.SEED_DATASET)
160
+ if not hide_seed_columns and len(seed_columns) > 0:
161
+ table = Table(title="Seed Columns", **table_kws)
162
+ table.add_column("Name")
163
+ table.add_column("Value")
164
+ for col in seed_columns:
165
+ if not col.drop:
166
+ table.add_row(col.name, convert_to_row_element(record[col.name]))
167
+ render_list.append(pad_console_element(table))
168
+
169
+ non_code_columns = (
170
+ config_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)
171
+ + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
172
+ + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
173
+ + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
174
+ )
175
+ if len(non_code_columns) > 0:
176
+ table = Table(title="Generated Columns", **table_kws)
177
+ table.add_column("Name")
178
+ table.add_column("Value")
179
+ for col in non_code_columns:
180
+ if not col.drop:
181
+ table.add_row(col.name, convert_to_row_element(record[col.name]))
182
+ render_list.append(pad_console_element(table))
183
+
184
+ for col in config_builder.get_columns_of_type(DataDesignerColumnType.LLM_CODE):
185
+ panel = Panel(
186
+ Syntax(
187
+ record[col.name],
188
+ lexer=code_lang_to_syntax_lexer(col.code_lang),
189
+ theme=syntax_highlighting_theme,
190
+ word_wrap=True,
191
+ background_color=background_color,
192
+ ),
193
+ title=col.name,
194
+ expand=True,
195
+ )
196
+ render_list.append(pad_console_element(panel))
197
+
198
+ validation_columns = config_builder.get_columns_of_type(DataDesignerColumnType.VALIDATION)
199
+ if len(validation_columns) > 0:
200
+ table = Table(title="Validation", **table_kws)
201
+ table.add_column("Name")
202
+ table.add_column("Value", ratio=1)
203
+ for col in validation_columns:
204
+ if not col.drop:
205
+ # Add is_valid before other fields
206
+ if "is_valid" in record[col.name]:
207
+ value_to_display = {"is_valid": record[col.name].get("is_valid")} | record[col.name]
208
+ else: # if columns treated separately
209
+ value_to_display = {}
210
+ for col_name, validation_output in record[col.name].items():
211
+ value_to_display[col_name] = {
212
+ "is_valid": validation_output.get("is_valid", None)
213
+ } | validation_output
214
+
215
+ table.add_row(col.name, convert_to_row_element(value_to_display))
216
+ render_list.append(pad_console_element(table, (1, 0, 1, 0)))
217
+
218
+ llm_judge_columns = config_builder.get_columns_of_type(DataDesignerColumnType.LLM_JUDGE)
219
+ if len(llm_judge_columns) > 0:
220
+ for col in llm_judge_columns:
221
+ if col.drop:
222
+ continue
223
+ table = Table(title=f"LLM-as-a-Judge: {col.name}", **table_kws)
224
+ row = []
225
+ judge = record[col.name]
226
+
227
+ for measure, results in judge.items():
228
+ table.add_column(measure)
229
+ row.append(f"score: {results['score']}\nreasoning: {results['reasoning']}")
230
+ table.add_row(*row)
231
+ render_list.append(pad_console_element(table, (1, 0, 1, 0)))
232
+
233
+ if record_index is not None:
234
+ index_label = Text(f"[index: {record_index}]", justify="center")
235
+ render_list.append(index_label)
236
+
237
+ console.print(Group(*render_list), markup=False)
238
+
239
+
240
+ def display_sampler_table(
241
+ sampler_params: dict[SamplerType, ConfigBase],
242
+ title: Optional[str] = None,
243
+ ) -> None:
244
+ table = Table(expand=True)
245
+ table.add_column("Type")
246
+ table.add_column("Parameter")
247
+ table.add_column("Data Type")
248
+ table.add_column("Required", justify="center")
249
+ table.add_column("Constraints")
250
+
251
+ for sampler_type, params in sampler_params.items():
252
+ num = 0
253
+ schema = params.model_json_schema()
254
+ for param_name, field_info in schema["properties"].items():
255
+ is_required = param_name in schema.get("required", [])
256
+ table.add_row(
257
+ sampler_type if num == 0 else "",
258
+ param_name,
259
+ _get_field_type(field_info),
260
+ "✓" if is_required else "",
261
+ _get_field_constraints(field_info, schema),
262
+ )
263
+ num += 1
264
+ table.add_section()
265
+
266
+ title = title or "NeMo Data Designer Samplers"
267
+
268
+ group = Group(Rule(title, end="\n\n"), table)
269
+ console.print(group)
270
+
271
+
272
+ def display_model_configs_table(model_configs: list[ModelConfig]) -> None:
273
+ table_model_configs = Table(expand=True)
274
+ table_model_configs.add_column("Alias")
275
+ table_model_configs.add_column("Model")
276
+ table_model_configs.add_column("Provider")
277
+ table_model_configs.add_column("Temperature")
278
+ table_model_configs.add_column("Top P")
279
+ for model_config in model_configs:
280
+ table_model_configs.add_row(
281
+ model_config.alias,
282
+ model_config.model,
283
+ model_config.provider,
284
+ str(model_config.inference_parameters.temperature),
285
+ str(model_config.inference_parameters.top_p),
286
+ )
287
+ group_args: list = [Rule(title="Model Configs"), table_model_configs]
288
+ if len(model_configs) == 0:
289
+ subtitle = Text(
290
+ "‼️ No model configs found. Please provide at least one model config to the config builder",
291
+ style="dim",
292
+ justify="center",
293
+ )
294
+ group_args.insert(1, subtitle)
295
+ group = Group(*group_args)
296
+ console.print(group)
297
+
298
+
299
+ def display_model_providers_table(model_providers: list[ModelProvider]) -> None:
300
+ table_model_providers = Table(expand=True)
301
+ table_model_providers.add_column("Name")
302
+ table_model_providers.add_column("Endpoint")
303
+ table_model_providers.add_column("API Key")
304
+ for model_provider in model_providers:
305
+ api_key = model_provider.api_key
306
+ if model_provider.api_key == OPENAI_API_KEY_ENV_VAR_NAME:
307
+ if get_openai_api_key() is not None:
308
+ api_key = mask_api_key(get_openai_api_key())
309
+ else:
310
+ api_key = f"* {OPENAI_API_KEY_ENV_VAR_NAME!r} not set in environment variables * "
311
+ elif model_provider.api_key == NVIDIA_API_KEY_ENV_VAR_NAME:
312
+ if get_nvidia_api_key() is not None:
313
+ api_key = mask_api_key(get_nvidia_api_key())
314
+ else:
315
+ api_key = f"* {NVIDIA_API_KEY_ENV_VAR_NAME!r} not set in environment variables *"
316
+ else:
317
+ api_key = mask_api_key(model_provider.api_key)
318
+ table_model_providers.add_row(model_provider.name, model_provider.endpoint, api_key)
319
+ group = Group(Rule(title="Model Providers"), table_model_providers)
320
+ console.print(group)
321
+
322
+
323
+ def mask_api_key(api_key: str | None) -> str:
324
+ """Mask API keys for display.
325
+
326
+ Environment variable names (all uppercase) are kept visible.
327
+ Actual API keys are masked to show only the last 4 characters.
328
+
329
+ Args:
330
+ api_key: The API key to mask.
331
+
332
+ Returns:
333
+ Masked API key string or "(not set)" if None.
334
+ """
335
+ if not api_key:
336
+ return "(not set)"
337
+
338
+ # Keep environment variable names visible
339
+ if api_key.isupper():
340
+ return api_key
341
+
342
+ # Mask actual API keys
343
+ return "***" + api_key[-4:] if len(api_key) > 4 else "***"
344
+
345
+
346
+ def convert_to_row_element(elem):
347
+ try:
348
+ elem = Pretty(json.loads(elem))
349
+ except (TypeError, json.JSONDecodeError):
350
+ pass
351
+ if isinstance(elem, (np.integer, np.floating, np.ndarray)):
352
+ elem = str(elem)
353
+ elif isinstance(elem, (list, dict)):
354
+ elem = Pretty(elem)
355
+ return elem
356
+
357
+
358
+ def pad_console_element(elem, padding=(1, 0, 1, 0)):
359
+ return Padding(elem, padding)
360
+
361
+
362
+ def _get_field_type(field: dict) -> str:
363
+ """Extract human-readable type information from a JSON Schema field."""
364
+
365
+ # single type
366
+ if "type" in field:
367
+ if field["type"] == "array":
368
+ return " | ".join([f"{f.strip()}[]" for f in _get_field_type(field["items"]).split("|")])
369
+ if field["type"] == "object":
370
+ return "dict"
371
+ return field["type"]
372
+
373
+ # union type
374
+ elif "anyOf" in field:
375
+ types = []
376
+ for f in field["anyOf"]:
377
+ if "$ref" in f:
378
+ types.append("enum")
379
+ elif f.get("type") == "array":
380
+ if "items" in f and "$ref" in f["items"]:
381
+ types.append("enum[]")
382
+ else:
383
+ types.append(f"{f['items']['type']}[]")
384
+ else:
385
+ types.append(f.get("type", ""))
386
+ return " | ".join(t for t in types if t)
387
+
388
+ return ""
389
+
390
+
391
+ def _get_field_constraints(field: dict, schema: dict) -> str:
392
+ """Extract human-readable constraints from a JSON Schema field."""
393
+ constraints = []
394
+
395
+ # numeric constraints
396
+ if "minimum" in field:
397
+ constraints.append(f">= {field['minimum']}")
398
+ if "exclusiveMinimum" in field:
399
+ constraints.append(f"> {field['exclusiveMinimum']}")
400
+ if "maximum" in field:
401
+ constraints.append(f"<= {field['maximum']}")
402
+ if "exclusiveMaximum" in field:
403
+ constraints.append(f"< {field['exclusiveMaximum']}")
404
+
405
+ # string constraints
406
+ if "minLength" in field:
407
+ constraints.append(f"len > {field['minLength']}")
408
+ if "maxLength" in field:
409
+ constraints.append(f"len < {field['maxLength']}")
410
+
411
+ # array constraints
412
+ if "minItems" in field:
413
+ constraints.append(f"len > {field['minItems']}")
414
+ if "maxItems" in field:
415
+ constraints.append(f"len < {field['maxItems']}")
416
+
417
+ # enum constraints
418
+ if "enum" in _get_field_type(field) and "$defs" in schema:
419
+ enum_values = []
420
+ for defs in schema["$defs"].values():
421
+ if "enum" in defs:
422
+ enum_values.extend(defs["enum"])
423
+ if len(enum_values) > 0:
424
+ enum_values = OrderedDict.fromkeys(enum_values)
425
+ constraints.append(f"allowed: {', '.join(enum_values.keys())}")
426
+
427
+ return ", ".join(constraints)
@@ -0,0 +1,96 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from enum import Enum
5
+ from typing import Any, Optional, Union
6
+
7
+ from pydantic import Field, field_serializer, model_validator
8
+ from typing_extensions import Self, TypeAlias
9
+
10
+ from .base import ConfigBase
11
+ from .utils.code_lang import SQL_DIALECTS, CodeLang
12
+
13
+ SUPPORTED_CODE_LANGUAGES = {CodeLang.PYTHON, *SQL_DIALECTS}
14
+
15
+
16
+ class ValidatorType(str, Enum):
17
+ CODE = "code"
18
+ LOCAL_CALLABLE = "local_callable"
19
+ REMOTE = "remote"
20
+
21
+
22
+ class CodeValidatorParams(ConfigBase):
23
+ """Configuration for code validation. Supports Python and SQL code validation.
24
+
25
+ Attributes:
26
+ code_lang: The language of the code to validate. Supported values include: `python`,
27
+ `sql:sqlite`, `sql:postgres`, `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`.
28
+ """
29
+
30
+ code_lang: CodeLang = Field(description="The language of the code to validate")
31
+
32
+ @model_validator(mode="after")
33
+ def validate_code_lang(self) -> Self:
34
+ if self.code_lang not in SUPPORTED_CODE_LANGUAGES:
35
+ raise ValueError(
36
+ f"Unsupported code language, supported languages are: {[lang.value for lang in SUPPORTED_CODE_LANGUAGES]}"
37
+ )
38
+ return self
39
+
40
+
41
+ class LocalCallableValidatorParams(ConfigBase):
42
+ """Configuration for local callable validation. Expects a function to be passed that validates the data.
43
+
44
+ Attributes:
45
+ validation_function: Function (`Callable[[pd.DataFrame], pd.DataFrame]`) to validate the
46
+ data. Output must contain a column `is_valid` of type `bool`.
47
+ output_schema: The JSON schema for the local callable validator's output. If not provided,
48
+ the output will not be validated.
49
+ """
50
+
51
+ validation_function: Any = Field(
52
+ description="Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data"
53
+ )
54
+ output_schema: Optional[dict[str, Any]] = Field(
55
+ default=None, description="Expected schema for local callable validator's output"
56
+ )
57
+
58
+ @field_serializer("validation_function")
59
+ def serialize_validation_function(self, v: Any) -> Any:
60
+ return v.__name__
61
+
62
+ @model_validator(mode="after")
63
+ def validate_validation_function(self) -> Self:
64
+ if not callable(self.validation_function):
65
+ raise ValueError("Validation function must be a callable")
66
+ return self
67
+
68
+
69
+ class RemoteValidatorParams(ConfigBase):
70
+ """Configuration for remote validation. Sends data to a remote endpoint for validation.
71
+
72
+ Attributes:
73
+ endpoint_url: The URL of the remote endpoint.
74
+ output_schema: The JSON schema for the remote validator's output. If not provided,
75
+ the output will not be validated.
76
+ timeout: The timeout for the HTTP request in seconds. Defaults to 30.0.
77
+ max_retries: The maximum number of retry attempts. Defaults to 3.
78
+ retry_backoff: The backoff factor for the retry delay in seconds. Defaults to 2.0.
79
+ max_parallel_requests: The maximum number of parallel requests to make. Defaults to 4.
80
+ """
81
+
82
+ endpoint_url: str = Field(description="URL of the remote endpoint")
83
+ output_schema: Optional[dict[str, Any]] = Field(
84
+ default=None, description="Expected schema for remote validator's output"
85
+ )
86
+ timeout: float = Field(default=30.0, gt=0, description="The timeout for the HTTP request")
87
+ max_retries: int = Field(default=3, ge=0, description="The maximum number of retry attempts")
88
+ retry_backoff: float = Field(default=2.0, gt=1, description="The backoff factor for the retry delay")
89
+ max_parallel_requests: int = Field(default=4, ge=1, description="The maximum number of parallel requests to make")
90
+
91
+
92
+ ValidatorParamsT: TypeAlias = Union[
93
+ CodeValidatorParams,
94
+ LocalCallableValidatorParams,
95
+ RemoteValidatorParams,
96
+ ]
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,55 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from abc import ABC, abstractmethod
7
+ import logging
8
+
9
+ import pandas as pd
10
+ import pyarrow as pa
11
+ from pydantic import BaseModel, model_validator
12
+ from typing_extensions import Self
13
+
14
+ from data_designer.config.base import ConfigBase
15
+ from data_designer.config.column_configs import SingleColumnConfig
16
+ from data_designer.config.column_types import DataDesignerColumnType
17
+ from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, TaskConfigT
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ColumnConfigWithDataFrame(ConfigBase):
23
+ column_config: SingleColumnConfig
24
+ df: pd.DataFrame
25
+
26
+ @model_validator(mode="after")
27
+ def validate_column_exists(self) -> Self:
28
+ if self.column_config.name not in self.df.columns:
29
+ raise ValueError(f"Column {self.column_config.name!r} not found in DataFrame")
30
+ return self
31
+
32
+ @model_validator(mode="after")
33
+ def ensure_pyarrow_backend(self) -> Self:
34
+ if not all(isinstance(dtype, pd.ArrowDtype) for dtype in self.df.dtypes):
35
+ self.df = pa.Table.from_pandas(self.df).to_pandas(types_mapper=pd.ArrowDtype)
36
+ return self
37
+
38
+ def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
39
+ return (self.column_config, self.df)
40
+
41
+
42
+ class ColumnProfilerMetadata(ConfigurableTaskMetadata):
43
+ applicable_column_types: list[DataDesignerColumnType]
44
+
45
+
46
+ class ColumnProfiler(ConfigurableTask[TaskConfigT], ABC):
47
+ @staticmethod
48
+ @abstractmethod
49
+ def metadata() -> ColumnProfilerMetadata: ...
50
+
51
+ @abstractmethod
52
+ def profile(self, column_config_with_df: ColumnConfigWithDataFrame) -> BaseModel: ...
53
+
54
+ def _initialize(self) -> None:
55
+ logger.info(f"💫 Initializing column profiler: '{self.metadata().name}'")