data-designer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. data_designer/__init__.py +15 -0
  2. data_designer/_version.py +34 -0
  3. data_designer/cli/README.md +236 -0
  4. data_designer/cli/__init__.py +6 -0
  5. data_designer/cli/commands/__init__.py +2 -0
  6. data_designer/cli/commands/list.py +130 -0
  7. data_designer/cli/commands/models.py +10 -0
  8. data_designer/cli/commands/providers.py +11 -0
  9. data_designer/cli/commands/reset.py +100 -0
  10. data_designer/cli/controllers/__init__.py +7 -0
  11. data_designer/cli/controllers/model_controller.py +246 -0
  12. data_designer/cli/controllers/provider_controller.py +317 -0
  13. data_designer/cli/forms/__init__.py +20 -0
  14. data_designer/cli/forms/builder.py +51 -0
  15. data_designer/cli/forms/field.py +180 -0
  16. data_designer/cli/forms/form.py +59 -0
  17. data_designer/cli/forms/model_builder.py +125 -0
  18. data_designer/cli/forms/provider_builder.py +76 -0
  19. data_designer/cli/main.py +44 -0
  20. data_designer/cli/repositories/__init__.py +8 -0
  21. data_designer/cli/repositories/base.py +39 -0
  22. data_designer/cli/repositories/model_repository.py +42 -0
  23. data_designer/cli/repositories/provider_repository.py +43 -0
  24. data_designer/cli/services/__init__.py +7 -0
  25. data_designer/cli/services/model_service.py +116 -0
  26. data_designer/cli/services/provider_service.py +111 -0
  27. data_designer/cli/ui.py +448 -0
  28. data_designer/cli/utils.py +47 -0
  29. data_designer/config/__init__.py +2 -0
  30. data_designer/config/analysis/column_profilers.py +89 -0
  31. data_designer/config/analysis/column_statistics.py +274 -0
  32. data_designer/config/analysis/dataset_profiler.py +60 -0
  33. data_designer/config/analysis/utils/errors.py +8 -0
  34. data_designer/config/analysis/utils/reporting.py +188 -0
  35. data_designer/config/base.py +68 -0
  36. data_designer/config/column_configs.py +354 -0
  37. data_designer/config/column_types.py +168 -0
  38. data_designer/config/config_builder.py +660 -0
  39. data_designer/config/data_designer_config.py +40 -0
  40. data_designer/config/dataset_builders.py +11 -0
  41. data_designer/config/datastore.py +151 -0
  42. data_designer/config/default_model_settings.py +123 -0
  43. data_designer/config/errors.py +19 -0
  44. data_designer/config/interface.py +54 -0
  45. data_designer/config/models.py +231 -0
  46. data_designer/config/preview_results.py +32 -0
  47. data_designer/config/processors.py +41 -0
  48. data_designer/config/sampler_constraints.py +51 -0
  49. data_designer/config/sampler_params.py +604 -0
  50. data_designer/config/seed.py +145 -0
  51. data_designer/config/utils/code_lang.py +83 -0
  52. data_designer/config/utils/constants.py +313 -0
  53. data_designer/config/utils/errors.py +19 -0
  54. data_designer/config/utils/info.py +88 -0
  55. data_designer/config/utils/io_helpers.py +273 -0
  56. data_designer/config/utils/misc.py +81 -0
  57. data_designer/config/utils/numerical_helpers.py +28 -0
  58. data_designer/config/utils/type_helpers.py +100 -0
  59. data_designer/config/utils/validation.py +336 -0
  60. data_designer/config/utils/visualization.py +427 -0
  61. data_designer/config/validator_params.py +96 -0
  62. data_designer/engine/__init__.py +2 -0
  63. data_designer/engine/analysis/column_profilers/base.py +55 -0
  64. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
  65. data_designer/engine/analysis/column_profilers/registry.py +20 -0
  66. data_designer/engine/analysis/column_statistics.py +142 -0
  67. data_designer/engine/analysis/dataset_profiler.py +125 -0
  68. data_designer/engine/analysis/errors.py +7 -0
  69. data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
  70. data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
  71. data_designer/engine/column_generators/__init__.py +2 -0
  72. data_designer/engine/column_generators/generators/__init__.py +2 -0
  73. data_designer/engine/column_generators/generators/base.py +61 -0
  74. data_designer/engine/column_generators/generators/expression.py +63 -0
  75. data_designer/engine/column_generators/generators/llm_generators.py +172 -0
  76. data_designer/engine/column_generators/generators/samplers.py +75 -0
  77. data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
  78. data_designer/engine/column_generators/generators/validation.py +147 -0
  79. data_designer/engine/column_generators/registry.py +56 -0
  80. data_designer/engine/column_generators/utils/errors.py +13 -0
  81. data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
  82. data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
  83. data_designer/engine/configurable_task.py +82 -0
  84. data_designer/engine/dataset_builders/artifact_storage.py +181 -0
  85. data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
  86. data_designer/engine/dataset_builders/errors.py +13 -0
  87. data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
  88. data_designer/engine/dataset_builders/utils/__init__.py +2 -0
  89. data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
  90. data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
  91. data_designer/engine/dataset_builders/utils/dag.py +56 -0
  92. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
  93. data_designer/engine/dataset_builders/utils/errors.py +13 -0
  94. data_designer/engine/errors.py +49 -0
  95. data_designer/engine/model_provider.py +75 -0
  96. data_designer/engine/models/__init__.py +2 -0
  97. data_designer/engine/models/errors.py +308 -0
  98. data_designer/engine/models/facade.py +225 -0
  99. data_designer/engine/models/litellm_overrides.py +162 -0
  100. data_designer/engine/models/parsers/__init__.py +2 -0
  101. data_designer/engine/models/parsers/errors.py +34 -0
  102. data_designer/engine/models/parsers/parser.py +236 -0
  103. data_designer/engine/models/parsers/postprocessors.py +93 -0
  104. data_designer/engine/models/parsers/tag_parsers.py +60 -0
  105. data_designer/engine/models/parsers/types.py +82 -0
  106. data_designer/engine/models/recipes/base.py +79 -0
  107. data_designer/engine/models/recipes/response_recipes.py +291 -0
  108. data_designer/engine/models/registry.py +118 -0
  109. data_designer/engine/models/usage.py +75 -0
  110. data_designer/engine/models/utils.py +38 -0
  111. data_designer/engine/processing/ginja/__init__.py +2 -0
  112. data_designer/engine/processing/ginja/ast.py +64 -0
  113. data_designer/engine/processing/ginja/environment.py +461 -0
  114. data_designer/engine/processing/ginja/exceptions.py +54 -0
  115. data_designer/engine/processing/ginja/record.py +30 -0
  116. data_designer/engine/processing/gsonschema/__init__.py +2 -0
  117. data_designer/engine/processing/gsonschema/exceptions.py +8 -0
  118. data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
  119. data_designer/engine/processing/gsonschema/types.py +8 -0
  120. data_designer/engine/processing/gsonschema/validators.py +143 -0
  121. data_designer/engine/processing/processors/base.py +15 -0
  122. data_designer/engine/processing/processors/drop_columns.py +46 -0
  123. data_designer/engine/processing/processors/registry.py +20 -0
  124. data_designer/engine/processing/utils.py +120 -0
  125. data_designer/engine/registry/base.py +97 -0
  126. data_designer/engine/registry/data_designer_registry.py +37 -0
  127. data_designer/engine/registry/errors.py +10 -0
  128. data_designer/engine/resources/managed_dataset_generator.py +35 -0
  129. data_designer/engine/resources/managed_dataset_repository.py +194 -0
  130. data_designer/engine/resources/managed_storage.py +63 -0
  131. data_designer/engine/resources/resource_provider.py +46 -0
  132. data_designer/engine/resources/seed_dataset_data_store.py +66 -0
  133. data_designer/engine/sampling_gen/column.py +89 -0
  134. data_designer/engine/sampling_gen/constraints.py +95 -0
  135. data_designer/engine/sampling_gen/data_sources/base.py +214 -0
  136. data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
  137. data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
  138. data_designer/engine/sampling_gen/entities/__init__.py +2 -0
  139. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  140. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
  141. data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
  142. data_designer/engine/sampling_gen/entities/errors.py +8 -0
  143. data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
  144. data_designer/engine/sampling_gen/entities/person.py +142 -0
  145. data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
  146. data_designer/engine/sampling_gen/errors.py +24 -0
  147. data_designer/engine/sampling_gen/generator.py +121 -0
  148. data_designer/engine/sampling_gen/jinja_utils.py +60 -0
  149. data_designer/engine/sampling_gen/people_gen.py +203 -0
  150. data_designer/engine/sampling_gen/person_constants.py +54 -0
  151. data_designer/engine/sampling_gen/schema.py +143 -0
  152. data_designer/engine/sampling_gen/schema_builder.py +59 -0
  153. data_designer/engine/sampling_gen/utils.py +40 -0
  154. data_designer/engine/secret_resolver.py +80 -0
  155. data_designer/engine/validators/__init__.py +17 -0
  156. data_designer/engine/validators/base.py +36 -0
  157. data_designer/engine/validators/local_callable.py +34 -0
  158. data_designer/engine/validators/python.py +245 -0
  159. data_designer/engine/validators/remote.py +83 -0
  160. data_designer/engine/validators/sql.py +60 -0
  161. data_designer/errors.py +5 -0
  162. data_designer/essentials/__init__.py +137 -0
  163. data_designer/interface/__init__.py +2 -0
  164. data_designer/interface/data_designer.py +351 -0
  165. data_designer/interface/errors.py +16 -0
  166. data_designer/interface/results.py +55 -0
  167. data_designer/logging.py +161 -0
  168. data_designer/plugin_manager.py +83 -0
  169. data_designer/plugins/__init__.py +6 -0
  170. data_designer/plugins/errors.py +10 -0
  171. data_designer/plugins/plugin.py +69 -0
  172. data_designer/plugins/registry.py +86 -0
  173. data_designer-0.1.0.dist-info/METADATA +173 -0
  174. data_designer-0.1.0.dist-info/RECORD +177 -0
  175. data_designer-0.1.0.dist-info/WHEEL +4 -0
  176. data_designer-0.1.0.dist-info/entry_points.txt +2 -0
  177. data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,93 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional, Type
5
+
6
+ import json_repair
7
+ from pydantic import BaseModel, ValidationError
8
+
9
+ from data_designer.engine.models.parsers.types import (
10
+ CodeBlock,
11
+ LLMStructuredResponse,
12
+ PydanticTypeBlock,
13
+ StructuredDataBlock,
14
+ TextBlock,
15
+ )
16
+
17
+
18
+ def merge_text_blocks(
19
+ structured_response: LLMStructuredResponse,
20
+ ) -> LLMStructuredResponse:
21
+ processed_response = structured_response.model_copy()
22
+ processed_response.parsed = []
23
+ accumulator = None
24
+ for block in structured_response.parsed:
25
+ if isinstance(block, TextBlock):
26
+ if accumulator is not None:
27
+ accumulator = TextBlock(text=accumulator.text + block.text)
28
+ else:
29
+ accumulator = block
30
+ else:
31
+ if accumulator is not None:
32
+ processed_response.parsed.append(accumulator)
33
+ accumulator = None
34
+
35
+ processed_response.parsed.append(block)
36
+
37
+ if accumulator:
38
+ processed_response.parsed.append(accumulator)
39
+
40
+ return processed_response
41
+
42
+
43
+ def deserialize_json_code(
44
+ structured_response: LLMStructuredResponse,
45
+ ) -> LLMStructuredResponse:
46
+ processed_response = structured_response.model_copy()
47
+ processed_response.parsed = []
48
+
49
+ for block in structured_response.parsed:
50
+ if isinstance(block, CodeBlock) and block.code_lang == "json":
51
+ deserialized = json_repair.loads(block.code)
52
+
53
+ block = StructuredDataBlock(serialized=block.code, obj=deserialized)
54
+
55
+ processed_response.parsed.append(block)
56
+ else:
57
+ processed_response.parsed.append(block)
58
+
59
+ return processed_response
60
+
61
+
62
+ class RealizePydanticTypes:
63
+ types: list[Type[BaseModel]]
64
+
65
+ def __init__(self, types: list[Type[BaseModel]]):
66
+ self.types = types
67
+
68
+ def _fit_types(self, obj: dict) -> Optional[BaseModel]:
69
+ final_obj = None
70
+
71
+ for t in self.types:
72
+ try:
73
+ final_obj = t.model_validate(obj)
74
+ except ValidationError:
75
+ pass
76
+
77
+ return final_obj
78
+
79
+ def __call__(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse:
80
+ processed_response = structured_response.model_copy()
81
+ processed_response.parsed = []
82
+
83
+ for block in structured_response.parsed:
84
+ if isinstance(block, StructuredDataBlock):
85
+ new_block = block
86
+ pydantic_obj = self._fit_types(block.obj)
87
+ if pydantic_obj:
88
+ new_block = PydanticTypeBlock(serialized=block.serialized, obj=pydantic_obj)
89
+ processed_response.parsed.append(new_block)
90
+ else:
91
+ processed_response.parsed.append(block)
92
+
93
+ return processed_response
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from lxml.etree import _Element
5
+
6
+ from data_designer.engine.models.parsers.types import CodeBlock, TextBlock
7
+
8
+
9
+ def text_parser(element: _Element) -> TextBlock:
10
+ return TextBlock(text=element.text if element.text else "")
11
+
12
+
13
+ def text_parser_keep_markup(element: _Element) -> TextBlock:
14
+ body = element.text if element.text else ""
15
+ return TextBlock(text=f"<{element.tag}>{body}</{element.tag}>")
16
+
17
+
18
+ def inline_code_parser(element: _Element) -> TextBlock:
19
+ return TextBlock(text=f"`{element.text if element.text else ''}`")
20
+
21
+
22
+ def code_block_parser(element: _Element) -> CodeBlock:
23
+ """Parse a <pre><code> element node.
24
+
25
+ This parser handles the special case of Markdown->HTML conversion
26
+ for fenced code blocks. These take on the form:
27
+
28
+ ```xx
29
+ ...
30
+ ```
31
+
32
+ <pre><code class="language-xx">...</code></pre>
33
+
34
+ This parser is intended to be attached to the special case of "pre.code"
35
+ tag hierarchies.
36
+
37
+ Syntax Handling
38
+
39
+ If the syntax is not specified, e.g. ``<code>...</code>`` or
40
+ ``<code class="">...</code>``, then the syntax field is returned
41
+ as None. However, the parser does not _enforce_ the prefix
42
+ `language-` on the value of the class attribute.
43
+ If it is not present, then the entire value
44
+
45
+ Args:
46
+ element (lxml.etree._Element): An element of the lxml-parsed
47
+ element tree.
48
+
49
+ Returns:
50
+ CodeBlock: Datat structured containing both the body of the code
51
+ as well as the specified synax of the code block.
52
+
53
+ """
54
+ prefix = "language-"
55
+ language_identifier = element.attrib.get("class", "")
56
+ language_identifier = language_identifier.removeprefix(prefix)
57
+ return CodeBlock(
58
+ code=element.text.strip() if element.text else "",
59
+ code_lang=language_identifier if language_identifier else None,
60
+ )
@@ -0,0 +1,82 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Optional, Protocol, Type, runtime_checkable
5
+
6
+ from lxml.etree import _Element
7
+ from pydantic import BaseModel, Field
8
+ from typing_extensions import Self
9
+
10
+
11
+ class LLMStructuredResponse(BaseModel):
12
+ """Output format for the LLM Response Parser."""
13
+
14
+ response: str = Field(description="Raw Markdown/Markup response received from the LLM and input to the parser.")
15
+ markup: str = Field(description="Markup/HTML resulting from running Markdown parsing on response.")
16
+ parsed: list[BaseModel] = Field(
17
+ default_factory=list,
18
+ description="Structured content parsed from markup. Elements of this list are in document-order.",
19
+ )
20
+
21
+ def head(self, n: int) -> Self:
22
+ """Retain only the first n elements of the parsed response."""
23
+ out = self.model_copy()
24
+ out.parsed = out.parsed[:n]
25
+ return out
26
+
27
+ def tail(self, n: int) -> Self:
28
+ """Retain only the last n elements of the parsed response."""
29
+ out = self.model_copy()
30
+ out.parsed = out.parsed[-n:]
31
+ return out
32
+
33
+ def filter(self, block_types: list[Type[BaseModel]]) -> Self:
34
+ out = self.model_copy()
35
+ out.parsed = [b for b in out.parsed if isinstance(b, tuple(block_types))]
36
+ return out
37
+
38
+
39
+ @runtime_checkable
40
+ class TagParser(Protocol):
41
+ """Protocol for tag parsing implementations.
42
+
43
+ All TagParsers are objects which can take as input an `lxml`
44
+ element, do some computation, and return some kind of structured
45
+ output, represented as a subclass of Pydantic `BaseModel`.
46
+ This protocol implementation can cover both classes as well
47
+ as curried fuctions as parsers (e.g. `partial`).
48
+ """
49
+
50
+ def __call__(self, element: _Element) -> BaseModel: ...
51
+
52
+
53
+ @runtime_checkable
54
+ class PostProcessor(Protocol):
55
+ """Protocol for parsed output postprocessing implementations.
56
+
57
+ Implementations of this protocol are used to transform the results of
58
+ the LLM response parser while retaining the same output structure.
59
+ This is done so that PostProcessor implementations can be chained
60
+ together.
61
+ """
62
+
63
+ def __call__(self, structured_response: LLMStructuredResponse) -> LLMStructuredResponse: ...
64
+
65
+
66
+ class TextBlock(BaseModel):
67
+ text: str
68
+
69
+
70
+ class CodeBlock(BaseModel):
71
+ code: str
72
+ code_lang: Optional[str] = None
73
+
74
+
75
+ class StructuredDataBlock(BaseModel):
76
+ serialized: str
77
+ obj: Any
78
+
79
+
80
+ class PydanticTypeBlock(BaseModel):
81
+ serialized: str
82
+ obj: BaseModel
@@ -0,0 +1,79 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import abc
5
+ from collections.abc import Callable
6
+ from typing import Generic, TypeVar
7
+
8
+ T = TypeVar("T")
9
+
10
+
11
+ class ResponseRecipe(abc.ABC, Generic[T]):
12
+ """Base class for defining response recipes.
13
+
14
+ Response recipes contain all necessary information for
15
+ getting an LLM to perform a particular common task,
16
+ like outputting code in a desired format or following
17
+ structured output.
18
+ """
19
+
20
+ @abc.abstractmethod
21
+ def _build_parser_fn(self) -> Callable[[str], T]:
22
+ """Build the recipe's output parser function."""
23
+ ...
24
+
25
+ @property
26
+ @abc.abstractmethod
27
+ def example_template(self) -> str: ...
28
+
29
+ @abc.abstractmethod
30
+ def serialize_output(self, output: T) -> str:
31
+ """Serialize an instance of the parser output."""
32
+ ...
33
+
34
+ @abc.abstractmethod
35
+ def deserialize_output(self, serialized_output: str) -> T:
36
+ """Deserialize a serialized instance of the parser output."""
37
+ ...
38
+
39
+ def __init__(self):
40
+ self._parse_fn = self._build_parser_fn()
41
+
42
+ @property
43
+ def task_instructions(self) -> str | None:
44
+ """Specifies task instructions.
45
+
46
+ These instructions lay out the particular task information the
47
+ LLM requires in order to carry out the function of the recipe.
48
+ """
49
+ return None
50
+
51
+ def parse(self, response: str) -> T:
52
+ """Apply the recipe's parser to a raw model output."""
53
+ return self._parse_fn(response)
54
+
55
+ def generate_response_example(self, example: T) -> str:
56
+ """Create a serialized response example that the parser would admit."""
57
+ return self.example_template.format(example=example)
58
+
59
+ def apply_recipe_to_user_prompt(self, user_prompt: str) -> str:
60
+ """Appends recipe specific task instructions if applicable.
61
+
62
+ Args:
63
+ user_prompt (str): User prompt to be appended with recipe specific task instructions if applicable.
64
+
65
+ Returns:
66
+ str: Final user prompt
67
+ """
68
+ return f"{user_prompt}\n\n{self.task_instructions}" if self.task_instructions else user_prompt
69
+
70
+ def apply_recipe_to_system_prompt(self, system_prompt: str | None) -> str:
71
+ """Appends recipe specific task instructions if applicable.
72
+
73
+ Args:
74
+ system_prompt (str): System prompt to be appended with recipe specific task instructions if applicable.
75
+
76
+ Returns:
77
+ str: Final system prompt
78
+ """
79
+ return system_prompt
@@ -0,0 +1,291 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from collections.abc import Callable
5
+ import json
6
+
7
+ from pydantic import BaseModel
8
+
9
+ from data_designer.config.utils.code_lang import CodeLang
10
+ from data_designer.engine.models.parsers.errors import ParserException
11
+ from data_designer.engine.models.parsers.parser import LLMResponseParser
12
+ from data_designer.engine.models.parsers.postprocessors import (
13
+ StructuredDataBlock,
14
+ deserialize_json_code,
15
+ merge_text_blocks,
16
+ )
17
+ from data_designer.engine.models.parsers.types import CodeBlock
18
+ from data_designer.engine.models.recipes.base import (
19
+ ResponseRecipe,
20
+ )
21
+ from data_designer.engine.processing.gsonschema.validators import JSONSchemaValidationError, validate
22
+
23
+
24
+ class TextResponseRecipe(ResponseRecipe[str]):
25
+ """Default text-parser.
26
+
27
+ This parser is meant to cover the "pass-through" case of natural language LLM responses.
28
+ """
29
+
30
+ @property
31
+ def example_template(self) -> str:
32
+ return "{example}"
33
+
34
+ def serialize_output(self, output: str) -> str:
35
+ return output
36
+
37
+ def deserialize_output(self, serialized_output: str) -> str:
38
+ return serialized_output
39
+
40
+ def _build_parser_fn(self) -> Callable[[str], str]:
41
+ parser = LLMResponseParser(
42
+ postprocessors=[
43
+ merge_text_blocks,
44
+ ]
45
+ )
46
+
47
+ return lambda x: parser.parse(x).response
48
+
49
+
50
+ class StructuredResponseRecipe(ResponseRecipe[dict]):
51
+ """Recipe for structured responses.
52
+
53
+ This recipe is intended to cover the generic case of
54
+ prompting-based requests for structured data outputs,
55
+ and the structure in question is determined by a
56
+ provided JSON Schema.
57
+
58
+ The LLM's response us validated against the provided
59
+ JSON Schema, however the object returned is python
60
+ dictionary obtained from deserializing the LLM's
61
+ JSON response.
62
+ """
63
+
64
+ json_schema: dict
65
+ pruning: bool
66
+ no_extra_properties: bool
67
+
68
+ def __init__(
69
+ self,
70
+ json_schema: dict,
71
+ pruning: bool = True,
72
+ no_extra_properties: bool = True,
73
+ **kwargs,
74
+ ):
75
+ """Initialize StructuredResponseRecipe.
76
+
77
+ Args:
78
+ json_schema (dict): A target JSON schema that the LLM
79
+ should adhere to when making its response.
80
+ pruning (bool): If `True`, then any extra fields in the returned
81
+ JSON object will be removed. Otherwise, they are retained,
82
+ which could raise validation errors. Default=True
83
+ no_extra_properties (bool) If `True`, then validation will fail
84
+ if extra properties are encountered in the returned JSON response.
85
+ Default=True.
86
+ """
87
+ super().__init__(**kwargs)
88
+ self.json_schema = json_schema
89
+ self.pruning = pruning
90
+ self.no_extra_properties = no_extra_properties
91
+
92
+ @property
93
+ def task_instructions(self) -> str:
94
+ return (
95
+ "* Your response must be in JSON format.\n"
96
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
97
+ "* The JSON format is given as a JSON Schema description within <response_schema> markup tags.\n\n"
98
+ f"<response_schema>\n{self.schema}\n</response_schema>"
99
+ )
100
+
101
+ @property
102
+ def example_template(self) -> str:
103
+ return "```json\n{example}\n```"
104
+
105
+ def generate_response_example(self, example: dict) -> str:
106
+ return self.example_template.format(example=json.dumps(example))
107
+
108
+ @property
109
+ def schema(self) -> str:
110
+ return json.dumps(self.json_schema)
111
+
112
+ def serialize_output(self, output: dict) -> str:
113
+ return json.dumps(output, ensure_ascii=False)
114
+
115
+ def deserialize_output(self, serialized_output: str) -> dict:
116
+ return json.loads(serialized_output)
117
+
118
+ @property
119
+ def _validate_args(self):
120
+ return {
121
+ "schema": self.json_schema,
122
+ "pruning": self.pruning,
123
+ "no_extra_properties": self.no_extra_properties,
124
+ }
125
+
126
+ def _build_parser_fn(self) -> Callable[[str], dict]:
127
+ parser = LLMResponseParser(
128
+ postprocessors=[
129
+ merge_text_blocks,
130
+ deserialize_json_code,
131
+ ]
132
+ )
133
+
134
+ def parse_fn(response: str) -> dict:
135
+ try:
136
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
137
+ return validate(obj, **self._validate_args)
138
+ except IndexError:
139
+ raise ParserException(
140
+ "No parsable JSON structure within ```json markdown fence.",
141
+ source=response,
142
+ ) from None
143
+ except JSONSchemaValidationError as exc:
144
+ raise ParserException(
145
+ "Response doesn't match requested <response_schema>\n" + str(exc),
146
+ source=response,
147
+ ) from None
148
+
149
+ return parse_fn
150
+
151
+
152
+ class PydanticResponseRecipe(ResponseRecipe[BaseModel]):
153
+ """Recipe for Pydantic responses.
154
+
155
+ This recipe covers the case that we have a Pydantic
156
+ data type (BaseModel) already specified in the runtime
157
+ making LLM calls, and we want to obtain an object of
158
+ that same data type as the output from the parser.
159
+
160
+ This recipe operates in a very similar fashion to
161
+ `StructuredResponseRecipe` except that it is initialized
162
+ from a Pydantic `BaseModel` and does the extra step of
163
+ validating against that `BaseModel` using
164
+ `BaseModel.model_validate` for its return.
165
+ """
166
+
167
+ data_type: type[BaseModel]
168
+
169
+ def __init__(self, data_type: type[BaseModel], **kwargs):
170
+ """Initialize a PydanticResponseRecipe.
171
+
172
+ Args:
173
+ data_type (type(BaseModel)): The target Pydantic BaseModel
174
+ subclass that the LLM should adhere to in its response,
175
+ and defines the output type of the parser.
176
+ """
177
+ super().__init__(**kwargs)
178
+ self.data_type = data_type
179
+
180
+ @property
181
+ def schema(self) -> str:
182
+ return json.dumps(self.data_type.model_json_schema())
183
+
184
+ @property
185
+ def task_instructions(self) -> str:
186
+ return (
187
+ "* Your response must be in JSON format.\n"
188
+ "* Your JSON response must be returned within a Markdown, ```json code fence.\n"
189
+ "* The JSON format is given as a JSON Schema description within <response_schema> markup tags.\n\n"
190
+ f"<response_schema>\n{self.schema}\n</response_schema>"
191
+ )
192
+
193
+ @property
194
+ def example_template(self) -> str:
195
+ return "```json\n{example}\n```"
196
+
197
+ def generate_response_example(self, example: BaseModel) -> str:
198
+ return self.example_template.format(example=example.model_dump_json())
199
+
200
+ def serialize_output(self, output: BaseModel) -> str:
201
+ return output.model_dump_json()
202
+
203
+ def deserialize_output(self, serialized_output: str) -> BaseModel:
204
+ return self.data_type.model_validate_json(serialized_output)
205
+
206
+ def _build_parser_fn(self) -> Callable[[str], BaseModel]:
207
+ parser = LLMResponseParser(
208
+ postprocessors=[
209
+ merge_text_blocks,
210
+ deserialize_json_code,
211
+ ]
212
+ )
213
+
214
+ def parse_fn(response: str) -> BaseModel:
215
+ try:
216
+ obj = parser.parse(response).filter([StructuredDataBlock]).parsed.pop().obj
217
+ return self.data_type.model_validate(obj)
218
+ except IndexError:
219
+ raise ParserException(
220
+ "No parsable JSON structure within ```json markdown fence.",
221
+ source=response,
222
+ ) from None
223
+ except Exception as exc:
224
+ raise ParserException(
225
+ "Response doesn't match requested <response_schema>\n" + str(exc),
226
+ source=response,
227
+ ) from None
228
+
229
+ return parse_fn
230
+
231
+
232
+ class CodeResponseRecipe(ResponseRecipe[str]):
233
+ """Obtain a code snippet from an LLM."""
234
+
235
+ def __init__(self, syntax: str | CodeLang, **kwargs):
236
+ """Initialize a CodeResponseRecipe.
237
+
238
+ Args:
239
+ syntax (str | CodeLang): The code syntax that the
240
+ LLM should adhere to, e.g. `"python"`, `"sql"`, etc.
241
+ """
242
+ super().__init__(**kwargs)
243
+ self.syntax = CodeLang.parse_lang(syntax)
244
+
245
+ @property
246
+ def task_instructions(self) -> str:
247
+ return (
248
+ f"* Your response must be code written in {self.syntax}.\n"
249
+ "* You will follow accepted and common syntax and best-practices.\n"
250
+ f"* Your response will be given in markdown code fences specifying the correct language.\n"
251
+ "* Only respond with a SINGLE code block."
252
+ )
253
+
254
+ @property
255
+ def example_template(self) -> str:
256
+ return f"```{self.syntax}\n{{example}}\n```\n"
257
+
258
+ def serialize_output(self, output: str) -> str:
259
+ return output
260
+
261
+ def deserialize_output(self, serialized_output: str) -> str:
262
+ return serialized_output
263
+
264
+ def _build_parser_fn(self) -> Callable[[str], str]:
265
+ parser = LLMResponseParser(
266
+ postprocessors=[
267
+ merge_text_blocks,
268
+ ]
269
+ )
270
+
271
+ def parse_fn(response: str) -> str:
272
+ try:
273
+ code_block = parser.parse(response).filter([CodeBlock]).parsed.pop()
274
+ # For the type checker -- should always pass
275
+ assert isinstance(code_block, CodeBlock)
276
+ except IndexError:
277
+ raise ParserException(
278
+ "No parsable code response.",
279
+ source=response,
280
+ ) from None
281
+
282
+ # Only report this as a parser error if there was a mismatch.
283
+ if code_block.code_lang and code_block.code_lang != self.syntax:
284
+ raise ParserException(
285
+ f"Responded with code not matching the requested syntax ({self.syntax}).",
286
+ source=response,
287
+ )
288
+
289
+ return code_block.code.strip()
290
+
291
+ return parse_fn