data-designer 0.3.8rc2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +1 -7
  5. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -129
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -51
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -335
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -212
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc2.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc2.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc2.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,455 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
- from abc import ABC, abstractmethod
8
- from enum import Enum
9
- from pathlib import Path
10
- from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeVar
11
-
12
- from pydantic import BaseModel, Field, field_validator, model_validator
13
- from typing_extensions import Self, TypeAlias
14
-
15
- from data_designer.config.base import ConfigBase
16
- from data_designer.config.errors import InvalidConfigError
17
- from data_designer.config.utils.constants import (
18
- MAX_TEMPERATURE,
19
- MAX_TOP_P,
20
- MIN_TEMPERATURE,
21
- MIN_TOP_P,
22
- )
23
- from data_designer.config.utils.io_helpers import smart_load_yaml
24
- from data_designer.lazy_heavy_imports import np
25
-
26
- if TYPE_CHECKING:
27
- import numpy as np
28
-
29
- logger = logging.getLogger(__name__)
30
-
31
-
32
- class Modality(str, Enum):
33
- """Supported modality types for multimodal model data."""
34
-
35
- IMAGE = "image"
36
-
37
-
38
- class ModalityDataType(str, Enum):
39
- """Data type formats for multimodal data."""
40
-
41
- URL = "url"
42
- BASE64 = "base64"
43
-
44
-
45
- class ImageFormat(str, Enum):
46
- """Supported image formats for image modality."""
47
-
48
- PNG = "png"
49
- JPG = "jpg"
50
- JPEG = "jpeg"
51
- GIF = "gif"
52
- WEBP = "webp"
53
-
54
-
55
- class DistributionType(str, Enum):
56
- """Types of distributions for sampling inference parameters."""
57
-
58
- UNIFORM = "uniform"
59
- MANUAL = "manual"
60
-
61
-
62
- class ModalityContext(ABC, BaseModel):
63
- modality: Modality
64
- column_name: str
65
- data_type: ModalityDataType
66
-
67
- @abstractmethod
68
- def get_context(self, record: dict) -> dict[str, Any]: ...
69
-
70
-
71
- class ImageContext(ModalityContext):
72
- """Configuration for providing image context to multimodal models.
73
-
74
- Attributes:
75
- modality: The modality type (always "image").
76
- column_name: Name of the column containing image data.
77
- data_type: Format of the image data ("url" or "base64").
78
- image_format: Image format (required for base64 data).
79
- """
80
-
81
- modality: Modality = Modality.IMAGE
82
- image_format: ImageFormat | None = None
83
-
84
- def get_context(self, record: dict) -> dict[str, Any]:
85
- """Get the context for the image modality.
86
-
87
- Args:
88
- record: The record containing the image data.
89
-
90
- Returns:
91
- The context for the image modality.
92
- """
93
- context = dict(type="image_url")
94
- context_value = record[self.column_name]
95
- if self.data_type == ModalityDataType.URL:
96
- context["image_url"] = context_value
97
- else:
98
- context["image_url"] = {
99
- "url": f"data:image/{self.image_format.value};base64,{context_value}",
100
- "format": self.image_format.value,
101
- }
102
- return context
103
-
104
- @model_validator(mode="after")
105
- def _validate_image_format(self) -> Self:
106
- if self.data_type == ModalityDataType.BASE64 and self.image_format is None:
107
- raise ValueError(f"image_format is required when data_type is {self.data_type.value}")
108
- return self
109
-
110
-
111
- DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase)
112
-
113
-
114
- class Distribution(ABC, ConfigBase, Generic[DistributionParamsT]):
115
- distribution_type: DistributionType
116
- params: DistributionParamsT
117
-
118
- @abstractmethod
119
- def sample(self) -> float: ...
120
-
121
-
122
- class ManualDistributionParams(ConfigBase):
123
- """Parameters for manual distribution sampling.
124
-
125
- Attributes:
126
- values: List of possible values to sample from.
127
- weights: Optional list of weights for each value. If not provided, all values have equal probability.
128
- """
129
-
130
- values: list[float] = Field(min_length=1)
131
- weights: list[float] | None = None
132
-
133
- @model_validator(mode="after")
134
- def _normalize_weights(self) -> Self:
135
- if self.weights is not None:
136
- self.weights = [w / sum(self.weights) for w in self.weights]
137
- return self
138
-
139
- @model_validator(mode="after")
140
- def _validate_equal_lengths(self) -> Self:
141
- if self.weights and len(self.values) != len(self.weights):
142
- raise ValueError("`values` and `weights` must have the same length")
143
- return self
144
-
145
-
146
- class ManualDistribution(Distribution[ManualDistributionParams]):
147
- """Manual (discrete) distribution for sampling inference parameters.
148
-
149
- Samples from a discrete set of values with optional weights. Useful for testing
150
- specific values or creating custom probability distributions for temperature or top_p.
151
-
152
- Attributes:
153
- distribution_type: Type of distribution ("manual").
154
- params: Distribution parameters (values, weights).
155
- """
156
-
157
- distribution_type: DistributionType | None = "manual"
158
- params: ManualDistributionParams
159
-
160
- def sample(self) -> float:
161
- """Sample a value from the manual distribution.
162
-
163
- Returns:
164
- A float value sampled from the manual distribution.
165
- """
166
- return float(np.random.choice(self.params.values, p=self.params.weights))
167
-
168
-
169
- class UniformDistributionParams(ConfigBase):
170
- """Parameters for uniform distribution sampling.
171
-
172
- Attributes:
173
- low: Lower bound (inclusive).
174
- high: Upper bound (exclusive).
175
- """
176
-
177
- low: float
178
- high: float
179
-
180
- @model_validator(mode="after")
181
- def _validate_low_lt_high(self) -> Self:
182
- if self.low >= self.high:
183
- raise ValueError("`low` must be less than `high`")
184
- return self
185
-
186
-
187
- class UniformDistribution(Distribution[UniformDistributionParams]):
188
- """Uniform distribution for sampling inference parameters.
189
-
190
- Samples values uniformly between low and high bounds. Useful for exploring
191
- a continuous range of values for temperature or top_p.
192
-
193
- Attributes:
194
- distribution_type: Type of distribution ("uniform").
195
- params: Distribution parameters (low, high).
196
- """
197
-
198
- distribution_type: DistributionType | None = "uniform"
199
- params: UniformDistributionParams
200
-
201
- def sample(self) -> float:
202
- """Sample a value from the uniform distribution.
203
-
204
- Returns:
205
- A float value sampled from the uniform distribution.
206
- """
207
- return float(np.random.uniform(low=self.params.low, high=self.params.high, size=1)[0])
208
-
209
-
210
- DistributionT: TypeAlias = UniformDistribution | ManualDistribution
211
-
212
-
213
- class GenerationType(str, Enum):
214
- CHAT_COMPLETION = "chat-completion"
215
- EMBEDDING = "embedding"
216
-
217
-
218
- class BaseInferenceParams(ConfigBase, ABC):
219
- """Base configuration for inference parameters.
220
-
221
- Attributes:
222
- generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
223
- max_parallel_requests: Maximum number of parallel requests to the model API.
224
- timeout: Timeout in seconds for each request.
225
- extra_body: Additional parameters to pass to the model API.
226
- """
227
-
228
- generation_type: GenerationType
229
- max_parallel_requests: int = Field(default=4, ge=1)
230
- timeout: int | None = Field(default=None, ge=1)
231
- extra_body: dict[str, Any] | None = None
232
-
233
- @property
234
- def generate_kwargs(self) -> dict[str, Any]:
235
- """Get the generate kwargs for the inference parameters.
236
-
237
- Returns:
238
- A dictionary of the generate kwargs.
239
- """
240
- result = {}
241
- if self.timeout is not None:
242
- result["timeout"] = self.timeout
243
- if self.extra_body is not None and self.extra_body != {}:
244
- result["extra_body"] = self.extra_body
245
- return result
246
-
247
- def format_for_display(self) -> str:
248
- """Format inference parameters for display.
249
-
250
- Returns:
251
- Formatted string of inference parameters
252
- """
253
- params_dict = self.model_dump(exclude_none=True, mode="json")
254
-
255
- if not params_dict:
256
- return "(none)"
257
-
258
- parts = []
259
- for key, value in params_dict.items():
260
- formatted_value = self._format_value(key, value)
261
- parts.append(f"{key}={formatted_value}")
262
- return ", ".join(parts)
263
-
264
- def _format_value(self, key: str, value: Any) -> str:
265
- """Format a single parameter value. Override in subclasses for custom formatting.
266
-
267
- Args:
268
- key: Parameter name
269
- value: Parameter value
270
-
271
- Returns:
272
- Formatted string representation of the value
273
- """
274
- if isinstance(value, float):
275
- return f"{value:.2f}"
276
- return str(value)
277
-
278
-
279
- class ChatCompletionInferenceParams(BaseInferenceParams):
280
- """Configuration for LLM inference parameters.
281
-
282
- Attributes:
283
- generation_type: Type of generation, always "chat-completion" for this class.
284
- temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
285
- top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
286
- max_tokens: Maximum number of tokens to generate in the response.
287
- """
288
-
289
- generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
290
- temperature: float | DistributionT | None = None
291
- top_p: float | DistributionT | None = None
292
- max_tokens: int | None = Field(default=None, ge=1)
293
-
294
- @property
295
- def generate_kwargs(self) -> dict[str, Any]:
296
- result = super().generate_kwargs
297
- if self.temperature is not None:
298
- result["temperature"] = (
299
- self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
300
- )
301
- if self.top_p is not None:
302
- result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
303
- if self.max_tokens is not None:
304
- result["max_tokens"] = self.max_tokens
305
- return result
306
-
307
- @model_validator(mode="after")
308
- def _validate_temperature(self) -> Self:
309
- return self._run_validation(
310
- value=self.temperature,
311
- param_name="temperature",
312
- min_value=MIN_TEMPERATURE,
313
- max_value=MAX_TEMPERATURE,
314
- )
315
-
316
- @model_validator(mode="after")
317
- def _validate_top_p(self) -> Self:
318
- return self._run_validation(
319
- value=self.top_p,
320
- param_name="top_p",
321
- min_value=MIN_TOP_P,
322
- max_value=MAX_TOP_P,
323
- )
324
-
325
- def _run_validation(
326
- self,
327
- value: float | DistributionT | None,
328
- param_name: str,
329
- min_value: float,
330
- max_value: float,
331
- ) -> Self:
332
- if value is None:
333
- return self
334
- value_err = ValueError(f"{param_name} defined in model config must be between {min_value} and {max_value}")
335
- if isinstance(value, Distribution):
336
- if value.distribution_type == DistributionType.UNIFORM:
337
- if value.params.low < min_value or value.params.high > max_value:
338
- raise value_err
339
- elif value.distribution_type == DistributionType.MANUAL:
340
- if any(not self._is_value_in_range(v, min_value, max_value) for v in value.params.values):
341
- raise value_err
342
- else:
343
- if not self._is_value_in_range(value, min_value, max_value):
344
- raise value_err
345
- return self
346
-
347
- def _is_value_in_range(self, value: float, min_value: float, max_value: float) -> bool:
348
- return min_value <= value <= max_value
349
-
350
- def _format_value(self, key: str, value: Any) -> str:
351
- """Format chat completion parameter values, including distributions.
352
-
353
- Args:
354
- key: Parameter name
355
- value: Parameter value
356
-
357
- Returns:
358
- Formatted string representation of the value
359
- """
360
- if isinstance(value, dict) and "distribution_type" in value:
361
- return "dist"
362
- return super()._format_value(key, value)
363
-
364
-
365
- class EmbeddingInferenceParams(BaseInferenceParams):
366
- """Configuration for embedding generation parameters.
367
-
368
- Attributes:
369
- generation_type: Type of generation, always "embedding" for this class.
370
- encoding_format: Format of the embedding encoding ("float" or "base64").
371
- dimensions: Number of dimensions for the embedding.
372
- """
373
-
374
- generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
375
- encoding_format: Literal["float", "base64"] = "float"
376
- dimensions: int | None = None
377
-
378
- @property
379
- def generate_kwargs(self) -> dict[str, float | int]:
380
- result = super().generate_kwargs
381
- if self.encoding_format is not None:
382
- result["encoding_format"] = self.encoding_format
383
- if self.dimensions is not None:
384
- result["dimensions"] = self.dimensions
385
- return result
386
-
387
-
388
- InferenceParamsT: TypeAlias = Annotated[
389
- ChatCompletionInferenceParams | EmbeddingInferenceParams, Field(discriminator="generation_type")
390
- ]
391
-
392
-
393
- class ModelConfig(ConfigBase):
394
- """Configuration for a model used for generation.
395
-
396
- Attributes:
397
- alias: User-defined alias to reference in column configurations.
398
- model: Model identifier (e.g., from build.nvidia.com or other providers).
399
- inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
400
- The generation_type is determined by the type of inference_parameters.
401
- provider: Optional model provider name if using custom providers.
402
- """
403
-
404
- alias: str
405
- model: str
406
- inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
407
- provider: str | None = None
408
-
409
- @property
410
- def generation_type(self) -> GenerationType:
411
- """Get the generation type from the inference parameters."""
412
- return self.inference_parameters.generation_type
413
-
414
- @field_validator("inference_parameters", mode="before")
415
- @classmethod
416
- def _convert_inference_parameters(cls, value: Any) -> Any:
417
- """Convert raw dict to appropriate inference parameters type based on field presence."""
418
- if isinstance(value, dict):
419
- # Infer type from presence of embedding-specific fields
420
- if "encoding_format" in value or "dimensions" in value:
421
- return EmbeddingInferenceParams(**value)
422
- else:
423
- return ChatCompletionInferenceParams(**value)
424
- return value
425
-
426
-
427
- class ModelProvider(ConfigBase):
428
- """Configuration for a custom model provider.
429
-
430
- Attributes:
431
- name: Name of the model provider.
432
- endpoint: API endpoint URL for the provider.
433
- provider_type: Provider type (default: "openai"). Determines the API format to use.
434
- api_key: Optional API key for authentication.
435
- extra_body: Additional parameters to pass in API requests.
436
- extra_headers: Additional headers to pass in API requests.
437
- """
438
-
439
- name: str
440
- endpoint: str
441
- provider_type: str = "openai"
442
- api_key: str | None = None
443
- extra_body: dict[str, Any] | None = None
444
- extra_headers: dict[str, str] | None = None
445
-
446
-
447
- def load_model_configs(model_configs: list[ModelConfig] | str | Path) -> list[ModelConfig]:
448
- if isinstance(model_configs, list) and all(isinstance(mc, ModelConfig) for mc in model_configs):
449
- return model_configs
450
- json_config = smart_load_yaml(model_configs)
451
- if "model_configs" not in json_config:
452
- raise InvalidConfigError(
453
- "The list of model configs must be provided under model_configs in the configuration file."
454
- )
455
- return [ModelConfig.model_validate(mc) for mc in json_config["model_configs"]]
@@ -1,41 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from typing import TYPE_CHECKING
7
-
8
- from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
9
- from data_designer.config.config_builder import DataDesignerConfigBuilder
10
- from data_designer.config.dataset_metadata import DatasetMetadata
11
- from data_designer.config.utils.visualization import WithRecordSamplerMixin
12
- from data_designer.lazy_heavy_imports import pd
13
-
14
- if TYPE_CHECKING:
15
- import pandas as pd
16
-
17
-
18
- class PreviewResults(WithRecordSamplerMixin):
19
- def __init__(
20
- self,
21
- *,
22
- config_builder: DataDesignerConfigBuilder,
23
- dataset_metadata: DatasetMetadata | None = None,
24
- dataset: pd.DataFrame | None = None,
25
- analysis: DatasetProfilerResults | None = None,
26
- processor_artifacts: dict[str, list[str] | str] | None = None,
27
- ):
28
- """Creates a new instance with results from a Data Designer preview run.
29
-
30
- Args:
31
- config_builder: Data Designer configuration builder.
32
- dataset_metadata: Metadata about the generated dataset (e.g., seed column names).
33
- dataset: Dataset of the preview run.
34
- analysis: Analysis of the preview run.
35
- processor_artifacts: Artifacts generated by the processors.
36
- """
37
- self.dataset: pd.DataFrame | None = dataset
38
- self.analysis: DatasetProfilerResults | None = analysis
39
- self.processor_artifacts: dict[str, list[str] | str] | None = processor_artifacts
40
- self.dataset_metadata: DatasetMetadata | None = dataset_metadata
41
- self._config_builder = config_builder
@@ -1,148 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import json
7
- from abc import ABC
8
- from enum import Enum
9
- from typing import Any, Literal
10
-
11
- from pydantic import Field, field_validator
12
- from typing_extensions import TypeAlias
13
-
14
- from data_designer.config.base import ConfigBase
15
- from data_designer.config.dataset_builders import BuildStage
16
- from data_designer.config.errors import InvalidConfigError
17
-
18
- SUPPORTED_STAGES = [BuildStage.POST_BATCH]
19
-
20
-
21
- class ProcessorType(str, Enum):
22
- """Enumeration of available processor types.
23
-
24
- Attributes:
25
- DROP_COLUMNS: Processor that removes specified columns from the output dataset.
26
- SCHEMA_TRANSFORM: Processor that creates a new dataset with a transformed schema using Jinja2 templates.
27
- """
28
-
29
- DROP_COLUMNS = "drop_columns"
30
- SCHEMA_TRANSFORM = "schema_transform"
31
-
32
-
33
- class ProcessorConfig(ConfigBase, ABC):
34
- """Abstract base class for all processor configuration types.
35
-
36
- Processors are transformations that run before or after columns are generated.
37
- They can modify, reshape, or augment the dataset before it's saved.
38
-
39
- Attributes:
40
- name: Unique name of the processor, used to identify the processor in results
41
- and to name output artifacts on disk.
42
- build_stage: The stage at which the processor runs. Currently only `POST_BATCH`
43
- is supported, meaning processors run after each batch of columns is generated.
44
- """
45
-
46
- name: str = Field(
47
- description="The name of the processor, used to identify the processor in the results and to write the artifacts to disk.",
48
- )
49
- build_stage: BuildStage = Field(
50
- default=BuildStage.POST_BATCH,
51
- description=f"The stage at which the processor will run. Supported stages: {', '.join(SUPPORTED_STAGES)}",
52
- )
53
- processor_type: str
54
-
55
- @field_validator("build_stage")
56
- def validate_build_stage(cls, v: BuildStage) -> BuildStage:
57
- if v not in SUPPORTED_STAGES:
58
- raise ValueError(
59
- f"Invalid dataset builder stage: {v}. Only these stages are supported: {', '.join(SUPPORTED_STAGES)}"
60
- )
61
- return v
62
-
63
-
64
- def get_processor_config_from_kwargs(processor_type: ProcessorType, **kwargs: Any) -> ProcessorConfig:
65
- """Create a processor configuration from a processor type and keyword arguments.
66
-
67
- Args:
68
- processor_type: The type of processor to create.
69
- **kwargs: Additional keyword arguments passed to the processor constructor.
70
-
71
- Returns:
72
- A processor configuration object of the specified type.
73
- """
74
- if processor_type == ProcessorType.DROP_COLUMNS:
75
- return DropColumnsProcessorConfig(**kwargs)
76
- elif processor_type == ProcessorType.SCHEMA_TRANSFORM:
77
- return SchemaTransformProcessorConfig(**kwargs)
78
-
79
-
80
- class DropColumnsProcessorConfig(ProcessorConfig):
81
- """Configuration for dropping columns from the output dataset.
82
-
83
- This processor removes specified columns from the generated dataset. The dropped
84
- columns are saved separately in a `dropped-columns` directory for reference.
85
- When this processor is added via the config builder, the corresponding column
86
- configs are automatically marked with `drop = True`.
87
-
88
- Alternatively, you can set `drop = True` when configuring a column.
89
-
90
- Attributes:
91
- column_names: List of column names to remove from the output dataset.
92
- processor_type: Discriminator field, always `ProcessorType.DROP_COLUMNS` for this configuration type.
93
- """
94
-
95
- column_names: list[str] = Field(description="List of column names to drop from the output dataset.")
96
- processor_type: Literal[ProcessorType.DROP_COLUMNS] = ProcessorType.DROP_COLUMNS
97
-
98
-
99
- class SchemaTransformProcessorConfig(ProcessorConfig):
100
- """Configuration for transforming the dataset schema using Jinja2 templates.
101
-
102
- This processor creates a new dataset with a transformed schema. Each key in the
103
- template becomes a column in the output, and values are Jinja2 templates that
104
- can reference any column in the batch. The transformed dataset is written to
105
- a `processors-outputs/{processor_name}/` directory alongside the main dataset.
106
-
107
- Attributes:
108
- template: Dictionary defining the output schema. Keys are new column names,
109
- values are Jinja2 templates (strings, lists, or nested structures).
110
- Must be JSON-serializable.
111
- processor_type: Discriminator field, always `ProcessorType.SCHEMA_TRANSFORM` for this configuration type.
112
- """
113
-
114
- template: dict[str, Any] = Field(
115
- ...,
116
- description="""
117
- Dictionary specifying columns and templates to use in the new dataset with transformed schema.
118
-
119
- Each key is a new column name, and each value is an object containing Jinja2 templates - for instance, a string or a list of strings.
120
- Values must be JSON-serializable.
121
-
122
- Example:
123
-
124
- ```python
125
- template = {
126
- "list_of_strings": ["{{ col1 }}", "{{ col2 }}"],
127
- "uppercase_string": "{{ col1 | upper }}",
128
- "lowercase_string": "{{ col2 | lower }}",
129
- }
130
- ```
131
-
132
- The above templates will create an new dataset with three columns: "list_of_strings", "uppercase_string", and "lowercase_string".
133
- References to columns "col1" and "col2" in the templates will be replaced with the actual values of the columns in the dataset.
134
- """,
135
- )
136
- processor_type: Literal[ProcessorType.SCHEMA_TRANSFORM] = ProcessorType.SCHEMA_TRANSFORM
137
-
138
- @field_validator("template")
139
- def validate_template(cls, v: dict[str, Any]) -> dict[str, Any]:
140
- try:
141
- json.dumps(v)
142
- except TypeError as e:
143
- if "not JSON serializable" in str(e):
144
- raise InvalidConfigError("Template must be JSON serializable")
145
- return v
146
-
147
-
148
- ProcessorConfigT: TypeAlias = DropColumnsProcessorConfig | SchemaTransformProcessorConfig
@@ -1,51 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from pydantic import Field, model_validator
7
- from typing_extensions import Self
8
-
9
- from data_designer.config.base import ConfigBase
10
-
11
-
12
- class RunConfig(ConfigBase):
13
- """Runtime configuration for dataset generation.
14
-
15
- Groups configuration options that control generation behavior but aren't
16
- part of the dataset configuration itself.
17
-
18
- Attributes:
19
- disable_early_shutdown: If True, disables the executor's early-shutdown behavior entirely.
20
- Generation will continue regardless of error rate, and the early-shutdown exception
21
- will never be raised. Error counts and summaries are still collected. Default is False.
22
- shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown when
23
- early shutdown is enabled. Default is 0.5.
24
- shutdown_error_window: Minimum number of completed tasks before error rate
25
- monitoring begins. Must be >= 0. Default is 10.
26
- buffer_size: Number of records to process in each batch during dataset generation.
27
- A batch is processed end-to-end (column generation, post-batch processors, and writing the batch
28
- to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000.
29
- non_inference_max_parallel_workers: Maximum number of worker threads used for non-inference
30
- cell-by-cell generators. Must be >= 1. Default is 4.
31
- max_conversation_restarts: Maximum number of full conversation restarts permitted when
32
- generation tasks call `ModelFacade.generate(...)`. Must be >= 0. Default is 5.
33
- max_conversation_correction_steps: Maximum number of correction rounds permitted within a
34
- single conversation when generation tasks call `ModelFacade.generate(...)`. Must be >= 0.
35
- Default is 0.
36
- """
37
-
38
- disable_early_shutdown: bool = False
39
- shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
40
- shutdown_error_window: int = Field(default=10, ge=0)
41
- buffer_size: int = Field(default=1000, gt=0)
42
- non_inference_max_parallel_workers: int = Field(default=4, ge=1)
43
- max_conversation_restarts: int = Field(default=5, ge=0)
44
- max_conversation_correction_steps: int = Field(default=0, ge=0)
45
-
46
- @model_validator(mode="after")
47
- def normalize_shutdown_settings(self) -> Self:
48
- """Normalize shutdown settings for compatibility."""
49
- if self.disable_early_shutdown:
50
- self.shutdown_error_rate = 1.0
51
- return self