adaptive-harmony 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. adaptive_harmony/__init__.py +162 -0
  2. adaptive_harmony/common/__init__.py +40 -0
  3. adaptive_harmony/common/callbacks.py +219 -0
  4. adaptive_harmony/common/checkpointing.py +163 -0
  5. adaptive_harmony/common/dpo.py +92 -0
  6. adaptive_harmony/common/env_grpo.py +361 -0
  7. adaptive_harmony/common/grpo.py +260 -0
  8. adaptive_harmony/common/gspo.py +70 -0
  9. adaptive_harmony/common/ppo.py +303 -0
  10. adaptive_harmony/common/rm.py +79 -0
  11. adaptive_harmony/common/sft.py +121 -0
  12. adaptive_harmony/core/__init__.py +0 -0
  13. adaptive_harmony/core/dataset.py +72 -0
  14. adaptive_harmony/core/display.py +93 -0
  15. adaptive_harmony/core/image_utils.py +110 -0
  16. adaptive_harmony/core/reasoning.py +12 -0
  17. adaptive_harmony/core/reward_client/__init__.py +19 -0
  18. adaptive_harmony/core/reward_client/client.py +160 -0
  19. adaptive_harmony/core/reward_client/reward_types.py +49 -0
  20. adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
  21. adaptive_harmony/core/rich_counter.py +351 -0
  22. adaptive_harmony/core/rl_utils.py +38 -0
  23. adaptive_harmony/core/schedulers.py +38 -0
  24. adaptive_harmony/core/structured_output.py +385 -0
  25. adaptive_harmony/core/utils.py +365 -0
  26. adaptive_harmony/environment/__init__.py +8 -0
  27. adaptive_harmony/environment/environment.py +121 -0
  28. adaptive_harmony/evaluation/__init__.py +1 -0
  29. adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
  30. adaptive_harmony/graders/__init__.py +20 -0
  31. adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
  32. adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
  33. adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
  34. adaptive_harmony/graders/base_grader.py +265 -0
  35. adaptive_harmony/graders/binary_judge/__init__.py +8 -0
  36. adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
  37. adaptive_harmony/graders/binary_judge/prompts.py +125 -0
  38. adaptive_harmony/graders/combined_grader.py +118 -0
  39. adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
  40. adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
  41. adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
  42. adaptive_harmony/graders/exceptions.py +9 -0
  43. adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
  44. adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
  45. adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
  46. adaptive_harmony/graders/range_judge/__init__.py +7 -0
  47. adaptive_harmony/graders/range_judge/prompts.py +232 -0
  48. adaptive_harmony/graders/range_judge/range_judge.py +188 -0
  49. adaptive_harmony/graders/range_judge/types.py +12 -0
  50. adaptive_harmony/graders/reward_server_grader.py +36 -0
  51. adaptive_harmony/graders/templated_prompt_judge.py +237 -0
  52. adaptive_harmony/graders/utils.py +79 -0
  53. adaptive_harmony/logging_table.py +1 -0
  54. adaptive_harmony/metric_logger.py +452 -0
  55. adaptive_harmony/parameters/__init__.py +2 -0
  56. adaptive_harmony/py.typed +0 -0
  57. adaptive_harmony/runtime/__init__.py +2 -0
  58. adaptive_harmony/runtime/context.py +2 -0
  59. adaptive_harmony/runtime/data.py +2 -0
  60. adaptive_harmony/runtime/decorators.py +2 -0
  61. adaptive_harmony/runtime/model_artifact_save.py +2 -0
  62. adaptive_harmony/runtime/runner.py +27 -0
  63. adaptive_harmony/runtime/simple_notifier.py +2 -0
  64. adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
  65. adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
  66. adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
  67. adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
@@ -0,0 +1,351 @@
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Progress tracking and coroutine observability for async operations.
5
+
6
+ Provides real-time visualization of async task execution with a progress bar
7
+ and hierarchical coroutine call tree display.
8
+ """
9
+
10
+ import asyncio
11
+ import os
12
+ import weakref
13
+
14
+ from rich.console import Group
15
+ from rich.live import Live
16
+ from rich.panel import Panel
17
+ from rich.progress import (
18
+ BarColumn,
19
+ Progress,
20
+ TaskID,
21
+ TextColumn,
22
+ TimeRemainingColumn,
23
+ )
24
+ from rich.table import Table
25
+ from rich.text import Text
26
+
27
+
28
+ def _is_stdlib(filename: str) -> bool:
29
+ """Check if a filename is from Python's standard library."""
30
+ if not filename:
31
+ return False
32
+
33
+ # Check for common stdlib patterns
34
+ stdlib_indicators = [
35
+ "asyncio",
36
+ "threading",
37
+ ]
38
+
39
+ return any(indicator in filename for indicator in stdlib_indicators)
40
+
41
+
42
+ def describe_coroutine(coro):
43
+ """
44
+ Extracts and returns all user code locations in the await chain as a list
45
+ (excluding stdlib). Each entry is formatted as "filename:lineno code..."
46
+ """
47
+ user_code_locations = []
48
+
49
+ # Follow the cr_await chain to collect all user coroutines
50
+ while coro is not None:
51
+ frame = getattr(coro, "cr_frame", None)
52
+ if frame is None:
53
+ break
54
+
55
+ filename = frame.f_code.co_filename
56
+ lineno = frame.f_lineno
57
+
58
+ # Don't descend into stdlib code
59
+ if _is_stdlib(filename):
60
+ user_code_locations.append("Internals:...")
61
+ break
62
+
63
+ # Collect this user code location (no indentation - tree will handle that)
64
+ # Truncate long file paths to just the filename
65
+ short_filename = filename.split("/")[-1]
66
+ # Get function name from frame
67
+ func_name = frame.f_code.co_name
68
+ # Keep it very short to fit in panel
69
+ user_code_locations.append(f"{short_filename}:{lineno} {func_name}()")
70
+
71
+ # Check if this coroutine is awaiting another coroutine
72
+ awaited = getattr(coro, "cr_await", None)
73
+ if awaited is None or not hasattr(awaited, "cr_frame"):
74
+ # This is the innermost coroutine
75
+ break
76
+
77
+ # Descend into the awaited coroutine
78
+ coro = awaited
79
+
80
+ return user_code_locations
81
+
82
+
83
+ class ProgressCounterRegistry:
84
+ _instance = None
85
+
86
+ def __new__(cls):
87
+ if cls._instance is None:
88
+ cls._instance = super().__new__(cls)
89
+ cls._instance._initialized = False
90
+ return cls._instance
91
+
92
+ def __init__(self):
93
+ if self._initialized:
94
+ return
95
+ self._main_counter_instance: ProgressCounter | None = None
96
+ self._wrappers: dict[str, ProgressCounterWrapper] = {}
97
+ self._initialized = True
98
+
99
+ def get_main_counter(self) -> "ProgressCounter | None":
100
+ return self._main_counter_instance
101
+
102
+ def set_main_counter(self, counter: "ProgressCounter | None"):
103
+ self._main_counter_instance = counter
104
+
105
+ def get_wrapper(self, key: str) -> "ProgressCounterWrapper | None":
106
+ return self._wrappers.get(key)
107
+
108
+ def set_wrapper(self, key: str, wrapper: "ProgressCounterWrapper"):
109
+ self._wrappers[key] = wrapper
110
+
111
+ def reset(self):
112
+ self._main_counter_instance = None
113
+ self._wrappers.clear()
114
+
115
+
116
+ class CoroutineTree:
117
+ """Manages building and formatting a tree structure from coroutine call chains."""
118
+
119
+ def __init__(self):
120
+ self.max_height = 1
121
+
122
+ def build_from_chains(self, chains: list[list[str]]) -> dict:
123
+ """Build a tree structure from coroutine call chains."""
124
+ tree = {}
125
+
126
+ for chain in chains:
127
+ current = tree
128
+ for desc in chain:
129
+ if desc not in current:
130
+ current[desc] = {"count": 0, "children": {}}
131
+ current[desc]["count"] += 1
132
+ current = current[desc]["children"]
133
+
134
+ return tree
135
+
136
+ def format_as_text(self, tree: dict, prefix: str = "") -> Text:
137
+ """Format tree structure as Rich Text with counts."""
138
+ text = Text()
139
+ items = list(tree.items())
140
+
141
+ for i, (desc, node) in enumerate(items):
142
+ is_last_item = i == len(items) - 1
143
+
144
+ # Tree branch characters
145
+ if prefix == "":
146
+ branch = ""
147
+ else:
148
+ branch = "└── " if is_last_item else "├── "
149
+
150
+ # Format line with count
151
+ count_str = f"[{node['count']}] "
152
+ text.append(prefix + branch, style="dim")
153
+ text.append(count_str, style="bold magenta")
154
+ text.append(desc + "\n")
155
+
156
+ # Recurse into children
157
+ if node["children"]:
158
+ # Add indentation for children
159
+ new_prefix = prefix + (" " if is_last_item else "│ ")
160
+ text.append(self.format_as_text(node["children"], new_prefix))
161
+
162
+ return text
163
+
164
+ def create_padded_display(self, chains: list[list[str]], active_count: int) -> Text:
165
+ """Create a formatted and padded display text for the tree."""
166
+ if not chains:
167
+ return Text("No chains available", style="dim")
168
+
169
+ # Build and format tree
170
+ tree = self.build_from_chains(chains)
171
+ header = Text(f"{active_count} tasks active\n\n", style="bold cyan")
172
+ tree_content = self.format_as_text(tree)
173
+
174
+ # Count lines to track height
175
+ current_height = len(header.plain.split("\n")) + len(tree_content.plain.split("\n"))
176
+ if current_height > self.max_height:
177
+ self.max_height = current_height
178
+
179
+ # Combine and add padding
180
+ display = Text()
181
+ display.append(header)
182
+ display.append(tree_content)
183
+
184
+ # Add padding to maintain consistent height
185
+ padding_lines = max(0, self.max_height - current_height)
186
+ for _ in range(padding_lines):
187
+ display.append("\n")
188
+
189
+ return display
190
+
191
+ def create_empty_display(self) -> Text:
192
+ """Create a display for when there are no active tasks."""
193
+ display = Text("No active tasks but updated", style="dim")
194
+
195
+ # Pad to maintain consistent height
196
+ current_height = 1
197
+ padding_lines = max(0, self.max_height - current_height)
198
+ for _ in range(padding_lines):
199
+ display.append("\n")
200
+
201
+ return display
202
+
203
+
204
+ class ProgressCounter:
205
+ def __init__(self, main_job_name: str, total_tasks: int):
206
+ registry = ProgressCounterRegistry()
207
+ assert registry.get_main_counter() is None, "Only one main counter instance is allowed"
208
+ self.total_tasks = total_tasks
209
+
210
+ # construct the left panel that will give the state of the batch at a glance
211
+ self.overall_progress = Progress(
212
+ "{task.description}",
213
+ BarColumn(),
214
+ TextColumn("[magenta]{task.completed:n} steps"),
215
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
216
+ TimeRemainingColumn(),
217
+ )
218
+ self.overall_task = self.overall_progress.add_task(main_job_name, total=self.total_tasks)
219
+ self.overall_panel = Panel(
220
+ self.overall_progress, title="Overall Progress", border_style="green", padding=(1, 1, 0, 1)
221
+ )
222
+
223
+ # construct the right panel for the coroutine tree
224
+ self.coroutine_tree = CoroutineTree()
225
+ self.tree_text = Text("Waiting for tasks...", style="dim")
226
+ self.tree_panel = Panel(self.tree_text, title="[b]Coroutine Tree", border_style="blue", padding=(1, 2))
227
+
228
+ self._monitored_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet()
229
+ self._monitor_task: asyncio.Task | None = None
230
+ self.interval = 0.5
231
+
232
+ # Construct the general table with two columns
233
+ self.progress_table = Table.grid()
234
+ self.progress_table.add_row(self.overall_panel, self.tree_panel)
235
+
236
+ # Disable the display if DISABLE_RICH_PROGRESS is set
237
+ # TODO: make this less dirty
238
+ self.enable_display = not os.getenv("DISABLE_RICH_PROGRESS", "0") == "1"
239
+
240
+ def register_task(self, task: asyncio.Task):
241
+ self._monitored_tasks.add(task)
242
+
243
+ def increment_total_counter(self):
244
+ self.overall_progress.advance(TaskID(0))
245
+
246
+ def is_done(self):
247
+ return self.overall_progress.finished
248
+
249
+ async def _monitor_loop(self):
250
+ try:
251
+ while True:
252
+ await asyncio.sleep(self.interval)
253
+
254
+ # Filter out done tasks (just in case gc hasn't run yet)
255
+ active = [t for t in self._monitored_tasks if not t.done()]
256
+
257
+ # Collect all call chains
258
+ chains = []
259
+ for task in active:
260
+ coro = task.get_coro()
261
+ descriptions = describe_coroutine(coro)
262
+ if descriptions:
263
+ chains.append(descriptions)
264
+
265
+ # Create tree display
266
+ if not active or not chains:
267
+ self.tree_text = self.coroutine_tree.create_empty_display()
268
+ else:
269
+ self.tree_text = self.coroutine_tree.create_padded_display(chains, len(active))
270
+
271
+ # Update tree panel
272
+ self.tree_panel = Panel(self.tree_text, title="[b]Coroutine Tree", border_style="blue", padding=(1, 2))
273
+
274
+ # Update overall panel with matching height
275
+ overall_with_padding = Group(
276
+ self.overall_progress,
277
+ Text("\n" * max(0, self.coroutine_tree.max_height - 2)), # Account for progress bar base height
278
+ )
279
+ self.overall_panel = Panel(
280
+ overall_with_padding, title="Overall Progress", border_style="green", padding=(1, 1, 0, 1)
281
+ )
282
+
283
+ # Recreate the progress table with updated panels
284
+ self.progress_table = Table.grid()
285
+ self.progress_table.add_row(self.overall_panel, self.tree_panel)
286
+ if self.enable_display:
287
+ self.live.update(self.progress_table)
288
+
289
+ except asyncio.CancelledError:
290
+ pass
291
+
292
+ async def __aenter__(self):
293
+ if self.enable_display:
294
+ self.live = Live(self.progress_table, refresh_per_second=4).__enter__()
295
+ self._monitor_task = asyncio.create_task(self._monitor_loop())
296
+ return self
297
+
298
+ async def __aexit__(self, exc_type, exc_value, exc_traceback):
299
+ # Cancel the monitor task
300
+ if self._monitor_task and not self._monitor_task.done():
301
+ self._monitor_task.cancel()
302
+ try:
303
+ await self._monitor_task
304
+ except asyncio.CancelledError:
305
+ pass
306
+
307
+ if self.enable_display:
308
+ self.live.__exit__(exc_type, exc_value, exc_traceback)
309
+ registry = ProgressCounterRegistry()
310
+ assert registry.get_main_counter() is not None, "Weird state"
311
+ # we clear both the main counter and any wrapper that would point to the old counter
312
+ registry.reset()
313
+
314
+
315
+ class ProgressCounterWrapper(ProgressCounter):
316
+ """Used to wrap a preexisting counter, needed to have simple logic in case an async_map is used inside an async_map_batch"""
317
+
318
+ def __init__(self, inner: ProgressCounter, main_function_name: str):
319
+ self.inner = inner
320
+ self.main_function_name = main_function_name
321
+
322
+ def register_task(self, task: asyncio.Task):
323
+ return self.inner.register_task(task)
324
+
325
+ def increment_total_counter(self):
326
+ # For now we do nothing
327
+ ...
328
+
329
+ async def __aenter__(self):
330
+ # do nothing, the inner progress counter has already been initialized
331
+ return self
332
+
333
+ async def __aexit__(self, _exc_type, _exc_value, _exc_traceback):
334
+ # do nothing, the inner progress counter will be exited later
335
+ ...
336
+
337
+
338
+ def get_progress_counter_or_wrapper(main_job_name: str, total_samples: int):
339
+ registry = ProgressCounterRegistry()
340
+ main_counter = registry.get_main_counter()
341
+
342
+ if main_counter is None:
343
+ main_counter = ProgressCounter(main_job_name, total_samples)
344
+ registry.set_main_counter(main_counter)
345
+ return main_counter
346
+ else:
347
+ wrapper = registry.get_wrapper(main_job_name)
348
+ if wrapper is None:
349
+ wrapper = ProgressCounterWrapper(main_counter, main_job_name)
350
+ registry.set_wrapper(main_job_name, wrapper)
351
+ return wrapper
@@ -0,0 +1,38 @@
1
+ def gae_advantages(
2
+ values: list[float],
3
+ rewards: list[float],
4
+ gae_lambda: float,
5
+ gae_gamma: float,
6
+ ) -> list[float]:
7
+ response_length = len(values)
8
+
9
+ lastgaelam = 0.0
10
+ advantages_reversed: list[float] = []
11
+ for t in reversed(range(response_length)):
12
+ nextvalues = values[t + 1] if t < response_length - 1 else 0.0
13
+ delta = rewards[t] + gae_gamma * nextvalues - values[t]
14
+ lastgaelam = delta + gae_gamma * gae_lambda * lastgaelam
15
+ advantages_reversed.append(lastgaelam)
16
+
17
+ return advantages_reversed[::-1]
18
+
19
+
20
+ def discounted_cumulative_rewards(
21
+ rewards: list[float],
22
+ gamma: float,
23
+ ) -> list[float]:
24
+ n = len(rewards)
25
+ returns = [0.0] * n
26
+ returns[-1] = rewards[-1]
27
+ for t in reversed(range(n - 1)):
28
+ returns[t] = rewards[t] + gamma * returns[t + 1]
29
+
30
+ return returns
31
+
32
+
33
+ def gae_td_returns(
34
+ advantages: list[float],
35
+ values: list[float],
36
+ ) -> list[float]:
37
+ returns = [a + b for a, b in zip(advantages, values)]
38
+ return returns
@@ -0,0 +1,38 @@
1
+ import math
2
+ from typing import Callable
3
+
4
+ Scheduler = Callable[[float], float]
5
+
6
+
7
+ class CosineSchedulerWithoutWarmup:
8
+ def __init__(self, lr=1e-5, decay_factor=10.0) -> None:
9
+ self.max_value = lr
10
+ self.min_value = self.max_value / decay_factor
11
+
12
+ def __call__(self, completion_percentage: float) -> float:
13
+ coefficient = 0.5 * (math.cos(math.pi * completion_percentage) + 1.0)
14
+ value_delta = self.max_value - self.min_value
15
+ return self.min_value + coefficient * value_delta
16
+
17
+
18
+ class CombinedSchedule:
19
+ def __init__(self, a: Scheduler, b: Scheduler, change_point: float) -> None:
20
+ self.a = a
21
+ self.b = b
22
+ self.change_point = change_point
23
+
24
+ def __call__(self, completion_percentage: float) -> float:
25
+ if completion_percentage < self.change_point:
26
+ return self.a(completion_percentage / self.change_point)
27
+ else:
28
+ return self.b((completion_percentage - self.change_point) / (1 - self.change_point))
29
+
30
+
31
+ class CosineScheduler:
32
+ def __init__(self, lr=1e-5, warmup_percentage=0.1, decay_factor=10.0):
33
+ self.combined = CombinedSchedule(
34
+ lambda x: x * lr, CosineSchedulerWithoutWarmup(lr, decay_factor), warmup_percentage
35
+ )
36
+
37
+ def __call__(self, completion_percentage: float) -> float:
38
+ return self.combined(completion_percentage)