skillrl 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- skillrl/__init__.py +49 -0
- skillrl/config.py +148 -0
- skillrl/core/__init__.py +36 -0
- skillrl/core/editor.py +110 -0
- skillrl/core/gate.py +94 -0
- skillrl/core/scheduler.py +88 -0
- skillrl/core/utils.py +96 -0
- skillrl/envs/__init__.py +19 -0
- skillrl/envs/base.py +52 -0
- skillrl/envs/qa.py +163 -0
- skillrl/llm/__init__.py +19 -0
- skillrl/llm/base.py +56 -0
- skillrl/llm/openai_client.py +163 -0
- skillrl/pipeline/__init__.py +22 -0
- skillrl/pipeline/aggregate.py +220 -0
- skillrl/pipeline/reflect.py +253 -0
- skillrl/pipeline/rollout.py +93 -0
- skillrl/pipeline/select.py +110 -0
- skillrl/prompts/__init__.py +53 -0
- skillrl/prompts/analyst_error.md +35 -0
- skillrl/prompts/analyst_success.md +31 -0
- skillrl/prompts/merge_failure.md +22 -0
- skillrl/prompts/merge_final.md +23 -0
- skillrl/prompts/merge_success.md +19 -0
- skillrl/prompts/ranking.md +23 -0
- skillrl/py.typed +0 -0
- skillrl/trainer.py +714 -0
- skillrl/types.py +241 -0
- skillrl-1.0.0.dist-info/METADATA +362 -0
- skillrl-1.0.0.dist-info/RECORD +33 -0
- skillrl-1.0.0.dist-info/WHEEL +5 -0
- skillrl-1.0.0.dist-info/licenses/LICENSE +21 -0
- skillrl-1.0.0.dist-info/top_level.txt +1 -0
skillrl/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""skillrl — A TRL-like training library for end-to-end skill optimization.
|
|
2
|
+
|
|
3
|
+
This package implements the core algorithm of Microsoft SkillOpt
|
|
4
|
+
(https://microsoft.github.io/SkillOpt/) as a clean, modular, TRL-style
|
|
5
|
+
training library. The trainable state is a natural-language *skill
|
|
6
|
+
document*; both the optimizer and the target LLM stay frozen.
|
|
7
|
+
|
|
8
|
+
Quick start
|
|
9
|
+
-----------
|
|
10
|
+
>>> from skillrl import SkillOptConfig, SkillOptTrainer
|
|
11
|
+
>>> from skillrl.envs.qa import SimpleQAEnv
|
|
12
|
+
>>> from skillrl.llm.openai_client import OpenAIChatClient
|
|
13
|
+
>>>
|
|
14
|
+
>>> cfg = SkillOptConfig(num_epochs=2, batch_size=8, edit_budget=4)
|
|
15
|
+
>>> env = SimpleQAEnv(train_items=[...], val_items=[...], test_items=[...])
|
|
16
|
+
>>> client = OpenAIChatClient(model="gpt-4o-mini")
|
|
17
|
+
>>> trainer = SkillOptTrainer(
|
|
18
|
+
... config=cfg, env=env,
|
|
19
|
+
... optimizer_client=client, target_client=client,
|
|
20
|
+
... initial_skill="You are a helpful assistant.",
|
|
21
|
+
... )
|
|
22
|
+
>>> summary = trainer.train()
|
|
23
|
+
>>> print(summary["test_hard"])
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from skillrl.config import SkillOptConfig
|
|
27
|
+
from skillrl.types import (
|
|
28
|
+
Edit,
|
|
29
|
+
Patch,
|
|
30
|
+
RolloutResult,
|
|
31
|
+
GateResult,
|
|
32
|
+
GateAction,
|
|
33
|
+
RawPatch,
|
|
34
|
+
)
|
|
35
|
+
from skillrl.trainer import SkillOptTrainer
|
|
36
|
+
|
|
37
|
+
__version__ = "1.0.0"
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"__version__",
|
|
41
|
+
"SkillOptConfig",
|
|
42
|
+
"SkillOptTrainer",
|
|
43
|
+
"Edit",
|
|
44
|
+
"Patch",
|
|
45
|
+
"RolloutResult",
|
|
46
|
+
"GateResult",
|
|
47
|
+
"GateAction",
|
|
48
|
+
"RawPatch",
|
|
49
|
+
]
|
skillrl/config.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""Training configuration for skillrl — a TRL-like dataclass.
|
|
2
|
+
|
|
3
|
+
Mirrors the paper-default protocol of SkillOpt (Yang et al., 2026) while
|
|
4
|
+
exposing only the knobs needed for the 1.0 core algorithm. Advanced
|
|
5
|
+
features such as slow update, meta skill, autonomous LR, accumulation,
|
|
6
|
+
codex/claude-code execution backends and full-rewrite update modes are
|
|
7
|
+
intentionally omitted from 1.0 — they are future extension points.
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field, asdict
|
|
12
|
+
from typing import Literal, Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
GateMetric = Literal["hard", "soft", "mixed"]
|
|
16
|
+
LRSchedulerMode = Literal["constant", "linear", "cosine"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SkillOptConfig:
|
|
21
|
+
"""Training arguments for :class:`~skillrl.trainer.SkillOptTrainer`.
|
|
22
|
+
|
|
23
|
+
The defaults reproduce the SkillOpt paper protocol: 4 epochs,
|
|
24
|
+
rollout batch 40, reflection minibatch 8, textual learning rate 4
|
|
25
|
+
with cosine decay, strict hard-validation gating.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
num_epochs:
|
|
30
|
+
Number of training epochs.
|
|
31
|
+
batch_size:
|
|
32
|
+
Number of training items rolled out per optimization step.
|
|
33
|
+
minibatch_size:
|
|
34
|
+
Number of trajectories grouped per analyst (Reflect) call.
|
|
35
|
+
merge_batch_size:
|
|
36
|
+
Number of patches merged together at each level of the
|
|
37
|
+
hierarchical Aggregate stage.
|
|
38
|
+
edit_budget:
|
|
39
|
+
Initial / maximum textual learning rate — the cap on edits
|
|
40
|
+
applied per optimization step. Acts as a soft trust region.
|
|
41
|
+
min_edit_budget:
|
|
42
|
+
Lower bound for decay schedules.
|
|
43
|
+
lr_scheduler:
|
|
44
|
+
Edit-budget schedule: ``constant`` | ``linear`` | ``cosine``.
|
|
45
|
+
gate_metric:
|
|
46
|
+
Validation-gate metric: ``hard`` (exact match) | ``soft``
|
|
47
|
+
(continuous reward) | ``mixed``.
|
|
48
|
+
gate_mixed_weight:
|
|
49
|
+
Weight ``w`` of the soft component when ``gate_metric='mixed'``.
|
|
50
|
+
Final score = ``(1 - w) * hard + w * soft``.
|
|
51
|
+
failure_only:
|
|
52
|
+
If True, only failed trajectories drive analyst patches.
|
|
53
|
+
workers:
|
|
54
|
+
Thread workers for parallel rollout / parallel analyst calls.
|
|
55
|
+
seed:
|
|
56
|
+
Random seed.
|
|
57
|
+
out_root:
|
|
58
|
+
Output directory. History, per-step artifacts, candidate
|
|
59
|
+
skills and ``best_skill.md`` are written here. Auto-resume
|
|
60
|
+
reads from this directory if it already contains a run.
|
|
61
|
+
selection_split:
|
|
62
|
+
Name of the split used for the validation gate. Defaults to
|
|
63
|
+
``"val"`` (the SkillOpt paper uses ``valid_seen``).
|
|
64
|
+
test_split:
|
|
65
|
+
Name of the held-out test split for the final report.
|
|
66
|
+
eval_test:
|
|
67
|
+
Whether to run the final test evaluation at the end of
|
|
68
|
+
training.
|
|
69
|
+
save_every_step:
|
|
70
|
+
If True, dump per-step ``skill_vXXXX.md`` snapshots to
|
|
71
|
+
``<out_root>/skills/`` for inspection.
|
|
72
|
+
optimizer_max_completion_tokens:
|
|
73
|
+
Token budget for analyst / merge / ranking calls.
|
|
74
|
+
verbose:
|
|
75
|
+
Print detailed per-stage progress.
|
|
76
|
+
extras:
|
|
77
|
+
Free-form extra config — passed through to environments and
|
|
78
|
+
custom prompt templates.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# ── Training schedule ────────────────────────────────────────────
|
|
82
|
+
num_epochs: int = 4
|
|
83
|
+
batch_size: int = 40
|
|
84
|
+
minibatch_size: int = 8
|
|
85
|
+
merge_batch_size: int = 8
|
|
86
|
+
|
|
87
|
+
# ── Optimizer (textual LR) ───────────────────────────────────────
|
|
88
|
+
edit_budget: int = 4
|
|
89
|
+
min_edit_budget: int = 2
|
|
90
|
+
lr_scheduler: LRSchedulerMode = "cosine"
|
|
91
|
+
|
|
92
|
+
# ── Validation gate ──────────────────────────────────────────────
|
|
93
|
+
gate_metric: GateMetric = "hard"
|
|
94
|
+
gate_mixed_weight: float = 0.5
|
|
95
|
+
|
|
96
|
+
# ── Reflect stage ────────────────────────────────────────────────
|
|
97
|
+
failure_only: bool = False
|
|
98
|
+
|
|
99
|
+
# ── Concurrency / runtime ────────────────────────────────────────
|
|
100
|
+
workers: int = 8
|
|
101
|
+
seed: int = 42
|
|
102
|
+
out_root: str = "outputs/skillrl_run"
|
|
103
|
+
|
|
104
|
+
# ── Evaluation splits ────────────────────────────────────────────
|
|
105
|
+
selection_split: str = "val"
|
|
106
|
+
test_split: str = "test"
|
|
107
|
+
eval_test: bool = True
|
|
108
|
+
|
|
109
|
+
# ── Persistence / observability ──────────────────────────────────
|
|
110
|
+
save_every_step: bool = True
|
|
111
|
+
optimizer_max_completion_tokens: int = 4096
|
|
112
|
+
verbose: bool = True
|
|
113
|
+
|
|
114
|
+
# ── Free-form extras (forwarded to env/prompts) ──────────────────
|
|
115
|
+
extras: dict[str, Any] = field(default_factory=dict)
|
|
116
|
+
|
|
117
|
+
# ── Helpers ──────────────────────────────────────────────────────
|
|
118
|
+
|
|
119
|
+
def __post_init__(self) -> None:
|
|
120
|
+
if self.batch_size <= 0:
|
|
121
|
+
raise ValueError(f"batch_size must be > 0, got {self.batch_size}")
|
|
122
|
+
if self.minibatch_size <= 0:
|
|
123
|
+
raise ValueError(f"minibatch_size must be > 0, got {self.minibatch_size}")
|
|
124
|
+
if self.merge_batch_size < 2:
|
|
125
|
+
raise ValueError(f"merge_batch_size must be >= 2, got {self.merge_batch_size}")
|
|
126
|
+
if self.edit_budget <= 0:
|
|
127
|
+
raise ValueError(f"edit_budget must be > 0, got {self.edit_budget}")
|
|
128
|
+
if self.min_edit_budget <= 0 or self.min_edit_budget > self.edit_budget:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"min_edit_budget must be in (0, edit_budget], "
|
|
131
|
+
f"got {self.min_edit_budget} (edit_budget={self.edit_budget})"
|
|
132
|
+
)
|
|
133
|
+
if self.lr_scheduler not in ("constant", "linear", "cosine"):
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"lr_scheduler must be one of constant/linear/cosine, "
|
|
136
|
+
f"got {self.lr_scheduler!r}"
|
|
137
|
+
)
|
|
138
|
+
if self.gate_metric not in ("hard", "soft", "mixed"):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"gate_metric must be one of hard/soft/mixed, got {self.gate_metric!r}"
|
|
141
|
+
)
|
|
142
|
+
if not 0.0 <= self.gate_mixed_weight <= 1.0:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"gate_mixed_weight must be in [0, 1], got {self.gate_mixed_weight}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def to_dict(self) -> dict:
|
|
148
|
+
return asdict(self)
|
skillrl/core/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""skillrl core algorithms — pure-functional building blocks.
|
|
2
|
+
|
|
3
|
+
This subpackage hosts the *non-LLM* primitives used by the trainer:
|
|
4
|
+
|
|
5
|
+
* :mod:`skillrl.core.editor` — apply edits / patches to a skill doc.
|
|
6
|
+
* :mod:`skillrl.core.scheduler` — edit-budget (textual LR) schedulers.
|
|
7
|
+
* :mod:`skillrl.core.gate` — validation gating (hard/soft/mixed).
|
|
8
|
+
* :mod:`skillrl.core.utils` — JSON extraction, scoring, hashing.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from skillrl.core.editor import apply_edit, apply_patch, apply_patch_with_report
|
|
12
|
+
from skillrl.core.scheduler import (
|
|
13
|
+
LRScheduler,
|
|
14
|
+
ConstantScheduler,
|
|
15
|
+
LinearScheduler,
|
|
16
|
+
CosineScheduler,
|
|
17
|
+
build_scheduler,
|
|
18
|
+
)
|
|
19
|
+
from skillrl.core.gate import evaluate_gate, select_gate_score
|
|
20
|
+
from skillrl.core.utils import compute_score, extract_json, skill_hash
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"apply_edit",
|
|
24
|
+
"apply_patch",
|
|
25
|
+
"apply_patch_with_report",
|
|
26
|
+
"LRScheduler",
|
|
27
|
+
"ConstantScheduler",
|
|
28
|
+
"LinearScheduler",
|
|
29
|
+
"CosineScheduler",
|
|
30
|
+
"build_scheduler",
|
|
31
|
+
"evaluate_gate",
|
|
32
|
+
"select_gate_score",
|
|
33
|
+
"compute_score",
|
|
34
|
+
"extract_json",
|
|
35
|
+
"skill_hash",
|
|
36
|
+
]
|
skillrl/core/editor.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Skill-document edit application — the Update stage (⑤).
|
|
2
|
+
|
|
3
|
+
Analogous to ``optimizer.step()`` in neural network training: a ranked
|
|
4
|
+
:class:`~skillrl.types.Patch` is applied sequentially onto the current
|
|
5
|
+
skill document, producing a candidate skill that is then validated by
|
|
6
|
+
the gate.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from skillrl.types import Edit, Patch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _edit_fields(edit: Edit | dict) -> tuple[str, str, str]:
|
|
16
|
+
"""Extract ``(op, content, target)`` from an ``Edit`` or plain dict."""
|
|
17
|
+
if isinstance(edit, Edit):
|
|
18
|
+
return edit.op, (edit.content or "").strip(), edit.target or ""
|
|
19
|
+
return (
|
|
20
|
+
str(edit.get("op", "") or ""),
|
|
21
|
+
str(edit.get("content", "") or "").strip(),
|
|
22
|
+
str(edit.get("target", "") or ""),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _apply_edit_with_report(skill: str, edit: Edit | dict) -> tuple[str, dict]:
|
|
27
|
+
op, content, target = _edit_fields(edit)
|
|
28
|
+
report: dict[str, Any] = {
|
|
29
|
+
"op": op,
|
|
30
|
+
"target": target[:200],
|
|
31
|
+
"content_preview": content[:200],
|
|
32
|
+
"status": "unknown",
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
if op == "append":
|
|
36
|
+
report["status"] = "applied_append"
|
|
37
|
+
return skill.rstrip() + "\n\n" + content + "\n", report
|
|
38
|
+
|
|
39
|
+
if op == "insert_after":
|
|
40
|
+
if not target or target not in skill:
|
|
41
|
+
report["status"] = "applied_insert_after_fallback_append"
|
|
42
|
+
return skill.rstrip() + "\n\n" + content + "\n", report
|
|
43
|
+
idx = skill.index(target) + len(target)
|
|
44
|
+
newline = skill.find("\n", idx)
|
|
45
|
+
insert_at = newline + 1 if newline != -1 else len(skill)
|
|
46
|
+
report["status"] = "applied_insert_after"
|
|
47
|
+
return skill[:insert_at] + "\n" + content + "\n" + skill[insert_at:], report
|
|
48
|
+
|
|
49
|
+
if op == "replace":
|
|
50
|
+
if not target:
|
|
51
|
+
report["status"] = "skipped_replace_missing_target"
|
|
52
|
+
return skill, report
|
|
53
|
+
if target not in skill:
|
|
54
|
+
report["status"] = "skipped_replace_target_not_found"
|
|
55
|
+
return skill, report
|
|
56
|
+
report["status"] = "applied_replace"
|
|
57
|
+
return skill.replace(target, content, 1), report
|
|
58
|
+
|
|
59
|
+
if op == "delete":
|
|
60
|
+
if not target:
|
|
61
|
+
report["status"] = "skipped_delete_missing_target"
|
|
62
|
+
return skill, report
|
|
63
|
+
if target not in skill:
|
|
64
|
+
report["status"] = "skipped_delete_target_not_found"
|
|
65
|
+
return skill, report
|
|
66
|
+
report["status"] = "applied_delete"
|
|
67
|
+
return skill.replace(target, "", 1), report
|
|
68
|
+
|
|
69
|
+
report["status"] = "skipped_unknown_op"
|
|
70
|
+
return skill, report
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def apply_edit(skill: str, edit: Edit | dict) -> str:
|
|
74
|
+
"""Apply a single edit operation to the skill document."""
|
|
75
|
+
updated, _ = _apply_edit_with_report(skill, edit)
|
|
76
|
+
return updated
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def apply_patch_with_report(
|
|
80
|
+
skill: str,
|
|
81
|
+
patch: Patch | dict,
|
|
82
|
+
) -> tuple[str, list[dict]]:
|
|
83
|
+
"""Apply a patch and return a per-edit report for observability."""
|
|
84
|
+
if isinstance(patch, Patch):
|
|
85
|
+
edits: list[Edit | dict] = list(patch.edits)
|
|
86
|
+
else:
|
|
87
|
+
edits = list(patch.get("edits", []) or [])
|
|
88
|
+
|
|
89
|
+
reports: list[dict] = []
|
|
90
|
+
for idx, edit in enumerate(edits, 1):
|
|
91
|
+
try:
|
|
92
|
+
skill, report = _apply_edit_with_report(skill, edit)
|
|
93
|
+
report["index"] = idx
|
|
94
|
+
except Exception as exc: # noqa: BLE001
|
|
95
|
+
report = {
|
|
96
|
+
"index": idx,
|
|
97
|
+
"op": "",
|
|
98
|
+
"target": "",
|
|
99
|
+
"content_preview": "",
|
|
100
|
+
"status": "error",
|
|
101
|
+
"error": str(exc),
|
|
102
|
+
}
|
|
103
|
+
reports.append(report)
|
|
104
|
+
return skill, reports
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def apply_patch(skill: str, patch: Patch | dict) -> str:
|
|
108
|
+
"""Apply a patch (list of edits) to the skill document sequentially."""
|
|
109
|
+
updated, _ = apply_patch_with_report(skill, patch)
|
|
110
|
+
return updated
|
skillrl/core/gate.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Validation gating — Stage ⑥ of the skillrl pipeline.
|
|
2
|
+
|
|
3
|
+
A candidate skill is *accepted* only when its score on the held-out
|
|
4
|
+
selection split *strictly improves* upon the running ``current_score``.
|
|
5
|
+
The metric is configurable:
|
|
6
|
+
|
|
7
|
+
* ``hard`` : exact-match accuracy (paper default).
|
|
8
|
+
* ``soft`` : continuous per-item reward.
|
|
9
|
+
* ``mixed`` : ``(1 - w) * hard + w * soft``.
|
|
10
|
+
|
|
11
|
+
The gate also tracks the global ``best_skill`` separately, so the
|
|
12
|
+
trainer can deploy the all-time-best skill at the end of training.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
from skillrl.types import GateResult
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
GateMetric = Literal["hard", "soft", "mixed"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def select_gate_score(
|
|
25
|
+
hard: float,
|
|
26
|
+
soft: float,
|
|
27
|
+
metric: GateMetric = "hard",
|
|
28
|
+
mixed_weight: float = 0.5,
|
|
29
|
+
) -> float:
|
|
30
|
+
"""Combine ``(hard, soft)`` into a single gate score per ``metric``."""
|
|
31
|
+
if metric == "hard":
|
|
32
|
+
return float(hard)
|
|
33
|
+
if metric == "soft":
|
|
34
|
+
return float(soft)
|
|
35
|
+
if metric == "mixed":
|
|
36
|
+
w = float(mixed_weight)
|
|
37
|
+
if not 0.0 <= w <= 1.0:
|
|
38
|
+
raise ValueError(f"mixed_weight must be in [0, 1], got {w}")
|
|
39
|
+
return (1.0 - w) * float(hard) + w * float(soft)
|
|
40
|
+
raise ValueError(f"Unknown gate metric: {metric!r}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def evaluate_gate(
|
|
44
|
+
*,
|
|
45
|
+
candidate_skill: str,
|
|
46
|
+
cand_hard: float,
|
|
47
|
+
current_skill: str,
|
|
48
|
+
current_score: float,
|
|
49
|
+
best_skill: str,
|
|
50
|
+
best_score: float,
|
|
51
|
+
best_step: int,
|
|
52
|
+
global_step: int,
|
|
53
|
+
cand_soft: float = 0.0,
|
|
54
|
+
metric: GateMetric = "hard",
|
|
55
|
+
mixed_weight: float = 0.5,
|
|
56
|
+
) -> GateResult:
|
|
57
|
+
"""Decide accept / reject and update ``current`` / ``best`` skill state.
|
|
58
|
+
|
|
59
|
+
The decision rule is:
|
|
60
|
+
|
|
61
|
+
* If ``cand_score > best_score`` → ``accept_new_best``: replace
|
|
62
|
+
both ``current`` and ``best``.
|
|
63
|
+
* Else if ``cand_score > current_score`` → ``accept``: replace
|
|
64
|
+
``current`` only.
|
|
65
|
+
* Else → ``reject``: keep both unchanged.
|
|
66
|
+
"""
|
|
67
|
+
cand_score = select_gate_score(cand_hard, cand_soft, metric, mixed_weight)
|
|
68
|
+
|
|
69
|
+
if cand_score > best_score:
|
|
70
|
+
return GateResult(
|
|
71
|
+
action="accept_new_best",
|
|
72
|
+
current_skill=candidate_skill,
|
|
73
|
+
current_score=cand_score,
|
|
74
|
+
best_skill=candidate_skill,
|
|
75
|
+
best_score=cand_score,
|
|
76
|
+
best_step=global_step,
|
|
77
|
+
)
|
|
78
|
+
if cand_score > current_score:
|
|
79
|
+
return GateResult(
|
|
80
|
+
action="accept",
|
|
81
|
+
current_skill=candidate_skill,
|
|
82
|
+
current_score=cand_score,
|
|
83
|
+
best_skill=best_skill,
|
|
84
|
+
best_score=best_score,
|
|
85
|
+
best_step=best_step,
|
|
86
|
+
)
|
|
87
|
+
return GateResult(
|
|
88
|
+
action="reject",
|
|
89
|
+
current_skill=current_skill,
|
|
90
|
+
current_score=current_score,
|
|
91
|
+
best_skill=best_skill,
|
|
92
|
+
best_score=best_score,
|
|
93
|
+
best_step=best_step,
|
|
94
|
+
)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Edit-budget (textual learning rate) schedulers.
|
|
2
|
+
|
|
3
|
+
In skillrl the *learning rate* is the maximum number of edits applied
|
|
4
|
+
to the skill document at each optimization step. Exposing the same
|
|
5
|
+
PyTorch-style API (``step()``, ``state_dict()`` / ``load_state_dict``)
|
|
6
|
+
keeps the abstraction familiar.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LRScheduler(ABC):
|
|
15
|
+
"""Base class for edit-budget schedulers."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, max_lr: int, min_lr: int, total_steps: int) -> None:
|
|
18
|
+
self.max_lr = int(max_lr)
|
|
19
|
+
self.min_lr = int(min_lr)
|
|
20
|
+
self.total_steps = max(int(total_steps), 1)
|
|
21
|
+
self._current_step = 0
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def _compute_lr(self, step: int) -> int:
|
|
25
|
+
"""Return the edit budget for a 1-indexed step."""
|
|
26
|
+
|
|
27
|
+
def step(self) -> int:
|
|
28
|
+
self._current_step += 1
|
|
29
|
+
return self._compute_lr(self._current_step)
|
|
30
|
+
|
|
31
|
+
def get_lr(self, step: int) -> int:
|
|
32
|
+
return self._compute_lr(step)
|
|
33
|
+
|
|
34
|
+
def state_dict(self) -> dict:
|
|
35
|
+
return {"current_step": self._current_step}
|
|
36
|
+
|
|
37
|
+
def load_state_dict(self, state: dict) -> None:
|
|
38
|
+
self._current_step = int(state.get("current_step", 0) or 0)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ConstantScheduler(LRScheduler):
|
|
42
|
+
"""Fixed budget throughout training."""
|
|
43
|
+
|
|
44
|
+
def _compute_lr(self, step: int) -> int:
|
|
45
|
+
return self.max_lr
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LinearScheduler(LRScheduler):
|
|
49
|
+
"""Linear decay from ``max_lr`` to ``min_lr`` over ``total_steps``."""
|
|
50
|
+
|
|
51
|
+
def _compute_lr(self, step: int) -> int:
|
|
52
|
+
if self.total_steps <= 1:
|
|
53
|
+
return self.max_lr
|
|
54
|
+
t = min(step, self.total_steps) / self.total_steps
|
|
55
|
+
lr = self.max_lr + (self.min_lr - self.max_lr) * t
|
|
56
|
+
return max(self.min_lr, round(lr))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class CosineScheduler(LRScheduler):
|
|
60
|
+
"""Cosine annealing from ``max_lr`` to ``min_lr`` over ``total_steps``."""
|
|
61
|
+
|
|
62
|
+
def _compute_lr(self, step: int) -> int:
|
|
63
|
+
if self.total_steps <= 1:
|
|
64
|
+
return self.max_lr
|
|
65
|
+
t = min(step, self.total_steps) / self.total_steps
|
|
66
|
+
lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t))
|
|
67
|
+
return max(self.min_lr, round(lr))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_REGISTRY: dict[str, type[LRScheduler]] = {
|
|
71
|
+
"constant": ConstantScheduler,
|
|
72
|
+
"linear": LinearScheduler,
|
|
73
|
+
"cosine": CosineScheduler,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def build_scheduler(
|
|
78
|
+
mode: str = "cosine",
|
|
79
|
+
max_lr: int = 4,
|
|
80
|
+
min_lr: int = 2,
|
|
81
|
+
total_steps: int = 8,
|
|
82
|
+
) -> LRScheduler:
|
|
83
|
+
"""Build an edit-budget scheduler from config parameters."""
|
|
84
|
+
if mode not in _REGISTRY:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Unknown scheduler mode {mode!r}. Available: {list(_REGISTRY)}"
|
|
87
|
+
)
|
|
88
|
+
return _REGISTRY[mode](max_lr=max_lr, min_lr=min_lr, total_steps=total_steps)
|
skillrl/core/utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Misc utilities — JSON extraction, scoring, hashing."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any, Iterable
|
|
8
|
+
|
|
9
|
+
from skillrl.types import RolloutResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ── Robust JSON extraction from LLM output ──────────────────────────────
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_FENCE_RE = re.compile(r"```(?:json)?\s*(.+?)```", re.DOTALL | re.IGNORECASE)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def extract_json(text: str) -> dict | None:
|
|
19
|
+
"""Best-effort extraction of a JSON object embedded in *text*.
|
|
20
|
+
|
|
21
|
+
Tries (in order):
|
|
22
|
+
|
|
23
|
+
1. parsing the entire string as JSON,
|
|
24
|
+
2. parsing the contents of the first fenced code block,
|
|
25
|
+
3. parsing the substring between the first ``{`` and last ``}``.
|
|
26
|
+
|
|
27
|
+
Returns ``None`` on failure — callers must handle this gracefully.
|
|
28
|
+
"""
|
|
29
|
+
if text is None:
|
|
30
|
+
return None
|
|
31
|
+
s = text.strip()
|
|
32
|
+
if not s:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
# 1. raw JSON
|
|
36
|
+
try:
|
|
37
|
+
obj = json.loads(s)
|
|
38
|
+
return obj if isinstance(obj, dict) else None
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
# 2. fenced block
|
|
43
|
+
m = _FENCE_RE.search(s)
|
|
44
|
+
if m:
|
|
45
|
+
try:
|
|
46
|
+
obj = json.loads(m.group(1).strip())
|
|
47
|
+
return obj if isinstance(obj, dict) else None
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
# 3. brace slice (last-resort)
|
|
52
|
+
first = s.find("{")
|
|
53
|
+
last = s.rfind("}")
|
|
54
|
+
if first != -1 and last != -1 and last > first:
|
|
55
|
+
candidate = s[first : last + 1]
|
|
56
|
+
try:
|
|
57
|
+
obj = json.loads(candidate)
|
|
58
|
+
return obj if isinstance(obj, dict) else None
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ── Aggregate rollout scoring ───────────────────────────────────────────
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def compute_score(results: Iterable[Any]) -> tuple[float, float]:
|
|
69
|
+
"""Return ``(hard_acc, soft_acc)`` averaged over a rollout batch.
|
|
70
|
+
|
|
71
|
+
Accepts ``RolloutResult`` instances or plain dicts (with ``hard``
|
|
72
|
+
and ``soft`` keys). Empty input yields ``(0.0, 0.0)``.
|
|
73
|
+
"""
|
|
74
|
+
items = list(results)
|
|
75
|
+
if not items:
|
|
76
|
+
return 0.0, 0.0
|
|
77
|
+
hard_total = 0.0
|
|
78
|
+
soft_total = 0.0
|
|
79
|
+
for r in items:
|
|
80
|
+
if isinstance(r, RolloutResult):
|
|
81
|
+
hard_total += float(r.hard)
|
|
82
|
+
soft_total += float(r.soft)
|
|
83
|
+
else:
|
|
84
|
+
hard_total += float(r.get("hard", 0) or 0)
|
|
85
|
+
soft_total += float(r.get("soft", 0.0) or 0.0)
|
|
86
|
+
n = len(items)
|
|
87
|
+
return hard_total / n, soft_total / n
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# ── Skill content hashing (cache key for selection eval) ────────────────
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def skill_hash(content: str) -> str:
|
|
94
|
+
"""Deterministic short hash for a skill document (cache key)."""
|
|
95
|
+
h = hashlib.sha256((content or "").encode("utf-8")).hexdigest()
|
|
96
|
+
return h[:16]
|
skillrl/envs/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Environment / dataset abstractions used by the trainer.
|
|
2
|
+
|
|
3
|
+
A ``SkillEnv`` defines:
|
|
4
|
+
|
|
5
|
+
* a train / val / test split,
|
|
6
|
+
* how to roll out a single item with a given skill (calling the target
|
|
7
|
+
LLM via the supplied client),
|
|
8
|
+
* how to score the rollout (``hard``: 0/1 exact-match, ``soft``: [0,1]
|
|
9
|
+
partial credit).
|
|
10
|
+
|
|
11
|
+
Bring-your-own-env: subclass :class:`SkillEnv` and pass the instance to
|
|
12
|
+
:class:`~skillrl.SkillOptTrainer`. See :class:`~skillrl.envs.qa.SimpleQAEnv`
|
|
13
|
+
for a complete reference implementation.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from skillrl.envs.base import SkillEnv
|
|
17
|
+
from skillrl.envs.qa import SimpleQAEnv
|
|
18
|
+
|
|
19
|
+
__all__ = ["SkillEnv", "SimpleQAEnv"]
|