data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,461 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections.abc import Callable
5
+ from functools import partial, wraps
6
+ import re
7
+ from typing import Any
8
+
9
+ from jinja2 import meta
10
+ from jinja2 import nodes as j_nodes
11
+ from jinja2.exceptions import SecurityError, TemplateSyntaxError
12
+ from jinja2.nodes import Template
13
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
14
+ from jsonpath_rust_bindings import Finder
15
+
16
+ from data_designer.engine.processing.ginja.ast import (
17
+ ast_count_name_references,
18
+ ast_descendant_count,
19
+ ast_max_depth,
20
+ )
21
+ from data_designer.engine.processing.ginja.exceptions import (
22
+ UserTemplateError,
23
+ UserTemplateUnsupportedFiltersError,
24
+ maybe_handle_missing_filter_exception,
25
+ )
26
+ from data_designer.engine.processing.ginja.record import sanitize_record
27
+
28
+ MAX_RENDERED_LEN = 512_000
29
+ MAX_AST_NODE_COUNT = 600
30
+ MAX_AST_DEPTH = 10
31
+ ALLOWED_JINJA_FILTERS = [
32
+ ## Jinja2 Builtin Filters
33
+ "abs",
34
+ "capitalize",
35
+ "escape",
36
+ "first",
37
+ "float",
38
+ "forceescape",
39
+ "int",
40
+ "items",
41
+ "last",
42
+ "length",
43
+ "list",
44
+ "lower",
45
+ "max",
46
+ "min",
47
+ "random",
48
+ "replace",
49
+ "reverse",
50
+ "round",
51
+ "sort",
52
+ "string",
53
+ "title",
54
+ "trim",
55
+ "truncate",
56
+ "unique",
57
+ "urlencode",
58
+ ## Custom Filters
59
+ "jsonpath",
60
+ ]
61
+
62
+ USER_PROMPT_TEMPLATE_ERROR_MESSAGE = """\
63
+ User provided prompt generation template is invalid.\
64
+ """
65
+ UNSUPPORTED_AST_NODES = [
66
+ j_nodes.Import, # No {% include ... %}
67
+ j_nodes.Macro, # No {% macro ... %}
68
+ j_nodes.Assign, # No {% set ... %}
69
+ j_nodes.Extends, # No {% extends ... %}
70
+ j_nodes.Block, # No {% block ... %}
71
+ ]
72
+
73
+
74
+ def jsonpath_jinja_filter(data: dict, expression: str) -> list[Any]:
75
+ """Defines JSONPath-based operations on variables.
76
+
77
+ Args:
78
+ data (dict): data object to filter.
79
+ expression (str): a valid JSONPath string.
80
+
81
+ Returns:
82
+ list[Any]: A list of JSONPath match values.
83
+ """
84
+ if not isinstance(data, dict):
85
+ raise ValueError("Cannot perform JSONPath filter on non-structured data.")
86
+
87
+ return [result.data for result in Finder(data).find(expression)]
88
+
89
+
90
+ def is_jinja_template(user_template: str) -> bool:
91
+ """Determine if a prompt template is a Jinja2 template from heuristics.
92
+
93
+ This function is intended to help migration from format strings->Jinja.
94
+ If we only support Jinja2, then this function is not needed.
95
+
96
+ Args:
97
+ user_template (str): A user-provided template string to test.
98
+
99
+ Returns:
100
+ True if the heuristic believes it is a Jinja2 template.
101
+ """
102
+ jinja_pattern_pairs = [("{{", "}}"), ("{%", "%}"), ("{#", "#}")]
103
+ for open_pattern, close_pattern in jinja_pattern_pairs:
104
+ if open_pattern in user_template and close_pattern in user_template:
105
+ return True
106
+
107
+ return False
108
+
109
+
110
+ class UserTemplateSandboxEnvironment(ImmutableSandboxedEnvironment):
111
+ """Defines a robust environment for rendering Gretel's Jinja2 subset.
112
+
113
+ The use of Jinja2 sandboxing is critical. We are taking Jinja2
114
+ templates from users -- we need to take steps to ensure that users
115
+ are not able to break containment or exfiltrate server-side secrets.
116
+
117
+ This Environment definition attempts to lock down as much as we can
118
+ for a pure python implementation by extending restrictions past
119
+ that of the `ImmutableSandboxedEnviornment.` While that environment
120
+ provides a base layer of protections, including:
121
+
122
+ - No references to private attributes
123
+ - Restrictions on loop iterations (OverflowError)
124
+
125
+ We enforce further precautions:
126
+
127
+ - Forced auto-escaping templates (preventing some injection attacks).
128
+ - Prevents access to the template's `self` attribute.
129
+ - Prevents reference to variables except for a provided white-list.
130
+ - Removes support for: include, extend, macro, set, block, and nested loops
131
+ - Errors on too-long rendered templates (e.g. >128k chars).
132
+ - Remove all default Jinja filter operations except for JSONPath (negotiable).
133
+ - Uses AST static analysis to threshold the complexity of allowed templates.
134
+
135
+ """
136
+
137
+ max_rendered_len: int
138
+ max_ast_node_count: int
139
+ max_ast_depth: int
140
+ allowed_references: list[str]
141
+
142
+ def __init__(
143
+ self,
144
+ allowed_references: list[str] | None = None,
145
+ max_rendered_len: int = MAX_RENDERED_LEN,
146
+ max_ast_node_count: int = MAX_AST_NODE_COUNT,
147
+ max_ast_depth: int = MAX_AST_DEPTH,
148
+ **kwargs,
149
+ ):
150
+ """Args:
151
+ max_rendered_len (int): The maximum allowable character count for
152
+ rendered templates.
153
+
154
+ allowed_references (optional, list[str]): If set, indicates which variables
155
+ are allowed to be referenced by the Jinja2 template. If not specified,
156
+ defaults to [], which indicates that the Jinja2 template is not
157
+ allowed to refer to _any_ variables outside of itself.
158
+
159
+ max_ast_node_count (optional, int): Parameter for static analysis of
160
+ Jinja2 template complexity -- counts the number of distinct nodes
161
+ in the parsed Jinja2 AST. A large number of nodes indicates many
162
+ distinct operations within the provided user template, which can
163
+ cause long compute times, or may be malicious in nature. If not
164
+ specified, defaults to MAX_AST_NODE_COUNT set by this module.
165
+
166
+ max_ast_depth (optional, int): Parameter for static analysis of
167
+ Jinja2 template complexity -- measures the maximum depth of the
168
+ parsed Jinja2 AST. A high depth indicates a high degree of nesting
169
+ within the user template. This may can cause long compute times,
170
+ or may be malicious in nature. If not specified, defaults to
171
+ MAX_AST_DEPTH set by this module.
172
+
173
+ **kwargs: Additional kwargs passed to ImmutableSandboxedEnvironment.
174
+ """
175
+ super().__init__(autoescape=False, **kwargs)
176
+ self.max_rendered_len = max_rendered_len
177
+ self.max_ast_node_count = max_ast_node_count
178
+ self.max_ast_depth = max_ast_depth
179
+ self.allowed_references = allowed_references if allowed_references else []
180
+
181
+ ## Add on our supported filters
182
+ self.filters["jsonpath"] = jsonpath_jinja_filter
183
+
184
+ ## Cut out all but approved Jinja filters
185
+ self.filters = {k: v for k, v in self.filters.items() if k in ALLOWED_JINJA_FILTERS}
186
+
187
+ def _assert_template_has_valid_references(self, ast: Template) -> None:
188
+ """Assert that all named variable references are allowed.
189
+
190
+ Checks against the environment's allowed reference list created
191
+ at initialization.
192
+ """
193
+ template_vars = meta.find_undeclared_variables(ast)
194
+ unallowed_vars = set(template_vars) - set(self.allowed_references)
195
+ if len(unallowed_vars) > 0:
196
+ raise UserTemplateError(f"Unknown variable references in Jinja template: {unallowed_vars}")
197
+
198
+ def _assert_template_has_valid_ast_nodes(self, ast: Template) -> None:
199
+ """Assert that un-allowed operations aren't in the template."""
200
+ black_list_node_count = sum(ast_descendant_count(ast, node_type) for node_type in UNSUPPORTED_AST_NODES)
201
+
202
+ if black_list_node_count != 0:
203
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
204
+
205
+ def _assert_template_has_no_recursive_for(self, ast: Template) -> None:
206
+ """Assert that the template does not use {% for ... recursive %}"""
207
+ if any(node.recursive for node in ast.find_all(j_nodes.For)):
208
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
209
+
210
+ def _assert_template_has_no_nested_for(self, ast: Template) -> None:
211
+ """Assert that the template does not contain nested loops.
212
+
213
+ This assertion is made to ensure that templates cannot combinatorially
214
+ explode. High-range values are controlled by the `MAX_RANGE` setting
215
+ on `SandboxedEnvironment`.
216
+ """
217
+ # Check each For node in the AST to see if it has For descendants
218
+ for node in ast.find_all(j_nodes.For):
219
+ if ast_descendant_count(node, only_type=j_nodes.For):
220
+ raise UserTemplateError("Non-permitted operations in Jinja template (nested-for).")
221
+
222
+ def _assert_template_ast_complexity(self, ast: Template) -> None:
223
+ """Assert that the AST tree parsed from the template is not overly complex.
224
+
225
+ Complexity is measured by the depth of the tree (measure of nesting),
226
+ as well as the number of nodes it contains (how many distinct operations).
227
+ If either is over a fixed limit specified at initialization, the assert fails.
228
+ """
229
+ node_count = ast_descendant_count(ast)
230
+ max_depth = ast_max_depth(ast)
231
+
232
+ if node_count > self.max_ast_node_count or max_depth > self.max_ast_depth:
233
+ raise UserTemplateError("Jinja template too complex, simplify your template.")
234
+
235
+ def _assert_template_has_no_self_reference(self, ast: Template) -> None:
236
+ """Assert that the template cannot refer to its own settings.
237
+
238
+ Templates may attempt to use {{ self }} references to gain
239
+ access to properties of the template object itself. This
240
+ is disallowed.
241
+ """
242
+ if ast_count_name_references(ast, "self") != 0:
243
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
244
+
245
+ def validate_template(self, user_template: str) -> None:
246
+ """Template validations are run against the template object itself.
247
+ First-layer injection attacks are (on the parse operation) are
248
+ prevented by using `autoescape=True` on environment creation.
249
+
250
+ Afterwards, we can analyze the AST of the parsed template to detect
251
+ and mitigate a wide range of attacks.
252
+
253
+ Args:
254
+ user_template (str): A submitted user Jinja2 template.
255
+
256
+ Raises:
257
+ TemplateSyntaxError: If the provided template is malformed or
258
+ not parseable as a Jinja2 template.
259
+ UserTemplateError: If any of the assertions fail.
260
+ """
261
+ try:
262
+ ast = self.parse(user_template)
263
+ self._assert_template_has_valid_ast_nodes(ast)
264
+ self._assert_template_has_no_recursive_for(ast)
265
+ self._assert_template_has_no_nested_for(ast)
266
+ self._assert_template_ast_complexity(ast)
267
+ self._assert_template_has_no_self_reference(ast)
268
+ self._assert_template_has_valid_references(ast)
269
+ except Exception as exception:
270
+ maybe_handle_missing_filter_exception(exception, available_jinja_filters=list(self.filters.keys()))
271
+ raise exception
272
+
273
+ def _assert_rendered_text_length(self, rendered_text: str) -> None:
274
+ """Check against the length of the rendered string."""
275
+ rendered_len = len(rendered_text)
276
+ if rendered_len > self.max_rendered_len:
277
+ raise UserTemplateError(f"Rendered Jinja template too large ({rendered_len} > {self.max_rendered_len}).")
278
+
279
+ def _assert_rendered_text_has_no_builtin_descriptions(self, rendered_text: str) -> None:
280
+ """Check to make sure that the outputs aren't descriptions of methods.
281
+
282
+ In the event that the user types the name of a __builtin__
283
+ object method, but doesn't call it, we don't want to report
284
+ information about the system's memory contents.
285
+
286
+ Further, if the user made a mistake, we'd rather error out
287
+ rather than continue task processing, for instance.
288
+ """
289
+ patterns = [
290
+ r"<built-in method (.*?) of (.*?) object at 0x(.*?)>",
291
+ r"<function (.*?) at (.*?)>",
292
+ ]
293
+ for pattern in patterns:
294
+ matches = re.search(pattern, rendered_text)
295
+ if bool(matches):
296
+ raise UserTemplateError("User template has uncalled __builtin__ method.")
297
+
298
+ def _assert_rendered_text_not_empty(self, rendered_text: str) -> None:
299
+ """Check to make sure the resulting text isn't an empty string"""
300
+ if len(rendered_text) == 0:
301
+ raise UserTemplateError("User template renders to empty text.")
302
+
303
+ def validate_rendered_text(self, rendered_text: str) -> None:
304
+ """Raises UserTemplateError on invalid renders.
305
+
306
+ This is used as a post-processing step for capturing and
307
+ acting on strings before they go out the door.
308
+ """
309
+ self._assert_rendered_text_not_empty(rendered_text)
310
+ self._assert_rendered_text_length(rendered_text)
311
+ self._assert_rendered_text_has_no_builtin_descriptions(rendered_text)
312
+
313
+ def safe_render(self, user_template: str, record: dict, skip_template_validation: bool = False) -> str:
314
+ """Attempt to safely render a user's template.
315
+
316
+ Args:
317
+ user_template (str): The user submitted Jinja2 template string.
318
+ record (dict): a record of fields which are able to be referenced by the template.
319
+ skip_template_validation (optional, bool): If true, then AST checks against the
320
+ template itself will not be performed. WARNING: this should ONLY be set to true
321
+ if the template has already been validated.
322
+
323
+ Raises:
324
+ UserTemplateError: If the template cannot be rendered because the
325
+ user template does not conform to Gretel's Jinja2 subset,
326
+ is too long, or contains some attempted malicious payload.
327
+ If skip_template_validation is False, this error may also indicate
328
+ that the template itself has failed static analysis. See the error
329
+ message for more details.
330
+
331
+ RecordContentsError: If there is a system-internal error with
332
+ the supplied record data. This error is raised to prevent Jinja2
333
+ processing of potentially insecure data objects.
334
+ """
335
+ if not skip_template_validation:
336
+ self.validate_template(user_template)
337
+
338
+ record = sanitize_record(record)
339
+
340
+ try:
341
+ template = self.from_string(user_template)
342
+ rendered_text = template.render(record)
343
+ except SecurityError:
344
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
345
+ except OverflowError:
346
+ raise UserTemplateError("Template too large.")
347
+ except Exception as exception:
348
+ maybe_handle_missing_filter_exception(exception, available_jinja_filters=list(self.filters.keys()))
349
+ raise exception
350
+
351
+ self.validate_rendered_text(rendered_text)
352
+
353
+ return rendered_text
354
+
355
+ def get_references(self, user_template: str) -> set[str]:
356
+ """Get all referenced variables from the provided template.
357
+
358
+ Args:
359
+ user_template (str): A user-provided Jinja template.
360
+
361
+ Returns:
362
+ set[str]: A set of all variable names referenced in
363
+ the supplied Jinja template. If no variables are
364
+ referenced, then this will be an empty list.
365
+ """
366
+ ast = self.parse(user_template)
367
+ return meta.find_undeclared_variables(ast)
368
+
369
+
370
+ def sanitize_user_exceptions(func):
371
+ """Sanitize returned user-space exceptions."""
372
+
373
+ @wraps(func)
374
+ def wrapper(*args, **kwargs):
375
+ try:
376
+ return func(*args, **kwargs)
377
+ except UserTemplateUnsupportedFiltersError as exception:
378
+ ## Informative messaging is already handled in this
379
+ ## specific case.
380
+ raise exception
381
+ except (UserTemplateError, TemplateSyntaxError):
382
+ ## All other details are wrapped in a generic error message
383
+ raise UserTemplateError(USER_PROMPT_TEMPLATE_ERROR_MESSAGE)
384
+
385
+ return wrapper
386
+
387
+
388
+ class WithJinja2UserTemplateRendering:
389
+ """Mixin class to support user-supplied Jinja2 rendering.
390
+
391
+ Provides `self.render_template(record: dict)` to the receiving
392
+ class, which can be used to safely render user-provided Jinja2
393
+ templates using `UserTemplateSandboxedEnvironment`.
394
+
395
+ This mixin also provides error message sanitization for exceptions
396
+ raised by the rendering environment.
397
+
398
+ Usage:
399
+
400
+ class Foo(WithJinja2UserTemplateRendering):
401
+ def my_func(self, user_template: str, records: list[dict]):
402
+
403
+ ## Call once per template -- must be prepared before
404
+ ## being able to call self.render_template
405
+ self.prepare_jinja2_template_renderer(user_template)
406
+
407
+ ## Can call many times after
408
+ for record in records:
409
+ self.render_template(record)
410
+ """
411
+
412
+ _template_render_fn: Callable
413
+
414
+ @sanitize_user_exceptions
415
+ def prepare_jinja2_template_renderer(self, prompt_template: str, dataset_variables: list[str]) -> None:
416
+ """Build Jinja2 template render function."""
417
+ jinja_render_env = UserTemplateSandboxEnvironment(allowed_references=dataset_variables)
418
+ jinja_render_env.validate_template(prompt_template)
419
+ self._template_render_fn = partial(
420
+ jinja_render_env.safe_render,
421
+ prompt_template,
422
+ skip_template_validation=True,
423
+ )
424
+
425
+ @sanitize_user_exceptions
426
+ def render_template(self, record: dict) -> str:
427
+ return self._template_render_fn(record)
428
+
429
+ @sanitize_user_exceptions
430
+ def prepare_jinja2_multi_template_renderer(
431
+ self,
432
+ template_name: str,
433
+ prompt_template: str,
434
+ dataset_variables: list[str],
435
+ ) -> None:
436
+ if not self._template_prepared_in_multi_template_renderer(template_name):
437
+ self._create_render_func_registry()
438
+ jinja_render_env = UserTemplateSandboxEnvironment(allowed_references=dataset_variables)
439
+ jinja_render_env.validate_template(prompt_template)
440
+ self._render_func_registry[template_name] = partial(
441
+ jinja_render_env.safe_render,
442
+ prompt_template,
443
+ skip_template_validation=True,
444
+ )
445
+
446
+ @sanitize_user_exceptions
447
+ def render_multi_template(self, template_name: str, record: dict) -> str:
448
+ if not hasattr(self, "_render_func_registry"):
449
+ raise UserTemplateError("Multi-template renderer not prepared.")
450
+ if template_name not in self._render_func_registry:
451
+ raise UserTemplateError(f"Template {template_name} not prepared.")
452
+ return self._render_func_registry[template_name](record)
453
+
454
+ def _template_prepared_in_multi_template_renderer(self, template_name: str) -> bool:
455
+ if not hasattr(self, "_render_func_registry"):
456
+ return False
457
+ return template_name in self._render_func_registry
458
+
459
+ def _create_render_func_registry(self) -> None:
460
+ if not hasattr(self, "_render_func_registry"):
461
+ self._render_func_registry = {}
@@ -0,0 +1,54 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import re
5
+
6
+ from jinja2 import TemplateAssertionError
7
+
8
+
9
+ class UserTemplateError(Exception):
10
+ """Exception for user-induced template flaws, intentional or not."""
11
+
12
+
13
+ class UserTemplateUnsupportedFiltersError(UserTemplateError):
14
+ """Specific exception for the case of unsupported filters."""
15
+
16
+
17
+ class RecordContentsError(Exception):
18
+ """Exception for cases involving the record providing template context."""
19
+
20
+
21
+ def maybe_handle_missing_filter_exception(exception: BaseException, available_jinja_filters: list[str]) -> None:
22
+ """Interpret and handle the possible case of a missing filter exception.
23
+
24
+ If this wasn't a missing filter exception, then this function will do
25
+ nothing.
26
+
27
+ Args:
28
+ exception (BaseException): The caught exception.
29
+ available_jinja_filters (list[str]): The list of Jinja filters that
30
+ are known to be available within the environment.
31
+
32
+ Raises:
33
+ UserTemplateUnsupportedFiltersError: If the exception was specifically for an unknown
34
+ or unsupported Jinja2 filter.
35
+ """
36
+ if not isinstance(exception, TemplateAssertionError):
37
+ return
38
+
39
+ exc_message = exception.message or ""
40
+
41
+ ## The missing filter message has the format:
42
+ ## "No filter named '____'"
43
+ match = re.search(r"No filter named '([^']+)'", exc_message)
44
+ if not match:
45
+ return
46
+ else:
47
+ missing_filter_name = match.group(1)
48
+ available_filter_str = ", ".join(available_jinja_filters)
49
+ raise UserTemplateUnsupportedFiltersError(
50
+ (
51
+ f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` "
52
+ f"is not a permitted operation. Available filters: {available_filter_str}"
53
+ )
54
+ ) from exception
@@ -0,0 +1,30 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import json
5
+
6
+ from data_designer.config.utils.io_helpers import serialize_data
7
+ from data_designer.engine.processing.ginja.exceptions import RecordContentsError
8
+
9
+
10
+ def sanitize_record(record: dict) -> dict:
11
+ """Sanitize a record into basic types.
12
+
13
+ To prevent any unexpected attributes from being callable from
14
+ the template, we apply a serdes step to ensure that the record
15
+ used as context for the rendering step consists of basic
16
+ python types (e.g. those that can be represented via JSON).
17
+
18
+ Args:
19
+ record (dict): A dictionary object which can be serialized.
20
+
21
+ Raises:
22
+ RecordContentsError if the record contents are not able
23
+ to be represented with JSON.
24
+ """
25
+ try:
26
+ ser = serialize_data(record)
27
+ except (TypeError, ValueError) as e:
28
+ raise RecordContentsError("Unexpected unserializable content found in record.") from e
29
+
30
+ return json.loads(ser)
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from jsonschema import ValidationError
5
+
6
+
7
+ class JSONSchemaValidationError(ValidationError):
8
+ """Alias of ValidationError to ease imports."""
@@ -0,0 +1,81 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+ from typing import Any
6
+
7
+ from data_designer.engine.processing.gsonschema.types import JSONSchemaT
8
+
9
+
10
+ def _is_bare_dictionary_schema(schema_part: Any) -> bool:
11
+ """Classify bare dictionary schemas
12
+
13
+ A bare dictionary schema is one which looks like the following:
14
+
15
+ { "title": ... , "type": "object" }
16
+ { "type": "object" }
17
+
18
+ These schemas do not specify any "properties", just their existence.
19
+ """
20
+ if not isinstance(schema_part, dict):
21
+ return False
22
+
23
+ if schema_part.get("type") != "object":
24
+ return False
25
+
26
+ if ("title" in schema_part and len(schema_part) == 2) or (len(schema_part) == 1):
27
+ return True
28
+
29
+ return False
30
+
31
+
32
+ def forbid_additional_properties(schema: JSONSchemaT) -> JSONSchemaT:
33
+ """Transform the provided schema into one which forbids additional properties.
34
+
35
+ Args:
36
+ schema (JSONSchemaT): A JSONSchema to transform.
37
+
38
+ Returns:
39
+ JSONSchemaT: A new JSONSchema matching the provided one, but
40
+ with `additionalProperties: False` set everywhere.
41
+ """
42
+ new_schema = deepcopy(schema)
43
+
44
+ def _enforce(schema_part: Any) -> None:
45
+ if isinstance(schema_part, dict):
46
+ if schema_part.get("type") == "object" or "properties" in schema_part:
47
+ ## We need to handle the special case that the schema specifies just a bare
48
+ ## dictionary. In those cases, the implication is that _all_ dictionaries
49
+ ## are valid, so we should not forbid extra properties in that case.
50
+ allow_additional_properties = _is_bare_dictionary_schema(schema_part)
51
+ schema_part["additionalProperties"] = allow_additional_properties
52
+
53
+ # Traverse into nested schemas.
54
+ for key, value in schema_part.items():
55
+ if key in ("properties", "patternProperties"):
56
+ if isinstance(value, dict):
57
+ for sub_schema in value.values():
58
+ _enforce(sub_schema)
59
+ elif key == "items":
60
+ if isinstance(value, dict):
61
+ _enforce(value)
62
+ elif isinstance(value, list):
63
+ for item in value:
64
+ _enforce(item)
65
+ elif key in ("allOf", "anyOf", "oneOf"):
66
+ if isinstance(value, list):
67
+ for item in value:
68
+ _enforce(item)
69
+ elif key in ("not", "if", "then", "else"):
70
+ if isinstance(value, dict):
71
+ _enforce(value)
72
+ elif key == "$defs":
73
+ for sub_schema in value.values():
74
+ _enforce(sub_schema)
75
+
76
+ elif isinstance(schema_part, list):
77
+ for item in schema_part:
78
+ _enforce(item)
79
+
80
+ _enforce(new_schema)
81
+ return new_schema
@@ -0,0 +1,8 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, TypeVar
5
+
6
+ T_primitive = TypeVar("T_primitive", str, int, float, bool)
7
+ DataObjectT = dict | list | str | int | float | bool
8
+ JSONSchemaT = dict[str, Any]