adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Literal, Type, Union, get_args, get_origin
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ValidationError
|
|
7
|
+
|
|
8
|
+
from adaptive_harmony import InferenceModel, StringThread
|
|
9
|
+
from adaptive_harmony.core.reasoning import remove_reasoning
|
|
10
|
+
|
|
11
|
+
FIX_OUTPUT_FORMAT = """Below, the COMPLETION did not satisfy the constraints given in the PROMPT. Please rewrite the completion to comply with constraints, nothing else.
|
|
12
|
+
|
|
13
|
+
PROMPT
|
|
14
|
+
The output should be a well-formatted JSON instance that conforms to the JSON schema below. All fields are required. Do not output anything else other than the JSON.
|
|
15
|
+
|
|
16
|
+
As an example, for the schema
|
|
17
|
+
{{
|
|
18
|
+
"foo": {{
|
|
19
|
+
"items":{{"type": "string"}},
|
|
20
|
+
"type": "array"
|
|
21
|
+
}},
|
|
22
|
+
"bar": {{"type": "integer"}}
|
|
23
|
+
}}
|
|
24
|
+
the object {{"foo": ["hey", "bye"], "bar": 1}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["hey", "bye"], "bar":"1" }}}} is not well-formatted.
|
|
25
|
+
|
|
26
|
+
Here is the output JSON schema:
|
|
27
|
+
{json_schema}
|
|
28
|
+
|
|
29
|
+
COMPLETION
|
|
30
|
+
{completion}
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class JsonParseError(Exception):
|
|
35
|
+
def __init__(self, message: str, completion: str):
|
|
36
|
+
super().__init__(message)
|
|
37
|
+
self.completion = completion
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_pydantic_schema(base_model: Type[BaseModel]) -> str:
|
|
41
|
+
schema = base_model.model_json_schema()
|
|
42
|
+
for prop in schema.get("properties", {}).values():
|
|
43
|
+
prop.pop("title", None)
|
|
44
|
+
return json.dumps(schema, indent=2)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class OutputParserException(Exception):
|
|
48
|
+
"""Exception raised for parsing errors."""
|
|
49
|
+
|
|
50
|
+
def __init__(self, message: str, llm_output: str | None = None):
|
|
51
|
+
super().__init__(message)
|
|
52
|
+
self.llm_output = llm_output
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def pydantic_parse[T: BaseModel](text: str, pydantic_object: type[T]) -> T:
|
|
56
|
+
"""Parse the output of an LLM call to a pydantic object.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
text: The output of the LLM call.
|
|
60
|
+
pydantic_object: The pydantic model to parse into.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The parsed pydantic object.
|
|
64
|
+
"""
|
|
65
|
+
# Remove Qwen3 reasoning
|
|
66
|
+
text = remove_reasoning(text)
|
|
67
|
+
|
|
68
|
+
# Strip initial whitespace
|
|
69
|
+
text = text.strip()
|
|
70
|
+
|
|
71
|
+
def parse_json_with_completion(json_text):
|
|
72
|
+
"""Parse JSON, handling partial JSON by completing missing brackets."""
|
|
73
|
+
# Strip whitespace and backticks
|
|
74
|
+
json_text = json_text.strip(" \n\r\t`")
|
|
75
|
+
|
|
76
|
+
# Handle action_input special case - escape special chars
|
|
77
|
+
if '"action_input"' in json_text:
|
|
78
|
+
|
|
79
|
+
def fix_action_input(match):
|
|
80
|
+
value = match.group(2)
|
|
81
|
+
value = re.sub(r"\n", r"\\n", value)
|
|
82
|
+
value = re.sub(r"\r", r"\\r", value)
|
|
83
|
+
value = re.sub(r"\t", r"\\t", value)
|
|
84
|
+
value = re.sub(r'(?<!\\)"', r"\"", value)
|
|
85
|
+
return match.group(1) + value + match.group(3)
|
|
86
|
+
|
|
87
|
+
json_text = re.sub(r'("action_input"\:\s*")(.*?)(")', fix_action_input, json_text, flags=re.DOTALL)
|
|
88
|
+
|
|
89
|
+
# NOTE Axel: gemma likes to escape the left bracket, patching for now
|
|
90
|
+
json_text = json_text.replace(r"\\[", "[")
|
|
91
|
+
|
|
92
|
+
# Try parsing as-is first
|
|
93
|
+
try:
|
|
94
|
+
return json.loads(json_text)
|
|
95
|
+
except json.JSONDecodeError:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
# Handle partial JSON - complete missing brackets and quotes
|
|
99
|
+
chars = list(json_text)
|
|
100
|
+
stack = []
|
|
101
|
+
in_string = False
|
|
102
|
+
escaped = False
|
|
103
|
+
|
|
104
|
+
for i, char in enumerate(chars):
|
|
105
|
+
if in_string:
|
|
106
|
+
if char == '"' and not escaped:
|
|
107
|
+
in_string = False
|
|
108
|
+
elif char == "\n" and not escaped:
|
|
109
|
+
chars[i] = "\\n"
|
|
110
|
+
escaped = char == "\\" and not escaped
|
|
111
|
+
elif char == '"':
|
|
112
|
+
in_string = True
|
|
113
|
+
escaped = False
|
|
114
|
+
elif char == "{":
|
|
115
|
+
stack.append("}")
|
|
116
|
+
elif char == "[":
|
|
117
|
+
stack.append("]")
|
|
118
|
+
elif char in {"}", "]"}:
|
|
119
|
+
if stack and stack[-1] == char:
|
|
120
|
+
stack.pop()
|
|
121
|
+
|
|
122
|
+
# Close unterminated string
|
|
123
|
+
if in_string:
|
|
124
|
+
if escaped and chars:
|
|
125
|
+
chars.pop()
|
|
126
|
+
chars.append('"')
|
|
127
|
+
|
|
128
|
+
# Add missing closing brackets
|
|
129
|
+
chars.extend(reversed(stack))
|
|
130
|
+
|
|
131
|
+
# Try parsing with progressively fewer characters
|
|
132
|
+
while chars:
|
|
133
|
+
try:
|
|
134
|
+
return json.loads("".join(chars))
|
|
135
|
+
except json.JSONDecodeError:
|
|
136
|
+
chars.pop()
|
|
137
|
+
|
|
138
|
+
# If nothing worked, raise with original
|
|
139
|
+
raise json.JSONDecodeError("Invalid JSON", json_text, 0)
|
|
140
|
+
|
|
141
|
+
# Try parsing the original text first
|
|
142
|
+
try:
|
|
143
|
+
json_object = parse_json_with_completion(text)
|
|
144
|
+
except json.JSONDecodeError:
|
|
145
|
+
# Try extracting from markdown blocks
|
|
146
|
+
markdown_match = re.search(r"```(json)(.*?)```", text, re.DOTALL)
|
|
147
|
+
if not markdown_match:
|
|
148
|
+
markdown_match = re.search(r"```(json)?(.*)", text, re.DOTALL)
|
|
149
|
+
xml_match = re.search(r"<json>(.*?)</json>", text, re.DOTALL)
|
|
150
|
+
if not xml_match:
|
|
151
|
+
xml_match = re.search(r"<json>(.*)", text, re.DOTALL)
|
|
152
|
+
|
|
153
|
+
if markdown_match or xml_match:
|
|
154
|
+
try:
|
|
155
|
+
json_object = parse_json_with_completion(
|
|
156
|
+
markdown_match.group(2) if markdown_match else (xml_match.group(1) if xml_match else "")
|
|
157
|
+
)
|
|
158
|
+
except json.JSONDecodeError:
|
|
159
|
+
msg = f"Invalid json output: {text}"
|
|
160
|
+
raise OutputParserException(msg, llm_output=text)
|
|
161
|
+
else:
|
|
162
|
+
msg = f"Invalid json output: {text}"
|
|
163
|
+
raise OutputParserException(msg, llm_output=text)
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
return pydantic_object.model_validate(json_object)
|
|
167
|
+
except ValidationError as e:
|
|
168
|
+
json_string = json.dumps(json_object)
|
|
169
|
+
msg = f"Failed to parse {pydantic_object.__name__} from completion {json_string}. Got: {e}"
|
|
170
|
+
raise OutputParserException(msg, llm_output=json_string) from e
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
async def generate_and_validate[T: BaseModel](
|
|
174
|
+
model: InferenceModel,
|
|
175
|
+
thread: StringThread,
|
|
176
|
+
pydantic_model: Type[T],
|
|
177
|
+
max_parsing_retries: int = 1,
|
|
178
|
+
) -> tuple[str, T]:
|
|
179
|
+
"""
|
|
180
|
+
Generates with InferenceModel, validates completion against Pydantic model and retries
|
|
181
|
+
if validation fails. It's recommended you use a StructuredJSONOutputBaseModel as
|
|
182
|
+
the pydantic_object to clean up the JSON schema for the LLM. Does not support RootModel.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
json_schema = get_pydantic_schema(pydantic_model)
|
|
186
|
+
|
|
187
|
+
response_thread = await model.generate(thread)
|
|
188
|
+
completion = response_thread.last_content()
|
|
189
|
+
|
|
190
|
+
current_retries = 0
|
|
191
|
+
while current_retries <= max_parsing_retries:
|
|
192
|
+
try:
|
|
193
|
+
parsed = pydantic_parse(completion, pydantic_model)
|
|
194
|
+
return (completion, parsed)
|
|
195
|
+
except Exception:
|
|
196
|
+
if current_retries == max_parsing_retries:
|
|
197
|
+
break
|
|
198
|
+
|
|
199
|
+
# Create repair prompt
|
|
200
|
+
repair_thread = StringThread(
|
|
201
|
+
[("user", FIX_OUTPUT_FORMAT.format(json_schema=json_schema, completion=completion))]
|
|
202
|
+
)
|
|
203
|
+
response_thread = await model.generate(repair_thread)
|
|
204
|
+
completion = response_thread.last_content()
|
|
205
|
+
current_retries += 1
|
|
206
|
+
|
|
207
|
+
raise JsonParseError(f"Could not parse json output after {max_parsing_retries} retries", completion)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _get_simplified_type(field_type):
|
|
211
|
+
origin = get_origin(field_type)
|
|
212
|
+
args = get_args(field_type)
|
|
213
|
+
|
|
214
|
+
if origin is list:
|
|
215
|
+
if args:
|
|
216
|
+
return [_get_simplified_type(args[0])]
|
|
217
|
+
else:
|
|
218
|
+
return "array"
|
|
219
|
+
elif origin is dict:
|
|
220
|
+
if len(args) == 2:
|
|
221
|
+
key_type = _get_simplified_type(args[0])
|
|
222
|
+
value_type = _get_simplified_type(args[1])
|
|
223
|
+
return f"Dict[{key_type}, {value_type}]"
|
|
224
|
+
else:
|
|
225
|
+
return "dict"
|
|
226
|
+
elif origin is tuple:
|
|
227
|
+
if args:
|
|
228
|
+
if len(args) == 2 and args[1] is ...:
|
|
229
|
+
# Variable length tuple like Tuple[str, ...]
|
|
230
|
+
element_type = _get_simplified_type(args[0])
|
|
231
|
+
return f"Tuple[{element_type}, ...]"
|
|
232
|
+
else:
|
|
233
|
+
# Fixed length tuple like Tuple[str, int]
|
|
234
|
+
element_types = [_get_simplified_type(arg) for arg in args]
|
|
235
|
+
return f"Tuple[{', '.join(element_types)}]"
|
|
236
|
+
else:
|
|
237
|
+
return "tuple"
|
|
238
|
+
elif origin is set:
|
|
239
|
+
if args:
|
|
240
|
+
element_type = _get_simplified_type(args[0])
|
|
241
|
+
return f"Set[{element_type}]"
|
|
242
|
+
else:
|
|
243
|
+
return "set"
|
|
244
|
+
elif origin is type(None):
|
|
245
|
+
return "null"
|
|
246
|
+
elif origin is Literal:
|
|
247
|
+
# Handle Literal types by showing them as Literal["value1", "value2"]
|
|
248
|
+
# Use double quotes for strings to match JSON format and prevent LLM confusion
|
|
249
|
+
literal_values = []
|
|
250
|
+
for arg in args:
|
|
251
|
+
if isinstance(arg, str):
|
|
252
|
+
# Use double quotes for strings to match JSON format
|
|
253
|
+
literal_values.append(f'"{arg}"')
|
|
254
|
+
else:
|
|
255
|
+
# Use repr() for non-strings (numbers, booleans, etc.)
|
|
256
|
+
literal_values.append(repr(arg))
|
|
257
|
+
return f"Literal[{', '.join(literal_values)}]"
|
|
258
|
+
elif origin is Union:
|
|
259
|
+
# Handle Union types by showing all possible types
|
|
260
|
+
if len(args) == 2 and type(None) in args:
|
|
261
|
+
# This is Optional[T] which is Union[T, None]
|
|
262
|
+
non_none_type = [arg for arg in args if arg is not type(None)][0]
|
|
263
|
+
simplified_type = _get_simplified_type(non_none_type)
|
|
264
|
+
# Convert to string representation if needed
|
|
265
|
+
if isinstance(simplified_type, (list, dict)):
|
|
266
|
+
simplified_type = str(simplified_type).replace("'", '"')
|
|
267
|
+
return f"Optional[{simplified_type}]"
|
|
268
|
+
else:
|
|
269
|
+
# Regular Union with multiple types
|
|
270
|
+
union_types = []
|
|
271
|
+
for arg in args:
|
|
272
|
+
simplified = _get_simplified_type(arg)
|
|
273
|
+
# Convert to string representation if needed
|
|
274
|
+
if isinstance(simplified, (list, dict)):
|
|
275
|
+
union_types.append(str(simplified).replace("'", '"'))
|
|
276
|
+
else:
|
|
277
|
+
union_types.append(str(simplified))
|
|
278
|
+
return f"Union[{', '.join(union_types)}]"
|
|
279
|
+
elif origin is not None:
|
|
280
|
+
return str(origin.__name__) if origin else str(field_type.__name__)
|
|
281
|
+
elif hasattr(field_type, "__bases__") and issubclass(field_type, BaseModel):
|
|
282
|
+
return get_simple_pydantic_schema(field_type)
|
|
283
|
+
elif hasattr(field_type, "__bases__") and issubclass(field_type, Enum):
|
|
284
|
+
# Handle Enum types by showing possible values
|
|
285
|
+
enum_values = [f'"{value.value}"' for value in field_type]
|
|
286
|
+
return f"Enum[{', '.join(enum_values)}]"
|
|
287
|
+
elif field_type is str:
|
|
288
|
+
return "str"
|
|
289
|
+
elif field_type is int:
|
|
290
|
+
return "int"
|
|
291
|
+
elif field_type is float:
|
|
292
|
+
return "float"
|
|
293
|
+
elif field_type is bool:
|
|
294
|
+
return "bool"
|
|
295
|
+
else:
|
|
296
|
+
return str(field_type.__name__)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def get_simple_pydantic_schema(model: type[BaseModel]):
|
|
300
|
+
representation = {}
|
|
301
|
+
for field_name, field in model.model_fields.items():
|
|
302
|
+
representation[field_name] = _get_simplified_type(field.annotation)
|
|
303
|
+
return representation
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _format_schema_value(value, indent=0):
|
|
307
|
+
"""Format a schema value for display, handling nested structures."""
|
|
308
|
+
indent_str = " " * indent
|
|
309
|
+
if isinstance(value, dict):
|
|
310
|
+
if not value:
|
|
311
|
+
return "{}"
|
|
312
|
+
lines = ["{"]
|
|
313
|
+
for k, v in value.items():
|
|
314
|
+
formatted_value = _format_schema_value(v, indent + 1)
|
|
315
|
+
lines.append(f' {indent_str}"{k}": {formatted_value},')
|
|
316
|
+
# Remove trailing comma from last item
|
|
317
|
+
if lines[-1].endswith(","):
|
|
318
|
+
lines[-1] = lines[-1][:-1]
|
|
319
|
+
lines.append(f"{indent_str}" + "}")
|
|
320
|
+
return "\n".join(lines)
|
|
321
|
+
elif isinstance(value, list):
|
|
322
|
+
if not value:
|
|
323
|
+
return "[]"
|
|
324
|
+
elif len(value) == 1:
|
|
325
|
+
formatted_item = _format_schema_value(value[0], indent)
|
|
326
|
+
return f"[{formatted_item}]"
|
|
327
|
+
else:
|
|
328
|
+
lines = ["["]
|
|
329
|
+
for item in value:
|
|
330
|
+
formatted_item = _format_schema_value(item, indent + 1)
|
|
331
|
+
lines.append(f" {indent_str}{formatted_item},")
|
|
332
|
+
# Remove trailing comma from last item
|
|
333
|
+
if lines[-1].endswith(","):
|
|
334
|
+
lines[-1] = lines[-1][:-1]
|
|
335
|
+
lines.append(f"{indent_str}]")
|
|
336
|
+
return "\n".join(lines)
|
|
337
|
+
elif isinstance(value, str) and (
|
|
338
|
+
value.startswith(("Literal[", "Union[", "Optional[", "Dict[", "Tuple[", "Set[", "Enum["))
|
|
339
|
+
or value in ("str", "int", "float", "bool")
|
|
340
|
+
):
|
|
341
|
+
# Don't add quotes around type annotations or basic type names
|
|
342
|
+
return value
|
|
343
|
+
else:
|
|
344
|
+
# Regular string values get quotes
|
|
345
|
+
return f'"{value}"'
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def render_schema(pydantic_model: type[BaseModel], with_field_descriptions: bool = True) -> str:
|
|
349
|
+
simplified_schema = get_simple_pydantic_schema(pydantic_model)
|
|
350
|
+
# Use custom formatting instead of json.dumps to handle Literal types properly
|
|
351
|
+
schema_str = _format_schema_value(simplified_schema)
|
|
352
|
+
|
|
353
|
+
if not with_field_descriptions:
|
|
354
|
+
return schema_str
|
|
355
|
+
|
|
356
|
+
descriptions = []
|
|
357
|
+
for field_name, field in pydantic_model.model_fields.items():
|
|
358
|
+
if not field.description:
|
|
359
|
+
raise ValueError(f"Field '{field_name}' in model '{pydantic_model.__name__}' is missing a description.")
|
|
360
|
+
descriptions.append(f"{field_name}: {field.description}")
|
|
361
|
+
|
|
362
|
+
for field_name, field in pydantic_model.model_fields.items():
|
|
363
|
+
if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel):
|
|
364
|
+
nested_model = field.annotation
|
|
365
|
+
for nested_field_name, nested_field in nested_model.model_fields.items():
|
|
366
|
+
if not nested_field.description:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"Field '{nested_field_name}' in nested model '{nested_model.__name__}' is missing a description."
|
|
369
|
+
)
|
|
370
|
+
descriptions.append(f"{field_name}.{nested_field_name}: {nested_field.description}")
|
|
371
|
+
elif get_origin(field.annotation) is list and get_args(field.annotation):
|
|
372
|
+
list_item_type = get_args(field.annotation)[0]
|
|
373
|
+
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
|
374
|
+
for nested_field_name, nested_field in list_item_type.model_fields.items():
|
|
375
|
+
if not nested_field.description:
|
|
376
|
+
raise ValueError(
|
|
377
|
+
f"Field '{nested_field_name}' in list item model '{list_item_type.__name__}' is missing a description."
|
|
378
|
+
)
|
|
379
|
+
descriptions.append(f"{field_name}[].{nested_field_name}: {nested_field.description}")
|
|
380
|
+
|
|
381
|
+
return f"{schema_str}\n\n{'\n'.join(descriptions)}"
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def render_pydantic_model(pydantic_model: BaseModel) -> str:
|
|
385
|
+
return pydantic_model.model_dump_json()
|