adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,385 @@
1
+ import json
2
+ import re
3
+ from enum import Enum
4
+ from typing import Literal, Type, Union, get_args, get_origin
5
+
6
+ from pydantic import BaseModel, ValidationError
7
+
8
+ from adaptive_harmony import InferenceModel, StringThread
9
+ from adaptive_harmony.core.reasoning import remove_reasoning
10
+
11
+ FIX_OUTPUT_FORMAT = """Below, the COMPLETION did not satisfy the constraints given in the PROMPT. Please rewrite the completion to comply with constraints, nothing else.
12
+
13
+ PROMPT
14
+ The output should be a well-formatted JSON instance that conforms to the JSON schema below. All fields are required. Do not output anything else other than the JSON.
15
+
16
+ As an example, for the schema
17
+ {{
18
+ "foo": {{
19
+ "items":{{"type": "string"}},
20
+ "type": "array"
21
+ }},
22
+ "bar": {{"type": "integer"}}
23
+ }}
24
+ the object {{"foo": ["hey", "bye"], "bar": 1}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["hey", "bye"], "bar":"1" }}}} is not well-formatted.
25
+
26
+ Here is the output JSON schema:
27
+ {json_schema}
28
+
29
+ COMPLETION
30
+ {completion}
31
+ """
32
+
33
+
34
+ class JsonParseError(Exception):
35
+ def __init__(self, message: str, completion: str):
36
+ super().__init__(message)
37
+ self.completion = completion
38
+
39
+
40
+ def get_pydantic_schema(base_model: Type[BaseModel]) -> str:
41
+ schema = base_model.model_json_schema()
42
+ for prop in schema.get("properties", {}).values():
43
+ prop.pop("title", None)
44
+ return json.dumps(schema, indent=2)
45
+
46
+
47
+ class OutputParserException(Exception):
48
+ """Exception raised for parsing errors."""
49
+
50
+ def __init__(self, message: str, llm_output: str | None = None):
51
+ super().__init__(message)
52
+ self.llm_output = llm_output
53
+
54
+
55
+ def pydantic_parse[T: BaseModel](text: str, pydantic_object: type[T]) -> T:
56
+ """Parse the output of an LLM call to a pydantic object.
57
+
58
+ Args:
59
+ text: The output of the LLM call.
60
+ pydantic_object: The pydantic model to parse into.
61
+
62
+ Returns:
63
+ The parsed pydantic object.
64
+ """
65
+ # Remove Qwen3 reasoning
66
+ text = remove_reasoning(text)
67
+
68
+ # Strip initial whitespace
69
+ text = text.strip()
70
+
71
+ def parse_json_with_completion(json_text):
72
+ """Parse JSON, handling partial JSON by completing missing brackets."""
73
+ # Strip whitespace and backticks
74
+ json_text = json_text.strip(" \n\r\t`")
75
+
76
+ # Handle action_input special case - escape special chars
77
+ if '"action_input"' in json_text:
78
+
79
+ def fix_action_input(match):
80
+ value = match.group(2)
81
+ value = re.sub(r"\n", r"\\n", value)
82
+ value = re.sub(r"\r", r"\\r", value)
83
+ value = re.sub(r"\t", r"\\t", value)
84
+ value = re.sub(r'(?<!\\)"', r"\"", value)
85
+ return match.group(1) + value + match.group(3)
86
+
87
+ json_text = re.sub(r'("action_input"\:\s*")(.*?)(")', fix_action_input, json_text, flags=re.DOTALL)
88
+
89
+ # NOTE Axel: gemma likes to escape the left bracket, patching for now
90
+ json_text = json_text.replace(r"\\[", "[")
91
+
92
+ # Try parsing as-is first
93
+ try:
94
+ return json.loads(json_text)
95
+ except json.JSONDecodeError:
96
+ pass
97
+
98
+ # Handle partial JSON - complete missing brackets and quotes
99
+ chars = list(json_text)
100
+ stack = []
101
+ in_string = False
102
+ escaped = False
103
+
104
+ for i, char in enumerate(chars):
105
+ if in_string:
106
+ if char == '"' and not escaped:
107
+ in_string = False
108
+ elif char == "\n" and not escaped:
109
+ chars[i] = "\\n"
110
+ escaped = char == "\\" and not escaped
111
+ elif char == '"':
112
+ in_string = True
113
+ escaped = False
114
+ elif char == "{":
115
+ stack.append("}")
116
+ elif char == "[":
117
+ stack.append("]")
118
+ elif char in {"}", "]"}:
119
+ if stack and stack[-1] == char:
120
+ stack.pop()
121
+
122
+ # Close unterminated string
123
+ if in_string:
124
+ if escaped and chars:
125
+ chars.pop()
126
+ chars.append('"')
127
+
128
+ # Add missing closing brackets
129
+ chars.extend(reversed(stack))
130
+
131
+ # Try parsing with progressively fewer characters
132
+ while chars:
133
+ try:
134
+ return json.loads("".join(chars))
135
+ except json.JSONDecodeError:
136
+ chars.pop()
137
+
138
+ # If nothing worked, raise with original
139
+ raise json.JSONDecodeError("Invalid JSON", json_text, 0)
140
+
141
+ # Try parsing the original text first
142
+ try:
143
+ json_object = parse_json_with_completion(text)
144
+ except json.JSONDecodeError:
145
+ # Try extracting from markdown blocks
146
+ markdown_match = re.search(r"```(json)(.*?)```", text, re.DOTALL)
147
+ if not markdown_match:
148
+ markdown_match = re.search(r"```(json)?(.*)", text, re.DOTALL)
149
+ xml_match = re.search(r"<json>(.*?)</json>", text, re.DOTALL)
150
+ if not xml_match:
151
+ xml_match = re.search(r"<json>(.*)", text, re.DOTALL)
152
+
153
+ if markdown_match or xml_match:
154
+ try:
155
+ json_object = parse_json_with_completion(
156
+ markdown_match.group(2) if markdown_match else (xml_match.group(1) if xml_match else "")
157
+ )
158
+ except json.JSONDecodeError:
159
+ msg = f"Invalid json output: {text}"
160
+ raise OutputParserException(msg, llm_output=text)
161
+ else:
162
+ msg = f"Invalid json output: {text}"
163
+ raise OutputParserException(msg, llm_output=text)
164
+
165
+ try:
166
+ return pydantic_object.model_validate(json_object)
167
+ except ValidationError as e:
168
+ json_string = json.dumps(json_object)
169
+ msg = f"Failed to parse {pydantic_object.__name__} from completion {json_string}. Got: {e}"
170
+ raise OutputParserException(msg, llm_output=json_string) from e
171
+
172
+
173
+ async def generate_and_validate[T: BaseModel](
174
+ model: InferenceModel,
175
+ thread: StringThread,
176
+ pydantic_model: Type[T],
177
+ max_parsing_retries: int = 1,
178
+ ) -> tuple[str, T]:
179
+ """
180
+ Generates with InferenceModel, validates completion against Pydantic model and retries
181
+ if validation fails. It's recommended you use a StructuredJSONOutputBaseModel as
182
+ the pydantic_object to clean up the JSON schema for the LLM. Does not support RootModel.
183
+ """
184
+
185
+ json_schema = get_pydantic_schema(pydantic_model)
186
+
187
+ response_thread = await model.generate(thread)
188
+ completion = response_thread.last_content()
189
+
190
+ current_retries = 0
191
+ while current_retries <= max_parsing_retries:
192
+ try:
193
+ parsed = pydantic_parse(completion, pydantic_model)
194
+ return (completion, parsed)
195
+ except Exception:
196
+ if current_retries == max_parsing_retries:
197
+ break
198
+
199
+ # Create repair prompt
200
+ repair_thread = StringThread(
201
+ [("user", FIX_OUTPUT_FORMAT.format(json_schema=json_schema, completion=completion))]
202
+ )
203
+ response_thread = await model.generate(repair_thread)
204
+ completion = response_thread.last_content()
205
+ current_retries += 1
206
+
207
+ raise JsonParseError(f"Could not parse json output after {max_parsing_retries} retries", completion)
208
+
209
+
210
+ def _get_simplified_type(field_type):
211
+ origin = get_origin(field_type)
212
+ args = get_args(field_type)
213
+
214
+ if origin is list:
215
+ if args:
216
+ return [_get_simplified_type(args[0])]
217
+ else:
218
+ return "array"
219
+ elif origin is dict:
220
+ if len(args) == 2:
221
+ key_type = _get_simplified_type(args[0])
222
+ value_type = _get_simplified_type(args[1])
223
+ return f"Dict[{key_type}, {value_type}]"
224
+ else:
225
+ return "dict"
226
+ elif origin is tuple:
227
+ if args:
228
+ if len(args) == 2 and args[1] is ...:
229
+ # Variable length tuple like Tuple[str, ...]
230
+ element_type = _get_simplified_type(args[0])
231
+ return f"Tuple[{element_type}, ...]"
232
+ else:
233
+ # Fixed length tuple like Tuple[str, int]
234
+ element_types = [_get_simplified_type(arg) for arg in args]
235
+ return f"Tuple[{', '.join(element_types)}]"
236
+ else:
237
+ return "tuple"
238
+ elif origin is set:
239
+ if args:
240
+ element_type = _get_simplified_type(args[0])
241
+ return f"Set[{element_type}]"
242
+ else:
243
+ return "set"
244
+ elif origin is type(None):
245
+ return "null"
246
+ elif origin is Literal:
247
+ # Handle Literal types by showing them as Literal["value1", "value2"]
248
+ # Use double quotes for strings to match JSON format and prevent LLM confusion
249
+ literal_values = []
250
+ for arg in args:
251
+ if isinstance(arg, str):
252
+ # Use double quotes for strings to match JSON format
253
+ literal_values.append(f'"{arg}"')
254
+ else:
255
+ # Use repr() for non-strings (numbers, booleans, etc.)
256
+ literal_values.append(repr(arg))
257
+ return f"Literal[{', '.join(literal_values)}]"
258
+ elif origin is Union:
259
+ # Handle Union types by showing all possible types
260
+ if len(args) == 2 and type(None) in args:
261
+ # This is Optional[T] which is Union[T, None]
262
+ non_none_type = [arg for arg in args if arg is not type(None)][0]
263
+ simplified_type = _get_simplified_type(non_none_type)
264
+ # Convert to string representation if needed
265
+ if isinstance(simplified_type, (list, dict)):
266
+ simplified_type = str(simplified_type).replace("'", '"')
267
+ return f"Optional[{simplified_type}]"
268
+ else:
269
+ # Regular Union with multiple types
270
+ union_types = []
271
+ for arg in args:
272
+ simplified = _get_simplified_type(arg)
273
+ # Convert to string representation if needed
274
+ if isinstance(simplified, (list, dict)):
275
+ union_types.append(str(simplified).replace("'", '"'))
276
+ else:
277
+ union_types.append(str(simplified))
278
+ return f"Union[{', '.join(union_types)}]"
279
+ elif origin is not None:
280
+ return str(origin.__name__) if origin else str(field_type.__name__)
281
+ elif hasattr(field_type, "__bases__") and issubclass(field_type, BaseModel):
282
+ return get_simple_pydantic_schema(field_type)
283
+ elif hasattr(field_type, "__bases__") and issubclass(field_type, Enum):
284
+ # Handle Enum types by showing possible values
285
+ enum_values = [f'"{value.value}"' for value in field_type]
286
+ return f"Enum[{', '.join(enum_values)}]"
287
+ elif field_type is str:
288
+ return "str"
289
+ elif field_type is int:
290
+ return "int"
291
+ elif field_type is float:
292
+ return "float"
293
+ elif field_type is bool:
294
+ return "bool"
295
+ else:
296
+ return str(field_type.__name__)
297
+
298
+
299
+ def get_simple_pydantic_schema(model: type[BaseModel]):
300
+ representation = {}
301
+ for field_name, field in model.model_fields.items():
302
+ representation[field_name] = _get_simplified_type(field.annotation)
303
+ return representation
304
+
305
+
306
+ def _format_schema_value(value, indent=0):
307
+ """Format a schema value for display, handling nested structures."""
308
+ indent_str = " " * indent
309
+ if isinstance(value, dict):
310
+ if not value:
311
+ return "{}"
312
+ lines = ["{"]
313
+ for k, v in value.items():
314
+ formatted_value = _format_schema_value(v, indent + 1)
315
+ lines.append(f' {indent_str}"{k}": {formatted_value},')
316
+ # Remove trailing comma from last item
317
+ if lines[-1].endswith(","):
318
+ lines[-1] = lines[-1][:-1]
319
+ lines.append(f"{indent_str}" + "}")
320
+ return "\n".join(lines)
321
+ elif isinstance(value, list):
322
+ if not value:
323
+ return "[]"
324
+ elif len(value) == 1:
325
+ formatted_item = _format_schema_value(value[0], indent)
326
+ return f"[{formatted_item}]"
327
+ else:
328
+ lines = ["["]
329
+ for item in value:
330
+ formatted_item = _format_schema_value(item, indent + 1)
331
+ lines.append(f" {indent_str}{formatted_item},")
332
+ # Remove trailing comma from last item
333
+ if lines[-1].endswith(","):
334
+ lines[-1] = lines[-1][:-1]
335
+ lines.append(f"{indent_str}]")
336
+ return "\n".join(lines)
337
+ elif isinstance(value, str) and (
338
+ value.startswith(("Literal[", "Union[", "Optional[", "Dict[", "Tuple[", "Set[", "Enum["))
339
+ or value in ("str", "int", "float", "bool")
340
+ ):
341
+ # Don't add quotes around type annotations or basic type names
342
+ return value
343
+ else:
344
+ # Regular string values get quotes
345
+ return f'"{value}"'
346
+
347
+
348
+ def render_schema(pydantic_model: type[BaseModel], with_field_descriptions: bool = True) -> str:
349
+ simplified_schema = get_simple_pydantic_schema(pydantic_model)
350
+ # Use custom formatting instead of json.dumps to handle Literal types properly
351
+ schema_str = _format_schema_value(simplified_schema)
352
+
353
+ if not with_field_descriptions:
354
+ return schema_str
355
+
356
+ descriptions = []
357
+ for field_name, field in pydantic_model.model_fields.items():
358
+ if not field.description:
359
+ raise ValueError(f"Field '{field_name}' in model '{pydantic_model.__name__}' is missing a description.")
360
+ descriptions.append(f"{field_name}: {field.description}")
361
+
362
+ for field_name, field in pydantic_model.model_fields.items():
363
+ if isinstance(field.annotation, type) and issubclass(field.annotation, BaseModel):
364
+ nested_model = field.annotation
365
+ for nested_field_name, nested_field in nested_model.model_fields.items():
366
+ if not nested_field.description:
367
+ raise ValueError(
368
+ f"Field '{nested_field_name}' in nested model '{nested_model.__name__}' is missing a description."
369
+ )
370
+ descriptions.append(f"{field_name}.{nested_field_name}: {nested_field.description}")
371
+ elif get_origin(field.annotation) is list and get_args(field.annotation):
372
+ list_item_type = get_args(field.annotation)[0]
373
+ if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
374
+ for nested_field_name, nested_field in list_item_type.model_fields.items():
375
+ if not nested_field.description:
376
+ raise ValueError(
377
+ f"Field '{nested_field_name}' in list item model '{list_item_type.__name__}' is missing a description."
378
+ )
379
+ descriptions.append(f"{field_name}[].{nested_field_name}: {nested_field.description}")
380
+
381
+ return f"{schema_str}\n\n{'\n'.join(descriptions)}"
382
+
383
+
384
+ def render_pydantic_model(pydantic_model: BaseModel) -> str:
385
+ return pydantic_model.model_dump_json()