data-designer-engine 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. data_designer/engine/__init__.py +2 -0
  2. data_designer/engine/_version.py +34 -0
  3. data_designer/engine/analysis/column_profilers/base.py +49 -0
  4. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +153 -0
  5. data_designer/engine/analysis/column_profilers/registry.py +22 -0
  6. data_designer/engine/analysis/column_statistics.py +145 -0
  7. data_designer/engine/analysis/dataset_profiler.py +149 -0
  8. data_designer/engine/analysis/errors.py +9 -0
  9. data_designer/engine/analysis/utils/column_statistics_calculations.py +234 -0
  10. data_designer/engine/analysis/utils/judge_score_processing.py +132 -0
  11. data_designer/engine/column_generators/__init__.py +2 -0
  12. data_designer/engine/column_generators/generators/__init__.py +2 -0
  13. data_designer/engine/column_generators/generators/base.py +122 -0
  14. data_designer/engine/column_generators/generators/embedding.py +35 -0
  15. data_designer/engine/column_generators/generators/expression.py +55 -0
  16. data_designer/engine/column_generators/generators/llm_completion.py +116 -0
  17. data_designer/engine/column_generators/generators/samplers.py +69 -0
  18. data_designer/engine/column_generators/generators/seed_dataset.py +144 -0
  19. data_designer/engine/column_generators/generators/validation.py +140 -0
  20. data_designer/engine/column_generators/registry.py +60 -0
  21. data_designer/engine/column_generators/utils/errors.py +15 -0
  22. data_designer/engine/column_generators/utils/generator_classification.py +43 -0
  23. data_designer/engine/column_generators/utils/judge_score_factory.py +58 -0
  24. data_designer/engine/column_generators/utils/prompt_renderer.py +100 -0
  25. data_designer/engine/compiler.py +97 -0
  26. data_designer/engine/configurable_task.py +71 -0
  27. data_designer/engine/dataset_builders/artifact_storage.py +283 -0
  28. data_designer/engine/dataset_builders/column_wise_builder.py +354 -0
  29. data_designer/engine/dataset_builders/errors.py +15 -0
  30. data_designer/engine/dataset_builders/multi_column_configs.py +46 -0
  31. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  32. data_designer/engine/dataset_builders/utils/concurrency.py +212 -0
  33. data_designer/engine/dataset_builders/utils/config_compiler.py +62 -0
  34. data_designer/engine/dataset_builders/utils/dag.py +62 -0
  35. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +200 -0
  36. data_designer/engine/dataset_builders/utils/errors.py +15 -0
  37. data_designer/engine/dataset_builders/utils/progress_tracker.py +122 -0
  38. data_designer/engine/errors.py +51 -0
  39. data_designer/engine/model_provider.py +77 -0
  40. data_designer/engine/models/__init__.py +2 -0
  41. data_designer/engine/models/errors.py +300 -0
  42. data_designer/engine/models/facade.py +284 -0
  43. data_designer/engine/models/factory.py +42 -0
  44. data_designer/engine/models/litellm_overrides.py +179 -0
  45. data_designer/engine/models/parsers/__init__.py +2 -0
  46. data_designer/engine/models/parsers/errors.py +34 -0
  47. data_designer/engine/models/parsers/parser.py +235 -0
  48. data_designer/engine/models/parsers/postprocessors.py +93 -0
  49. data_designer/engine/models/parsers/tag_parsers.py +62 -0
  50. data_designer/engine/models/parsers/types.py +84 -0
  51. data_designer/engine/models/recipes/base.py +81 -0
  52. data_designer/engine/models/recipes/response_recipes.py +293 -0
  53. data_designer/engine/models/registry.py +151 -0
  54. data_designer/engine/models/telemetry.py +362 -0
  55. data_designer/engine/models/usage.py +73 -0
  56. data_designer/engine/models/utils.py +101 -0
  57. data_designer/engine/processing/ginja/__init__.py +2 -0
  58. data_designer/engine/processing/ginja/ast.py +65 -0
  59. data_designer/engine/processing/ginja/environment.py +463 -0
  60. data_designer/engine/processing/ginja/exceptions.py +56 -0
  61. data_designer/engine/processing/ginja/record.py +32 -0
  62. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  63. data_designer/engine/processing/gsonschema/exceptions.py +15 -0
  64. data_designer/engine/processing/gsonschema/schema_transformers.py +83 -0
  65. data_designer/engine/processing/gsonschema/types.py +10 -0
  66. data_designer/engine/processing/gsonschema/validators.py +202 -0
  67. data_designer/engine/processing/processors/base.py +13 -0
  68. data_designer/engine/processing/processors/drop_columns.py +42 -0
  69. data_designer/engine/processing/processors/registry.py +25 -0
  70. data_designer/engine/processing/processors/schema_transform.py +71 -0
  71. data_designer/engine/processing/utils.py +169 -0
  72. data_designer/engine/registry/base.py +99 -0
  73. data_designer/engine/registry/data_designer_registry.py +39 -0
  74. data_designer/engine/registry/errors.py +12 -0
  75. data_designer/engine/resources/managed_dataset_generator.py +39 -0
  76. data_designer/engine/resources/managed_dataset_repository.py +197 -0
  77. data_designer/engine/resources/managed_storage.py +65 -0
  78. data_designer/engine/resources/resource_provider.py +77 -0
  79. data_designer/engine/resources/seed_reader.py +154 -0
  80. data_designer/engine/sampling_gen/column.py +91 -0
  81. data_designer/engine/sampling_gen/constraints.py +100 -0
  82. data_designer/engine/sampling_gen/data_sources/base.py +217 -0
  83. data_designer/engine/sampling_gen/data_sources/errors.py +12 -0
  84. data_designer/engine/sampling_gen/data_sources/sources.py +347 -0
  85. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  86. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  87. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +90 -0
  88. data_designer/engine/sampling_gen/entities/email_address_utils.py +171 -0
  89. data_designer/engine/sampling_gen/entities/errors.py +10 -0
  90. data_designer/engine/sampling_gen/entities/national_id_utils.py +102 -0
  91. data_designer/engine/sampling_gen/entities/person.py +144 -0
  92. data_designer/engine/sampling_gen/entities/phone_number.py +128 -0
  93. data_designer/engine/sampling_gen/errors.py +26 -0
  94. data_designer/engine/sampling_gen/generator.py +122 -0
  95. data_designer/engine/sampling_gen/jinja_utils.py +64 -0
  96. data_designer/engine/sampling_gen/people_gen.py +199 -0
  97. data_designer/engine/sampling_gen/person_constants.py +56 -0
  98. data_designer/engine/sampling_gen/schema.py +147 -0
  99. data_designer/engine/sampling_gen/schema_builder.py +61 -0
  100. data_designer/engine/sampling_gen/utils.py +46 -0
  101. data_designer/engine/secret_resolver.py +82 -0
  102. data_designer/engine/testing/__init__.py +12 -0
  103. data_designer/engine/testing/stubs.py +133 -0
  104. data_designer/engine/testing/utils.py +20 -0
  105. data_designer/engine/validation.py +367 -0
  106. data_designer/engine/validators/__init__.py +19 -0
  107. data_designer/engine/validators/base.py +38 -0
  108. data_designer/engine/validators/local_callable.py +39 -0
  109. data_designer/engine/validators/python.py +254 -0
  110. data_designer/engine/validators/remote.py +89 -0
  111. data_designer/engine/validators/sql.py +65 -0
  112. data_designer_engine-0.4.0.dist-info/METADATA +50 -0
  113. data_designer_engine-0.4.0.dist-info/RECORD +114 -0
  114. data_designer_engine-0.4.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,463 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import re
7
+ from collections.abc import Callable
8
+ from functools import partial, wraps
9
+ from typing import Any
10
+
11
+ from jinja2 import meta
12
+ from jinja2 import nodes as j_nodes
13
+ from jinja2.exceptions import SecurityError, TemplateSyntaxError
14
+ from jinja2.nodes import Template
15
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
16
+ from jsonpath_rust_bindings import Finder
17
+
18
+ from data_designer.engine.processing.ginja.ast import (
19
+ ast_count_name_references,
20
+ ast_descendant_count,
21
+ ast_max_depth,
22
+ )
23
+ from data_designer.engine.processing.ginja.exceptions import (
24
+ UserTemplateError,
25
+ UserTemplateUnsupportedFiltersError,
26
+ maybe_handle_missing_filter_exception,
27
+ )
28
+ from data_designer.engine.processing.ginja.record import sanitize_record
29
+
30
+ MAX_RENDERED_LEN = 512_000
31
+ MAX_AST_NODE_COUNT = 600
32
+ MAX_AST_DEPTH = 10
33
+ ALLOWED_JINJA_FILTERS = [
34
+ ## Jinja2 Builtin Filters
35
+ "abs",
36
+ "capitalize",
37
+ "escape",
38
+ "first",
39
+ "float",
40
+ "forceescape",
41
+ "int",
42
+ "items",
43
+ "last",
44
+ "length",
45
+ "list",
46
+ "lower",
47
+ "max",
48
+ "min",
49
+ "random",
50
+ "replace",
51
+ "reverse",
52
+ "round",
53
+ "sort",
54
+ "string",
55
+ "title",
56
+ "trim",
57
+ "truncate",
58
+ "unique",
59
+ "urlencode",
60
+ ## Custom Filters
61
+ "jsonpath",
62
+ ]
63
+
64
+ USER_PROMPT_TEMPLATE_ERROR_MESSAGE = """\
65
+ User provided prompt generation template is invalid.\
66
+ """
67
+ UNSUPPORTED_AST_NODES = [
68
+ j_nodes.Import, # No {% include ... %}
69
+ j_nodes.Macro, # No {% macro ... %}
70
+ j_nodes.Assign, # No {% set ... %}
71
+ j_nodes.Extends, # No {% extends ... %}
72
+ j_nodes.Block, # No {% block ... %}
73
+ ]
74
+
75
+
76
+ def jsonpath_jinja_filter(data: dict, expression: str) -> list[Any]:
77
+ """Defines JSONPath-based operations on variables.
78
+
79
+ Args:
80
+ data (dict): data object to filter.
81
+ expression (str): a valid JSONPath string.
82
+
83
+ Returns:
84
+ list[Any]: A list of JSONPath match values.
85
+ """
86
+ if not isinstance(data, dict):
87
+ raise ValueError("Cannot perform JSONPath filter on non-structured data.")
88
+
89
+ return [result.data for result in Finder(data).find(expression)]
90
+
91
+
92
+ def is_jinja_template(user_template: str) -> bool:
93
+ """Determine if a prompt template is a Jinja2 template from heuristics.
94
+
95
+ This function is intended to help migration from format strings->Jinja.
96
+ If we only support Jinja2, then this function is not needed.
97
+
98
+ Args:
99
+ user_template (str): A user-provided template string to test.
100
+
101
+ Returns:
102
+ True if the heuristic believes it is a Jinja2 template.
103
+ """
104
+ jinja_pattern_pairs = [("{{", "}}"), ("{%", "%}"), ("{#", "#}")]
105
+ for open_pattern, close_pattern in jinja_pattern_pairs:
106
+ if open_pattern in user_template and close_pattern in user_template:
107
+ return True
108
+
109
+ return False
110
+
111
+
112
+ class UserTemplateSandboxEnvironment(ImmutableSandboxedEnvironment):
113
+ """Defines a robust environment for rendering Gretel's Jinja2 subset.
114
+
115
+ The use of Jinja2 sandboxing is critical. We are taking Jinja2
116
+ templates from users -- we need to take steps to ensure that users
117
+ are not able to break containment or exfiltrate server-side secrets.
118
+
119
+ This Environment definition attempts to lock down as much as we can
120
+ for a pure python implementation by extending restrictions past
121
+ that of the `ImmutableSandboxedEnviornment.` While that environment
122
+ provides a base layer of protections, including:
123
+
124
+ - No references to private attributes
125
+ - Restrictions on loop iterations (OverflowError)
126
+
127
+ We enforce further precautions:
128
+
129
+ - Forced auto-escaping templates (preventing some injection attacks).
130
+ - Prevents access to the template's `self` attribute.
131
+ - Prevents reference to variables except for a provided white-list.
132
+ - Removes support for: include, extend, macro, set, block, and nested loops
133
+ - Errors on too-long rendered templates (e.g. >128k chars).
134
+ - Remove all default Jinja filter operations except for JSONPath (negotiable).
135
+ - Uses AST static analysis to threshold the complexity of allowed templates.
136
+
137
+ """
138
+
139
+ max_rendered_len: int
140
+ max_ast_node_count: int
141
+ max_ast_depth: int
142
+ allowed_references: list[str]
143
+
144
+ def __init__(
145
+ self,
146
+ allowed_references: list[str] | None = None,
147
+ max_rendered_len: int = MAX_RENDERED_LEN,
148
+ max_ast_node_count: int = MAX_AST_NODE_COUNT,
149
+ max_ast_depth: int = MAX_AST_DEPTH,
150
+ **kwargs,
151
+ ):
152
+ """Args:
153
+ max_rendered_len (int): The maximum allowable character count for
154
+ rendered templates.
155
+
156
+ allowed_references (optional, list[str]): If set, indicates which variables
157
+ are allowed to be referenced by the Jinja2 template. If not specified,
158
+ defaults to [], which indicates that the Jinja2 template is not
159
+ allowed to refer to _any_ variables outside of itself.
160
+
161
+ max_ast_node_count (optional, int): Parameter for static analysis of
162
+ Jinja2 template complexity -- counts the number of distinct nodes
163
+ in the parsed Jinja2 AST. A large number of nodes indicates many
164
+ distinct operations within the provided user template, which can
165
+ cause long compute times, or may be malicious in nature. If not
166
+ specified, defaults to MAX_AST_NODE_COUNT set by this module.
167
+
168
+ max_ast_depth (optional, int): Parameter for static analysis of
169
+ Jinja2 template complexity -- measures the maximum depth of the
170
+ parsed Jinja2 AST. A high depth indicates a high degree of nesting
171
+ within the user template. This may can cause long compute times,
172
+ or may be malicious in nature. If not specified, defaults to
173
+ MAX_AST_DEPTH set by this module.
174
+
175
+ **kwargs: Additional kwargs passed to ImmutableSandboxedEnvironment.
176
+ """
177
+ super().__init__(autoescape=False, **kwargs)
178
+ self.max_rendered_len = max_rendered_len
179
+ self.max_ast_node_count = max_ast_node_count
180
+ self.max_ast_depth = max_ast_depth
181
+ self.allowed_references = allowed_references if allowed_references else []
182
+
183
+ ## Add on our supported filters
184
+ self.filters["jsonpath"] = jsonpath_jinja_filter
185
+
186
+ ## Cut out all but approved Jinja filters
187
+ self.filters = {k: v for k, v in self.filters.items() if k in ALLOWED_JINJA_FILTERS}
188
+
189
+ def _assert_template_has_valid_references(self, ast: Template) -> None:
190
+ """Assert that all named variable references are allowed.
191
+
192
+ Checks against the environment's allowed reference list created
193
+ at initialization.
194
+ """
195
+ template_vars = meta.find_undeclared_variables(ast)
196
+ unallowed_vars = set(template_vars) - set(self.allowed_references)
197
+ if len(unallowed_vars) > 0:
198
+ raise UserTemplateError(f"Unknown variable references in Jinja template: {unallowed_vars}")
199
+
200
+ def _assert_template_has_valid_ast_nodes(self, ast: Template) -> None:
201
+ """Assert that un-allowed operations aren't in the template."""
202
+ black_list_node_count = sum(ast_descendant_count(ast, node_type) for node_type in UNSUPPORTED_AST_NODES)
203
+
204
+ if black_list_node_count != 0:
205
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
206
+
207
+ def _assert_template_has_no_recursive_for(self, ast: Template) -> None:
208
+ """Assert that the template does not use {% for ... recursive %}"""
209
+ if any(node.recursive for node in ast.find_all(j_nodes.For)):
210
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
211
+
212
+ def _assert_template_has_no_nested_for(self, ast: Template) -> None:
213
+ """Assert that the template does not contain nested loops.
214
+
215
+ This assertion is made to ensure that templates cannot combinatorially
216
+ explode. High-range values are controlled by the `MAX_RANGE` setting
217
+ on `SandboxedEnvironment`.
218
+ """
219
+ # Check each For node in the AST to see if it has For descendants
220
+ for node in ast.find_all(j_nodes.For):
221
+ if ast_descendant_count(node, only_type=j_nodes.For):
222
+ raise UserTemplateError("Non-permitted operations in Jinja template (nested-for).")
223
+
224
+ def _assert_template_ast_complexity(self, ast: Template) -> None:
225
+ """Assert that the AST tree parsed from the template is not overly complex.
226
+
227
+ Complexity is measured by the depth of the tree (measure of nesting),
228
+ as well as the number of nodes it contains (how many distinct operations).
229
+ If either is over a fixed limit specified at initialization, the assert fails.
230
+ """
231
+ node_count = ast_descendant_count(ast)
232
+ max_depth = ast_max_depth(ast)
233
+
234
+ if node_count > self.max_ast_node_count or max_depth > self.max_ast_depth:
235
+ raise UserTemplateError("Jinja template too complex, simplify your template.")
236
+
237
+ def _assert_template_has_no_self_reference(self, ast: Template) -> None:
238
+ """Assert that the template cannot refer to its own settings.
239
+
240
+ Templates may attempt to use {{ self }} references to gain
241
+ access to properties of the template object itself. This
242
+ is disallowed.
243
+ """
244
+ if ast_count_name_references(ast, "self") != 0:
245
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
246
+
247
+ def validate_template(self, user_template: str) -> None:
248
+ """Template validations are run against the template object itself.
249
+ First-layer injection attacks are (on the parse operation) are
250
+ prevented by using `autoescape=True` on environment creation.
251
+
252
+ Afterwards, we can analyze the AST of the parsed template to detect
253
+ and mitigate a wide range of attacks.
254
+
255
+ Args:
256
+ user_template (str): A submitted user Jinja2 template.
257
+
258
+ Raises:
259
+ TemplateSyntaxError: If the provided template is malformed or
260
+ not parseable as a Jinja2 template.
261
+ UserTemplateError: If any of the assertions fail.
262
+ """
263
+ try:
264
+ ast = self.parse(user_template)
265
+ self._assert_template_has_valid_ast_nodes(ast)
266
+ self._assert_template_has_no_recursive_for(ast)
267
+ self._assert_template_has_no_nested_for(ast)
268
+ self._assert_template_ast_complexity(ast)
269
+ self._assert_template_has_no_self_reference(ast)
270
+ self._assert_template_has_valid_references(ast)
271
+ except Exception as exception:
272
+ maybe_handle_missing_filter_exception(exception, available_jinja_filters=list(self.filters.keys()))
273
+ raise exception
274
+
275
+ def _assert_rendered_text_length(self, rendered_text: str) -> None:
276
+ """Check against the length of the rendered string."""
277
+ rendered_len = len(rendered_text)
278
+ if rendered_len > self.max_rendered_len:
279
+ raise UserTemplateError(f"Rendered Jinja template too large ({rendered_len} > {self.max_rendered_len}).")
280
+
281
+ def _assert_rendered_text_has_no_builtin_descriptions(self, rendered_text: str) -> None:
282
+ """Check to make sure that the outputs aren't descriptions of methods.
283
+
284
+ In the event that the user types the name of a __builtin__
285
+ object method, but doesn't call it, we don't want to report
286
+ information about the system's memory contents.
287
+
288
+ Further, if the user made a mistake, we'd rather error out
289
+ rather than continue task processing, for instance.
290
+ """
291
+ patterns = [
292
+ r"<built-in method (.*?) of (.*?) object at 0x(.*?)>",
293
+ r"<function (.*?) at (.*?)>",
294
+ ]
295
+ for pattern in patterns:
296
+ matches = re.search(pattern, rendered_text)
297
+ if bool(matches):
298
+ raise UserTemplateError("User template has uncalled __builtin__ method.")
299
+
300
+ def _assert_rendered_text_not_empty(self, rendered_text: str) -> None:
301
+ """Check to make sure the resulting text isn't an empty string"""
302
+ if len(rendered_text) == 0:
303
+ raise UserTemplateError("User template renders to empty text.")
304
+
305
+ def validate_rendered_text(self, rendered_text: str) -> None:
306
+ """Raises UserTemplateError on invalid renders.
307
+
308
+ This is used as a post-processing step for capturing and
309
+ acting on strings before they go out the door.
310
+ """
311
+ self._assert_rendered_text_not_empty(rendered_text)
312
+ self._assert_rendered_text_length(rendered_text)
313
+ self._assert_rendered_text_has_no_builtin_descriptions(rendered_text)
314
+
315
+ def safe_render(self, user_template: str, record: dict, skip_template_validation: bool = False) -> str:
316
+ """Attempt to safely render a user's template.
317
+
318
+ Args:
319
+ user_template (str): The user submitted Jinja2 template string.
320
+ record (dict): a record of fields which are able to be referenced by the template.
321
+ skip_template_validation (optional, bool): If true, then AST checks against the
322
+ template itself will not be performed. WARNING: this should ONLY be set to true
323
+ if the template has already been validated.
324
+
325
+ Raises:
326
+ UserTemplateError: If the template cannot be rendered because the
327
+ user template does not conform to Gretel's Jinja2 subset,
328
+ is too long, or contains some attempted malicious payload.
329
+ If skip_template_validation is False, this error may also indicate
330
+ that the template itself has failed static analysis. See the error
331
+ message for more details.
332
+
333
+ RecordContentsError: If there is a system-internal error with
334
+ the supplied record data. This error is raised to prevent Jinja2
335
+ processing of potentially insecure data objects.
336
+ """
337
+ if not skip_template_validation:
338
+ self.validate_template(user_template)
339
+
340
+ record = sanitize_record(record)
341
+
342
+ try:
343
+ template = self.from_string(user_template)
344
+ rendered_text = template.render(record)
345
+ except SecurityError:
346
+ raise UserTemplateError("Non-permitted operations in Jinja template.")
347
+ except OverflowError:
348
+ raise UserTemplateError("Template too large.")
349
+ except Exception as exception:
350
+ maybe_handle_missing_filter_exception(exception, available_jinja_filters=list(self.filters.keys()))
351
+ raise exception
352
+
353
+ self.validate_rendered_text(rendered_text)
354
+
355
+ return rendered_text
356
+
357
+ def get_references(self, user_template: str) -> set[str]:
358
+ """Get all referenced variables from the provided template.
359
+
360
+ Args:
361
+ user_template (str): A user-provided Jinja template.
362
+
363
+ Returns:
364
+ set[str]: A set of all variable names referenced in
365
+ the supplied Jinja template. If no variables are
366
+ referenced, then this will be an empty list.
367
+ """
368
+ ast = self.parse(user_template)
369
+ return meta.find_undeclared_variables(ast)
370
+
371
+
372
+ def sanitize_user_exceptions(func):
373
+ """Sanitize returned user-space exceptions."""
374
+
375
+ @wraps(func)
376
+ def wrapper(*args, **kwargs):
377
+ try:
378
+ return func(*args, **kwargs)
379
+ except UserTemplateUnsupportedFiltersError as exception:
380
+ ## Informative messaging is already handled in this
381
+ ## specific case.
382
+ raise exception
383
+ except (UserTemplateError, TemplateSyntaxError):
384
+ ## All other details are wrapped in a generic error message
385
+ raise UserTemplateError(USER_PROMPT_TEMPLATE_ERROR_MESSAGE)
386
+
387
+ return wrapper
388
+
389
+
390
+ class WithJinja2UserTemplateRendering:
391
+ """Mixin class to support user-supplied Jinja2 rendering.
392
+
393
+ Provides `self.render_template(record: dict)` to the receiving
394
+ class, which can be used to safely render user-provided Jinja2
395
+ templates using `UserTemplateSandboxedEnvironment`.
396
+
397
+ This mixin also provides error message sanitization for exceptions
398
+ raised by the rendering environment.
399
+
400
+ Usage:
401
+
402
+ class Foo(WithJinja2UserTemplateRendering):
403
+ def my_func(self, user_template: str, records: list[dict]):
404
+
405
+ ## Call once per template -- must be prepared before
406
+ ## being able to call self.render_template
407
+ self.prepare_jinja2_template_renderer(user_template)
408
+
409
+ ## Can call many times after
410
+ for record in records:
411
+ self.render_template(record)
412
+ """
413
+
414
+ _template_render_fn: Callable
415
+
416
+ @sanitize_user_exceptions
417
+ def prepare_jinja2_template_renderer(self, prompt_template: str, dataset_variables: list[str]) -> None:
418
+ """Build Jinja2 template render function."""
419
+ jinja_render_env = UserTemplateSandboxEnvironment(allowed_references=dataset_variables)
420
+ jinja_render_env.validate_template(prompt_template)
421
+ self._template_render_fn = partial(
422
+ jinja_render_env.safe_render,
423
+ prompt_template,
424
+ skip_template_validation=True,
425
+ )
426
+
427
+ @sanitize_user_exceptions
428
+ def render_template(self, record: dict) -> str:
429
+ return self._template_render_fn(record)
430
+
431
+ @sanitize_user_exceptions
432
+ def prepare_jinja2_multi_template_renderer(
433
+ self,
434
+ template_name: str,
435
+ prompt_template: str,
436
+ dataset_variables: list[str],
437
+ ) -> None:
438
+ if not self._template_prepared_in_multi_template_renderer(template_name):
439
+ self._create_render_func_registry()
440
+ jinja_render_env = UserTemplateSandboxEnvironment(allowed_references=dataset_variables)
441
+ jinja_render_env.validate_template(prompt_template)
442
+ self._render_func_registry[template_name] = partial(
443
+ jinja_render_env.safe_render,
444
+ prompt_template,
445
+ skip_template_validation=True,
446
+ )
447
+
448
+ @sanitize_user_exceptions
449
+ def render_multi_template(self, template_name: str, record: dict) -> str:
450
+ if not hasattr(self, "_render_func_registry"):
451
+ raise UserTemplateError("Multi-template renderer not prepared.")
452
+ if template_name not in self._render_func_registry:
453
+ raise UserTemplateError(f"Template {template_name} not prepared.")
454
+ return self._render_func_registry[template_name](record)
455
+
456
+ def _template_prepared_in_multi_template_renderer(self, template_name: str) -> bool:
457
+ if not hasattr(self, "_render_func_registry"):
458
+ return False
459
+ return template_name in self._render_func_registry
460
+
461
+ def _create_render_func_registry(self) -> None:
462
+ if not hasattr(self, "_render_func_registry"):
463
+ self._render_func_registry = {}
@@ -0,0 +1,56 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import re
7
+
8
+ from jinja2 import TemplateAssertionError
9
+
10
+
11
+ class UserTemplateError(Exception):
12
+ """Exception for user-induced template flaws, intentional or not."""
13
+
14
+
15
+ class UserTemplateUnsupportedFiltersError(UserTemplateError):
16
+ """Specific exception for the case of unsupported filters."""
17
+
18
+
19
+ class RecordContentsError(Exception):
20
+ """Exception for cases involving the record providing template context."""
21
+
22
+
23
+ def maybe_handle_missing_filter_exception(exception: BaseException, available_jinja_filters: list[str]) -> None:
24
+ """Interpret and handle the possible case of a missing filter exception.
25
+
26
+ If this wasn't a missing filter exception, then this function will do
27
+ nothing.
28
+
29
+ Args:
30
+ exception (BaseException): The caught exception.
31
+ available_jinja_filters (list[str]): The list of Jinja filters that
32
+ are known to be available within the environment.
33
+
34
+ Raises:
35
+ UserTemplateUnsupportedFiltersError: If the exception was specifically for an unknown
36
+ or unsupported Jinja2 filter.
37
+ """
38
+ if not isinstance(exception, TemplateAssertionError):
39
+ return
40
+
41
+ exc_message = exception.message or ""
42
+
43
+ ## The missing filter message has the format:
44
+ ## "No filter named '____'"
45
+ match = re.search(r"No filter named '([^']+)'", exc_message)
46
+ if not match:
47
+ return
48
+ else:
49
+ missing_filter_name = match.group(1)
50
+ available_filter_str = ", ".join(available_jinja_filters)
51
+ raise UserTemplateUnsupportedFiltersError(
52
+ (
53
+ f"The Jinja2 filter `{{{{ ... | {missing_filter_name} }}}}` "
54
+ f"is not a permitted operation. Available filters: {available_filter_str}"
55
+ )
56
+ ) from exception
@@ -0,0 +1,32 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+
8
+ from data_designer.config.utils.io_helpers import serialize_data
9
+ from data_designer.engine.processing.ginja.exceptions import RecordContentsError
10
+
11
+
12
+ def sanitize_record(record: dict) -> dict:
13
+ """Sanitize a record into basic types.
14
+
15
+ To prevent any unexpected attributes from being callable from
16
+ the template, we apply a serdes step to ensure that the record
17
+ used as context for the rendering step consists of basic
18
+ python types (e.g. those that can be represented via JSON).
19
+
20
+ Args:
21
+ record (dict): A dictionary object which can be serialized.
22
+
23
+ Raises:
24
+ RecordContentsError if the record contents are not able
25
+ to be represented with JSON.
26
+ """
27
+ try:
28
+ ser = serialize_data(record)
29
+ except (TypeError, ValueError) as e:
30
+ raise RecordContentsError("Unexpected unserializable content found in record.") from e
31
+
32
+ return json.loads(ser)
@@ -0,0 +1,2 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,15 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING
7
+
8
+ from data_designer.lazy_heavy_imports import jsonschema
9
+
10
+ if TYPE_CHECKING:
11
+ import jsonschema
12
+
13
+
14
+ class JSONSchemaValidationError(jsonschema.ValidationError):
15
+ """Alias of ValidationError to ease imports."""
@@ -0,0 +1,83 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from copy import deepcopy
7
+ from typing import Any
8
+
9
+ from data_designer.engine.processing.gsonschema.types import JSONSchemaT
10
+
11
+
12
+ def _is_bare_dictionary_schema(schema_part: Any) -> bool:
13
+ """Classify bare dictionary schemas
14
+
15
+ A bare dictionary schema is one which looks like the following:
16
+
17
+ { "title": ... , "type": "object" }
18
+ { "type": "object" }
19
+
20
+ These schemas do not specify any "properties", just their existence.
21
+ """
22
+ if not isinstance(schema_part, dict):
23
+ return False
24
+
25
+ if schema_part.get("type") != "object":
26
+ return False
27
+
28
+ if ("title" in schema_part and len(schema_part) == 2) or (len(schema_part) == 1):
29
+ return True
30
+
31
+ return False
32
+
33
+
34
+ def forbid_additional_properties(schema: JSONSchemaT) -> JSONSchemaT:
35
+ """Transform the provided schema into one which forbids additional properties.
36
+
37
+ Args:
38
+ schema (JSONSchemaT): A JSONSchema to transform.
39
+
40
+ Returns:
41
+ JSONSchemaT: A new JSONSchema matching the provided one, but
42
+ with `additionalProperties: False` set everywhere.
43
+ """
44
+ new_schema = deepcopy(schema)
45
+
46
+ def _enforce(schema_part: Any) -> None:
47
+ if isinstance(schema_part, dict):
48
+ if schema_part.get("type") == "object" or "properties" in schema_part:
49
+ ## We need to handle the special case that the schema specifies just a bare
50
+ ## dictionary. In those cases, the implication is that _all_ dictionaries
51
+ ## are valid, so we should not forbid extra properties in that case.
52
+ allow_additional_properties = _is_bare_dictionary_schema(schema_part)
53
+ schema_part["additionalProperties"] = allow_additional_properties
54
+
55
+ # Traverse into nested schemas.
56
+ for key, value in schema_part.items():
57
+ if key in ("properties", "patternProperties"):
58
+ if isinstance(value, dict):
59
+ for sub_schema in value.values():
60
+ _enforce(sub_schema)
61
+ elif key == "items":
62
+ if isinstance(value, dict):
63
+ _enforce(value)
64
+ elif isinstance(value, list):
65
+ for item in value:
66
+ _enforce(item)
67
+ elif key in ("allOf", "anyOf", "oneOf"):
68
+ if isinstance(value, list):
69
+ for item in value:
70
+ _enforce(item)
71
+ elif key in ("not", "if", "then", "else"):
72
+ if isinstance(value, dict):
73
+ _enforce(value)
74
+ elif key == "$defs":
75
+ for sub_schema in value.values():
76
+ _enforce(sub_schema)
77
+
78
+ elif isinstance(schema_part, list):
79
+ for item in schema_part:
80
+ _enforce(item)
81
+
82
+ _enforce(new_schema)
83
+ return new_schema
@@ -0,0 +1,10 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Any, TypeVar
7
+
8
+ T_primitive = TypeVar("T_primitive", str, int, float, bool)
9
+ DataObjectT = dict | list | str | int | float | bool
10
+ JSONSchemaT = dict[str, Any]