adaptive-harmony 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_harmony/__init__.py +162 -0
- adaptive_harmony/common/__init__.py +40 -0
- adaptive_harmony/common/callbacks.py +219 -0
- adaptive_harmony/common/checkpointing.py +163 -0
- adaptive_harmony/common/dpo.py +92 -0
- adaptive_harmony/common/env_grpo.py +361 -0
- adaptive_harmony/common/grpo.py +260 -0
- adaptive_harmony/common/gspo.py +70 -0
- adaptive_harmony/common/ppo.py +303 -0
- adaptive_harmony/common/rm.py +79 -0
- adaptive_harmony/common/sft.py +121 -0
- adaptive_harmony/core/__init__.py +0 -0
- adaptive_harmony/core/dataset.py +72 -0
- adaptive_harmony/core/display.py +93 -0
- adaptive_harmony/core/image_utils.py +110 -0
- adaptive_harmony/core/reasoning.py +12 -0
- adaptive_harmony/core/reward_client/__init__.py +19 -0
- adaptive_harmony/core/reward_client/client.py +160 -0
- adaptive_harmony/core/reward_client/reward_types.py +49 -0
- adaptive_harmony/core/reward_client/websocket_utils.py +18 -0
- adaptive_harmony/core/rich_counter.py +351 -0
- adaptive_harmony/core/rl_utils.py +38 -0
- adaptive_harmony/core/schedulers.py +38 -0
- adaptive_harmony/core/structured_output.py +385 -0
- adaptive_harmony/core/utils.py +365 -0
- adaptive_harmony/environment/__init__.py +8 -0
- adaptive_harmony/environment/environment.py +121 -0
- adaptive_harmony/evaluation/__init__.py +1 -0
- adaptive_harmony/evaluation/evaluation_artifact.py +67 -0
- adaptive_harmony/graders/__init__.py +20 -0
- adaptive_harmony/graders/answer_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/answer_relevancy_judge/answer_relevancy_judge.py +102 -0
- adaptive_harmony/graders/answer_relevancy_judge/prompts.py +58 -0
- adaptive_harmony/graders/base_grader.py +265 -0
- adaptive_harmony/graders/binary_judge/__init__.py +8 -0
- adaptive_harmony/graders/binary_judge/binary_judge.py +202 -0
- adaptive_harmony/graders/binary_judge/prompts.py +125 -0
- adaptive_harmony/graders/combined_grader.py +118 -0
- adaptive_harmony/graders/context_relevancy_judge/__init__.py +3 -0
- adaptive_harmony/graders/context_relevancy_judge/context_relevancy_judge.py +128 -0
- adaptive_harmony/graders/context_relevancy_judge/prompts.py +84 -0
- adaptive_harmony/graders/exceptions.py +9 -0
- adaptive_harmony/graders/faithfulness_judge/__init__.py +3 -0
- adaptive_harmony/graders/faithfulness_judge/faithfulness_judge.py +159 -0
- adaptive_harmony/graders/faithfulness_judge/prompts.py +22 -0
- adaptive_harmony/graders/range_judge/__init__.py +7 -0
- adaptive_harmony/graders/range_judge/prompts.py +232 -0
- adaptive_harmony/graders/range_judge/range_judge.py +188 -0
- adaptive_harmony/graders/range_judge/types.py +12 -0
- adaptive_harmony/graders/reward_server_grader.py +36 -0
- adaptive_harmony/graders/templated_prompt_judge.py +237 -0
- adaptive_harmony/graders/utils.py +79 -0
- adaptive_harmony/logging_table.py +1 -0
- adaptive_harmony/metric_logger.py +452 -0
- adaptive_harmony/parameters/__init__.py +2 -0
- adaptive_harmony/py.typed +0 -0
- adaptive_harmony/runtime/__init__.py +2 -0
- adaptive_harmony/runtime/context.py +2 -0
- adaptive_harmony/runtime/data.py +2 -0
- adaptive_harmony/runtime/decorators.py +2 -0
- adaptive_harmony/runtime/model_artifact_save.py +2 -0
- adaptive_harmony/runtime/runner.py +27 -0
- adaptive_harmony/runtime/simple_notifier.py +2 -0
- adaptive_harmony-0.1.23.dist-info/METADATA +37 -0
- adaptive_harmony-0.1.23.dist-info/RECORD +67 -0
- adaptive_harmony-0.1.23.dist-info/WHEEL +5 -0
- adaptive_harmony-0.1.23.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Progress tracking and coroutine observability for async operations.
|
|
5
|
+
|
|
6
|
+
Provides real-time visualization of async task execution with a progress bar
|
|
7
|
+
and hierarchical coroutine call tree display.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import os
|
|
12
|
+
import weakref
|
|
13
|
+
|
|
14
|
+
from rich.console import Group
|
|
15
|
+
from rich.live import Live
|
|
16
|
+
from rich.panel import Panel
|
|
17
|
+
from rich.progress import (
|
|
18
|
+
BarColumn,
|
|
19
|
+
Progress,
|
|
20
|
+
TaskID,
|
|
21
|
+
TextColumn,
|
|
22
|
+
TimeRemainingColumn,
|
|
23
|
+
)
|
|
24
|
+
from rich.table import Table
|
|
25
|
+
from rich.text import Text
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _is_stdlib(filename: str) -> bool:
|
|
29
|
+
"""Check if a filename is from Python's standard library."""
|
|
30
|
+
if not filename:
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
# Check for common stdlib patterns
|
|
34
|
+
stdlib_indicators = [
|
|
35
|
+
"asyncio",
|
|
36
|
+
"threading",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
return any(indicator in filename for indicator in stdlib_indicators)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def describe_coroutine(coro):
|
|
43
|
+
"""
|
|
44
|
+
Extracts and returns all user code locations in the await chain as a list
|
|
45
|
+
(excluding stdlib). Each entry is formatted as "filename:lineno code..."
|
|
46
|
+
"""
|
|
47
|
+
user_code_locations = []
|
|
48
|
+
|
|
49
|
+
# Follow the cr_await chain to collect all user coroutines
|
|
50
|
+
while coro is not None:
|
|
51
|
+
frame = getattr(coro, "cr_frame", None)
|
|
52
|
+
if frame is None:
|
|
53
|
+
break
|
|
54
|
+
|
|
55
|
+
filename = frame.f_code.co_filename
|
|
56
|
+
lineno = frame.f_lineno
|
|
57
|
+
|
|
58
|
+
# Don't descend into stdlib code
|
|
59
|
+
if _is_stdlib(filename):
|
|
60
|
+
user_code_locations.append("Internals:...")
|
|
61
|
+
break
|
|
62
|
+
|
|
63
|
+
# Collect this user code location (no indentation - tree will handle that)
|
|
64
|
+
# Truncate long file paths to just the filename
|
|
65
|
+
short_filename = filename.split("/")[-1]
|
|
66
|
+
# Get function name from frame
|
|
67
|
+
func_name = frame.f_code.co_name
|
|
68
|
+
# Keep it very short to fit in panel
|
|
69
|
+
user_code_locations.append(f"{short_filename}:{lineno} {func_name}()")
|
|
70
|
+
|
|
71
|
+
# Check if this coroutine is awaiting another coroutine
|
|
72
|
+
awaited = getattr(coro, "cr_await", None)
|
|
73
|
+
if awaited is None or not hasattr(awaited, "cr_frame"):
|
|
74
|
+
# This is the innermost coroutine
|
|
75
|
+
break
|
|
76
|
+
|
|
77
|
+
# Descend into the awaited coroutine
|
|
78
|
+
coro = awaited
|
|
79
|
+
|
|
80
|
+
return user_code_locations
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ProgressCounterRegistry:
|
|
84
|
+
_instance = None
|
|
85
|
+
|
|
86
|
+
def __new__(cls):
|
|
87
|
+
if cls._instance is None:
|
|
88
|
+
cls._instance = super().__new__(cls)
|
|
89
|
+
cls._instance._initialized = False
|
|
90
|
+
return cls._instance
|
|
91
|
+
|
|
92
|
+
def __init__(self):
|
|
93
|
+
if self._initialized:
|
|
94
|
+
return
|
|
95
|
+
self._main_counter_instance: ProgressCounter | None = None
|
|
96
|
+
self._wrappers: dict[str, ProgressCounterWrapper] = {}
|
|
97
|
+
self._initialized = True
|
|
98
|
+
|
|
99
|
+
def get_main_counter(self) -> "ProgressCounter | None":
|
|
100
|
+
return self._main_counter_instance
|
|
101
|
+
|
|
102
|
+
def set_main_counter(self, counter: "ProgressCounter | None"):
|
|
103
|
+
self._main_counter_instance = counter
|
|
104
|
+
|
|
105
|
+
def get_wrapper(self, key: str) -> "ProgressCounterWrapper | None":
|
|
106
|
+
return self._wrappers.get(key)
|
|
107
|
+
|
|
108
|
+
def set_wrapper(self, key: str, wrapper: "ProgressCounterWrapper"):
|
|
109
|
+
self._wrappers[key] = wrapper
|
|
110
|
+
|
|
111
|
+
def reset(self):
|
|
112
|
+
self._main_counter_instance = None
|
|
113
|
+
self._wrappers.clear()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class CoroutineTree:
|
|
117
|
+
"""Manages building and formatting a tree structure from coroutine call chains."""
|
|
118
|
+
|
|
119
|
+
def __init__(self):
|
|
120
|
+
self.max_height = 1
|
|
121
|
+
|
|
122
|
+
def build_from_chains(self, chains: list[list[str]]) -> dict:
|
|
123
|
+
"""Build a tree structure from coroutine call chains."""
|
|
124
|
+
tree = {}
|
|
125
|
+
|
|
126
|
+
for chain in chains:
|
|
127
|
+
current = tree
|
|
128
|
+
for desc in chain:
|
|
129
|
+
if desc not in current:
|
|
130
|
+
current[desc] = {"count": 0, "children": {}}
|
|
131
|
+
current[desc]["count"] += 1
|
|
132
|
+
current = current[desc]["children"]
|
|
133
|
+
|
|
134
|
+
return tree
|
|
135
|
+
|
|
136
|
+
def format_as_text(self, tree: dict, prefix: str = "") -> Text:
|
|
137
|
+
"""Format tree structure as Rich Text with counts."""
|
|
138
|
+
text = Text()
|
|
139
|
+
items = list(tree.items())
|
|
140
|
+
|
|
141
|
+
for i, (desc, node) in enumerate(items):
|
|
142
|
+
is_last_item = i == len(items) - 1
|
|
143
|
+
|
|
144
|
+
# Tree branch characters
|
|
145
|
+
if prefix == "":
|
|
146
|
+
branch = ""
|
|
147
|
+
else:
|
|
148
|
+
branch = "└── " if is_last_item else "├── "
|
|
149
|
+
|
|
150
|
+
# Format line with count
|
|
151
|
+
count_str = f"[{node['count']}] "
|
|
152
|
+
text.append(prefix + branch, style="dim")
|
|
153
|
+
text.append(count_str, style="bold magenta")
|
|
154
|
+
text.append(desc + "\n")
|
|
155
|
+
|
|
156
|
+
# Recurse into children
|
|
157
|
+
if node["children"]:
|
|
158
|
+
# Add indentation for children
|
|
159
|
+
new_prefix = prefix + (" " if is_last_item else "│ ")
|
|
160
|
+
text.append(self.format_as_text(node["children"], new_prefix))
|
|
161
|
+
|
|
162
|
+
return text
|
|
163
|
+
|
|
164
|
+
def create_padded_display(self, chains: list[list[str]], active_count: int) -> Text:
|
|
165
|
+
"""Create a formatted and padded display text for the tree."""
|
|
166
|
+
if not chains:
|
|
167
|
+
return Text("No chains available", style="dim")
|
|
168
|
+
|
|
169
|
+
# Build and format tree
|
|
170
|
+
tree = self.build_from_chains(chains)
|
|
171
|
+
header = Text(f"{active_count} tasks active\n\n", style="bold cyan")
|
|
172
|
+
tree_content = self.format_as_text(tree)
|
|
173
|
+
|
|
174
|
+
# Count lines to track height
|
|
175
|
+
current_height = len(header.plain.split("\n")) + len(tree_content.plain.split("\n"))
|
|
176
|
+
if current_height > self.max_height:
|
|
177
|
+
self.max_height = current_height
|
|
178
|
+
|
|
179
|
+
# Combine and add padding
|
|
180
|
+
display = Text()
|
|
181
|
+
display.append(header)
|
|
182
|
+
display.append(tree_content)
|
|
183
|
+
|
|
184
|
+
# Add padding to maintain consistent height
|
|
185
|
+
padding_lines = max(0, self.max_height - current_height)
|
|
186
|
+
for _ in range(padding_lines):
|
|
187
|
+
display.append("\n")
|
|
188
|
+
|
|
189
|
+
return display
|
|
190
|
+
|
|
191
|
+
def create_empty_display(self) -> Text:
|
|
192
|
+
"""Create a display for when there are no active tasks."""
|
|
193
|
+
display = Text("No active tasks but updated", style="dim")
|
|
194
|
+
|
|
195
|
+
# Pad to maintain consistent height
|
|
196
|
+
current_height = 1
|
|
197
|
+
padding_lines = max(0, self.max_height - current_height)
|
|
198
|
+
for _ in range(padding_lines):
|
|
199
|
+
display.append("\n")
|
|
200
|
+
|
|
201
|
+
return display
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class ProgressCounter:
|
|
205
|
+
def __init__(self, main_job_name: str, total_tasks: int):
|
|
206
|
+
registry = ProgressCounterRegistry()
|
|
207
|
+
assert registry.get_main_counter() is None, "Only one main counter instance is allowed"
|
|
208
|
+
self.total_tasks = total_tasks
|
|
209
|
+
|
|
210
|
+
# construct the left panel that will give the state of the batch at a glance
|
|
211
|
+
self.overall_progress = Progress(
|
|
212
|
+
"{task.description}",
|
|
213
|
+
BarColumn(),
|
|
214
|
+
TextColumn("[magenta]{task.completed:n} steps"),
|
|
215
|
+
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
216
|
+
TimeRemainingColumn(),
|
|
217
|
+
)
|
|
218
|
+
self.overall_task = self.overall_progress.add_task(main_job_name, total=self.total_tasks)
|
|
219
|
+
self.overall_panel = Panel(
|
|
220
|
+
self.overall_progress, title="Overall Progress", border_style="green", padding=(1, 1, 0, 1)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# construct the right panel for the coroutine tree
|
|
224
|
+
self.coroutine_tree = CoroutineTree()
|
|
225
|
+
self.tree_text = Text("Waiting for tasks...", style="dim")
|
|
226
|
+
self.tree_panel = Panel(self.tree_text, title="[b]Coroutine Tree", border_style="blue", padding=(1, 2))
|
|
227
|
+
|
|
228
|
+
self._monitored_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet()
|
|
229
|
+
self._monitor_task: asyncio.Task | None = None
|
|
230
|
+
self.interval = 0.5
|
|
231
|
+
|
|
232
|
+
# Construct the general table with two columns
|
|
233
|
+
self.progress_table = Table.grid()
|
|
234
|
+
self.progress_table.add_row(self.overall_panel, self.tree_panel)
|
|
235
|
+
|
|
236
|
+
# Disable the display if DISABLE_RICH_PROGRESS is set
|
|
237
|
+
# TODO: make this less dirty
|
|
238
|
+
self.enable_display = not os.getenv("DISABLE_RICH_PROGRESS", "0") == "1"
|
|
239
|
+
|
|
240
|
+
def register_task(self, task: asyncio.Task):
|
|
241
|
+
self._monitored_tasks.add(task)
|
|
242
|
+
|
|
243
|
+
def increment_total_counter(self):
|
|
244
|
+
self.overall_progress.advance(TaskID(0))
|
|
245
|
+
|
|
246
|
+
def is_done(self):
|
|
247
|
+
return self.overall_progress.finished
|
|
248
|
+
|
|
249
|
+
async def _monitor_loop(self):
|
|
250
|
+
try:
|
|
251
|
+
while True:
|
|
252
|
+
await asyncio.sleep(self.interval)
|
|
253
|
+
|
|
254
|
+
# Filter out done tasks (just in case gc hasn't run yet)
|
|
255
|
+
active = [t for t in self._monitored_tasks if not t.done()]
|
|
256
|
+
|
|
257
|
+
# Collect all call chains
|
|
258
|
+
chains = []
|
|
259
|
+
for task in active:
|
|
260
|
+
coro = task.get_coro()
|
|
261
|
+
descriptions = describe_coroutine(coro)
|
|
262
|
+
if descriptions:
|
|
263
|
+
chains.append(descriptions)
|
|
264
|
+
|
|
265
|
+
# Create tree display
|
|
266
|
+
if not active or not chains:
|
|
267
|
+
self.tree_text = self.coroutine_tree.create_empty_display()
|
|
268
|
+
else:
|
|
269
|
+
self.tree_text = self.coroutine_tree.create_padded_display(chains, len(active))
|
|
270
|
+
|
|
271
|
+
# Update tree panel
|
|
272
|
+
self.tree_panel = Panel(self.tree_text, title="[b]Coroutine Tree", border_style="blue", padding=(1, 2))
|
|
273
|
+
|
|
274
|
+
# Update overall panel with matching height
|
|
275
|
+
overall_with_padding = Group(
|
|
276
|
+
self.overall_progress,
|
|
277
|
+
Text("\n" * max(0, self.coroutine_tree.max_height - 2)), # Account for progress bar base height
|
|
278
|
+
)
|
|
279
|
+
self.overall_panel = Panel(
|
|
280
|
+
overall_with_padding, title="Overall Progress", border_style="green", padding=(1, 1, 0, 1)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Recreate the progress table with updated panels
|
|
284
|
+
self.progress_table = Table.grid()
|
|
285
|
+
self.progress_table.add_row(self.overall_panel, self.tree_panel)
|
|
286
|
+
if self.enable_display:
|
|
287
|
+
self.live.update(self.progress_table)
|
|
288
|
+
|
|
289
|
+
except asyncio.CancelledError:
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
async def __aenter__(self):
|
|
293
|
+
if self.enable_display:
|
|
294
|
+
self.live = Live(self.progress_table, refresh_per_second=4).__enter__()
|
|
295
|
+
self._monitor_task = asyncio.create_task(self._monitor_loop())
|
|
296
|
+
return self
|
|
297
|
+
|
|
298
|
+
async def __aexit__(self, exc_type, exc_value, exc_traceback):
|
|
299
|
+
# Cancel the monitor task
|
|
300
|
+
if self._monitor_task and not self._monitor_task.done():
|
|
301
|
+
self._monitor_task.cancel()
|
|
302
|
+
try:
|
|
303
|
+
await self._monitor_task
|
|
304
|
+
except asyncio.CancelledError:
|
|
305
|
+
pass
|
|
306
|
+
|
|
307
|
+
if self.enable_display:
|
|
308
|
+
self.live.__exit__(exc_type, exc_value, exc_traceback)
|
|
309
|
+
registry = ProgressCounterRegistry()
|
|
310
|
+
assert registry.get_main_counter() is not None, "Weird state"
|
|
311
|
+
# we clear both the main counter and any wrapper that would point to the old counter
|
|
312
|
+
registry.reset()
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class ProgressCounterWrapper(ProgressCounter):
|
|
316
|
+
"""Used to wrap a preexisting counter, needed to have simple logic in case an async_map is used inside an async_map_batch"""
|
|
317
|
+
|
|
318
|
+
def __init__(self, inner: ProgressCounter, main_function_name: str):
|
|
319
|
+
self.inner = inner
|
|
320
|
+
self.main_function_name = main_function_name
|
|
321
|
+
|
|
322
|
+
def register_task(self, task: asyncio.Task):
|
|
323
|
+
return self.inner.register_task(task)
|
|
324
|
+
|
|
325
|
+
def increment_total_counter(self):
|
|
326
|
+
# For now we do nothing
|
|
327
|
+
...
|
|
328
|
+
|
|
329
|
+
async def __aenter__(self):
|
|
330
|
+
# do nothing, the inner progress counter has already been initialized
|
|
331
|
+
return self
|
|
332
|
+
|
|
333
|
+
async def __aexit__(self, _exc_type, _exc_value, _exc_traceback):
|
|
334
|
+
# do nothing, the inner progress counter will be exited later
|
|
335
|
+
...
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_progress_counter_or_wrapper(main_job_name: str, total_samples: int):
|
|
339
|
+
registry = ProgressCounterRegistry()
|
|
340
|
+
main_counter = registry.get_main_counter()
|
|
341
|
+
|
|
342
|
+
if main_counter is None:
|
|
343
|
+
main_counter = ProgressCounter(main_job_name, total_samples)
|
|
344
|
+
registry.set_main_counter(main_counter)
|
|
345
|
+
return main_counter
|
|
346
|
+
else:
|
|
347
|
+
wrapper = registry.get_wrapper(main_job_name)
|
|
348
|
+
if wrapper is None:
|
|
349
|
+
wrapper = ProgressCounterWrapper(main_counter, main_job_name)
|
|
350
|
+
registry.set_wrapper(main_job_name, wrapper)
|
|
351
|
+
return wrapper
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
def gae_advantages(
|
|
2
|
+
values: list[float],
|
|
3
|
+
rewards: list[float],
|
|
4
|
+
gae_lambda: float,
|
|
5
|
+
gae_gamma: float,
|
|
6
|
+
) -> list[float]:
|
|
7
|
+
response_length = len(values)
|
|
8
|
+
|
|
9
|
+
lastgaelam = 0.0
|
|
10
|
+
advantages_reversed: list[float] = []
|
|
11
|
+
for t in reversed(range(response_length)):
|
|
12
|
+
nextvalues = values[t + 1] if t < response_length - 1 else 0.0
|
|
13
|
+
delta = rewards[t] + gae_gamma * nextvalues - values[t]
|
|
14
|
+
lastgaelam = delta + gae_gamma * gae_lambda * lastgaelam
|
|
15
|
+
advantages_reversed.append(lastgaelam)
|
|
16
|
+
|
|
17
|
+
return advantages_reversed[::-1]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def discounted_cumulative_rewards(
|
|
21
|
+
rewards: list[float],
|
|
22
|
+
gamma: float,
|
|
23
|
+
) -> list[float]:
|
|
24
|
+
n = len(rewards)
|
|
25
|
+
returns = [0.0] * n
|
|
26
|
+
returns[-1] = rewards[-1]
|
|
27
|
+
for t in reversed(range(n - 1)):
|
|
28
|
+
returns[t] = rewards[t] + gamma * returns[t + 1]
|
|
29
|
+
|
|
30
|
+
return returns
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def gae_td_returns(
|
|
34
|
+
advantages: list[float],
|
|
35
|
+
values: list[float],
|
|
36
|
+
) -> list[float]:
|
|
37
|
+
returns = [a + b for a, b in zip(advantages, values)]
|
|
38
|
+
return returns
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
Scheduler = Callable[[float], float]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CosineSchedulerWithoutWarmup:
|
|
8
|
+
def __init__(self, lr=1e-5, decay_factor=10.0) -> None:
|
|
9
|
+
self.max_value = lr
|
|
10
|
+
self.min_value = self.max_value / decay_factor
|
|
11
|
+
|
|
12
|
+
def __call__(self, completion_percentage: float) -> float:
|
|
13
|
+
coefficient = 0.5 * (math.cos(math.pi * completion_percentage) + 1.0)
|
|
14
|
+
value_delta = self.max_value - self.min_value
|
|
15
|
+
return self.min_value + coefficient * value_delta
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CombinedSchedule:
|
|
19
|
+
def __init__(self, a: Scheduler, b: Scheduler, change_point: float) -> None:
|
|
20
|
+
self.a = a
|
|
21
|
+
self.b = b
|
|
22
|
+
self.change_point = change_point
|
|
23
|
+
|
|
24
|
+
def __call__(self, completion_percentage: float) -> float:
|
|
25
|
+
if completion_percentage < self.change_point:
|
|
26
|
+
return self.a(completion_percentage / self.change_point)
|
|
27
|
+
else:
|
|
28
|
+
return self.b((completion_percentage - self.change_point) / (1 - self.change_point))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CosineScheduler:
|
|
32
|
+
def __init__(self, lr=1e-5, warmup_percentage=0.1, decay_factor=10.0):
|
|
33
|
+
self.combined = CombinedSchedule(
|
|
34
|
+
lambda x: x * lr, CosineSchedulerWithoutWarmup(lr, decay_factor), warmup_percentage
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def __call__(self, completion_percentage: float) -> float:
|
|
38
|
+
return self.combined(completion_percentage)
|