freesolo 0.2.45__tar.gz → 0.2.46__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {freesolo-0.2.45 → freesolo-0.2.46}/.github/workflows/publish-packages.yml +6 -1
- {freesolo-0.2.45 → freesolo-0.2.46}/PKG-INFO +1 -1
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/__init__.py +4 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/README.md +16 -2
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/datums.py +122 -25
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/train_grpo.py +41 -11
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/train_sft.py +11 -9
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/types.py +15 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/checkpoints.py +67 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pyproject.toml +1 -1
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_grpo_datums_and_sampling.py +96 -8
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_records_rewards_and_config.py +5 -5
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_training_efficiency_fixes.py +78 -2
- {freesolo-0.2.45 → freesolo-0.2.46}/uv.lock +1 -1
- {freesolo-0.2.45 → freesolo-0.2.46}/.github/workflows/python-checks.yml +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/.github/workflows/sync-package-function-usage.yml +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/.github/workflows/version-consistency.yml +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/.gitignore +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/AGENTS.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/PROMPT.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/TRAINING_CONTRACT.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/data/support_eval.jsonl +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/data/support_train.jsonl +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/environment.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/evaluation_custom_scorer.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/evaluation_from_files.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/gepa_prompt_example.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/support_dataset.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/tracing_manual_span.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/tracing_multistep_agent.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/examples/training_sft_grpo.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/bun.lock +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/core.d.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/core.d.ts.map +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/core.js +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/evaluation.d.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/evaluation.d.ts.map +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/evaluation.js +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/index.d.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/index.d.ts.map +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/index.js +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/tracing.d.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/tracing.d.ts.map +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/dist/tracing.js +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/package.json +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/src/core.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/src/evaluation.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/src/index.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/src/tracing.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/tests/evaluation.test.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/tests/tracing.test.ts +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/npm/tsconfig.json +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/package.json +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/.gitignore +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/contracts/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/contracts/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/contracts/markdown.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/contracts/types.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/datasets/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/datasets/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/datasets/core.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/datasets/records.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/datasets/types.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/environments/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/environments/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/environments/base.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/environments/evaluation.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/environments/types.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/client.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/judges/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/judges/base.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/responses.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/results.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/evaluation/types.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/adapter.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/reflection.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/setup.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/gepa/types.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/py.typed +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/tracing/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/tracing/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/tracing/otel.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/tracing/sanitize.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/config.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/rewards.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/grpo/sampling.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/storage.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/training/wandb_series.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/README.md +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/__init__.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/core.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/hosting.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/judge.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/openai.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/oracle.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/storage.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/upload.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/pypi/freesolo/utils/wandb.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/ruff.toml +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/end_to_end_testing/test_environment_evaluation_flow.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/end_to_end_testing/test_examples.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_contracts_and_judges.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_core_utils.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_datasets.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_environment_evaluation_edges.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_evaluation_client.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_gepa_adapter.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_hosting_and_deployment_clients.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_openai_and_oracle_tokens.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_package_metadata.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_storage_sync.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_tracing_opentelemetry.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_train_sft.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_upload.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_utils_checkpoints.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_wandb_series.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/functionality/test_wandb_utils.py +0 -0
- {freesolo-0.2.45 → freesolo-0.2.46}/tests/security/test_sanitize_and_contract_security.py +0 -0
|
@@ -264,7 +264,12 @@ jobs:
|
|
|
264
264
|
echo "::error::NPM_TOKEN is not configured; refusing to skip publish."
|
|
265
265
|
exit 1
|
|
266
266
|
fi
|
|
267
|
-
bun publish
|
|
267
|
+
# bun publish does not pick up NODE_AUTH_TOKEN or ~/.npmrc auth, so
|
|
268
|
+
# publish the bun-built package with npm and a project npmrc.
|
|
269
|
+
umask 077
|
|
270
|
+
printf '//registry.npmjs.org/:_authToken=%s\n' "$NODE_AUTH_TOKEN" > .npmrc
|
|
271
|
+
npm publish --access public
|
|
272
|
+
rm -f .npmrc
|
|
268
273
|
|
|
269
274
|
- name: No npm package changes
|
|
270
275
|
if: github.event_name == 'push' && steps.changes.outputs.npm_changed == 'false'
|
|
@@ -4,6 +4,7 @@ from types import ModuleType
|
|
|
4
4
|
from .grpo.config import GrpoConfig
|
|
5
5
|
from .types import (
|
|
6
6
|
DEFAULT_TRAINING_LORA_RANK,
|
|
7
|
+
SUPPORTED_TRAINING_MODELS,
|
|
7
8
|
SUPPORTED_TRAINING_RENDERERS,
|
|
8
9
|
TRAINING_BASE_MODEL,
|
|
9
10
|
TRAINING_RENDERER_NAME,
|
|
@@ -12,6 +13,7 @@ from .types import (
|
|
|
12
13
|
TrainSftOptions,
|
|
13
14
|
resolve_sft_config,
|
|
14
15
|
resolve_tinker_base_url,
|
|
16
|
+
resolve_training_model,
|
|
15
17
|
resolve_training_renderer,
|
|
16
18
|
tinker_checkpoint_run_config,
|
|
17
19
|
tinker_run_config,
|
|
@@ -48,6 +50,7 @@ __all__ = [
|
|
|
48
50
|
"REWARD_METADATA_MEAN_TEMPLATE",
|
|
49
51
|
"REWARD_METADATA_RATE_TEMPLATE",
|
|
50
52
|
"SFT_WANDB_SERIES",
|
|
53
|
+
"SUPPORTED_TRAINING_MODELS",
|
|
51
54
|
"SUPPORTED_TRAINING_RENDERERS",
|
|
52
55
|
"TRAINING_BASE_MODEL",
|
|
53
56
|
"TRAINING_RENDERER_NAME",
|
|
@@ -57,6 +60,7 @@ __all__ = [
|
|
|
57
60
|
"TrainSftOptions",
|
|
58
61
|
"resolve_sft_config",
|
|
59
62
|
"resolve_tinker_base_url",
|
|
63
|
+
"resolve_training_model",
|
|
60
64
|
"resolve_training_renderer",
|
|
61
65
|
"tinker_checkpoint_run_config",
|
|
62
66
|
"tinker_run_config",
|
|
@@ -95,8 +95,22 @@ class RepoEnvironment(EnvironmentSingleTurn):
|
|
|
95
95
|
or sampling helpers directly from generated repos.
|
|
96
96
|
- Do not block GRPO on SFT. Pass `sft_state_path` or `sft_log_dir` only for a
|
|
97
97
|
deliberate warm-start comparison.
|
|
98
|
-
-
|
|
99
|
-
|
|
98
|
+
- Advantages use group reward-decoupled normalization (arXiv:2601.05242):
|
|
99
|
+
each reward component is z-normalized within the rollout group
|
|
100
|
+
independently and the weighted normalized advantages are summed, then the
|
|
101
|
+
whole batch is normalized once more. Components come from
|
|
102
|
+
`RewardResult.metrics` entries (one named component per contract reward
|
|
103
|
+
function). To weight components unequally set `RewardMetric.weight`
|
|
104
|
+
(default 1.0) — it multiplies the component's normalized advantage;
|
|
105
|
+
pre-scaling raw scores does nothing because z-normalization cancels
|
|
106
|
+
scale. Only keys reported by every result in the group are compared, and
|
|
107
|
+
when those carry no signal (mixed metric coverage, or shared components
|
|
108
|
+
that tie) the combined scores are z-normalized instead. `advantage_clip`
|
|
109
|
+
clamps the batch-normalized values.
|
|
110
|
+
- If every reward component in a group is constant, the group carries no
|
|
111
|
+
training signal and GRPO skips it. Design rewards with enough diversity to
|
|
112
|
+
create trainable groups; component-level resolution means groups whose
|
|
113
|
+
combined totals tie can still train when individual components differ.
|
|
100
114
|
- Log reward diagnostics such as nonzero rate, unique reward count, uniform
|
|
101
115
|
groups, trainable groups, invalid output rate, and representative completions.
|
|
102
116
|
- Keep public eval semantics stricter and stable; training reward shaping may be
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import math
|
|
4
5
|
from dataclasses import dataclass, field
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
@@ -69,6 +70,7 @@ async def build_grpo_batch_datums(
|
|
|
69
70
|
return token_trace_groups, episodes, reward_results
|
|
70
71
|
|
|
71
72
|
result = GrpoBatchResult()
|
|
73
|
+
pending_groups: list[tuple[list[list[TokenTrace]], list[float]]] = []
|
|
72
74
|
scored_examples = await asyncio.gather(
|
|
73
75
|
*[sample_and_score(example) for example in batch]
|
|
74
76
|
)
|
|
@@ -78,32 +80,29 @@ async def build_grpo_batch_datums(
|
|
|
78
80
|
"Environment score_episodes() must return one RewardResult per sampled episode"
|
|
79
81
|
)
|
|
80
82
|
result.reward_results.extend(reward_results)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
for token_traces, reward_result in zip(
|
|
85
|
-
token_trace_groups,
|
|
86
|
-
reward_results,
|
|
87
|
-
strict=True,
|
|
88
|
-
):
|
|
89
|
-
reward = float(reward_result.score)
|
|
90
|
-
rewards.append(reward)
|
|
91
|
-
result.rewards.append(reward)
|
|
92
|
-
rescored_traces.append((token_traces, reward))
|
|
93
|
-
|
|
83
|
+
result.rewards.extend(
|
|
84
|
+
float(reward_result.score) for reward_result in reward_results
|
|
85
|
+
)
|
|
94
86
|
result.response_count += len(episodes)
|
|
95
87
|
result.group_count += 1
|
|
96
|
-
|
|
88
|
+
|
|
89
|
+
group_advantages = decoupled_group_advantages(reward_results)
|
|
90
|
+
if not any(value != 0.0 for value in group_advantages):
|
|
97
91
|
result.uniform_group_count += 1
|
|
98
92
|
continue
|
|
99
93
|
|
|
100
94
|
result.trainable_group_count += 1
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
95
|
+
pending_groups.append((token_trace_groups, group_advantages))
|
|
96
|
+
|
|
97
|
+
normalized = batch_normalized_advantages(
|
|
98
|
+
[value for _, advantages in pending_groups for value in advantages],
|
|
99
|
+
advantage_clip=advantage_clip,
|
|
100
|
+
)
|
|
101
|
+
cursor = 0
|
|
102
|
+
for token_trace_groups, _ in pending_groups:
|
|
103
|
+
for token_traces in token_trace_groups:
|
|
104
|
+
advantage = normalized[cursor]
|
|
105
|
+
cursor += 1
|
|
107
106
|
for token_trace in token_traces:
|
|
108
107
|
token_advantages = [
|
|
109
108
|
advantage * mask_value for mask_value in token_trace.advantage_mask
|
|
@@ -278,12 +277,110 @@ def _resolve_max_episode_turns(value: int) -> int:
|
|
|
278
277
|
return value
|
|
279
278
|
|
|
280
279
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
280
|
+
_ADVANTAGE_EPSILON = 1e-6
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _metric_weight(metric: Any) -> float:
|
|
284
|
+
weight = getattr(metric, "weight", None)
|
|
285
|
+
if (
|
|
286
|
+
isinstance(weight, (int, float))
|
|
287
|
+
and not isinstance(weight, bool)
|
|
288
|
+
and math.isfinite(weight)
|
|
289
|
+
and weight > 0
|
|
290
|
+
):
|
|
291
|
+
return float(weight)
|
|
292
|
+
return 1.0
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _component_scores(reward_result: Any) -> dict[str, tuple[float, float]]:
|
|
296
|
+
"""Map component name -> (score, weight) for one RewardResult."""
|
|
297
|
+
components: dict[str, tuple[float, float]] = {}
|
|
298
|
+
for metric in getattr(reward_result, "metrics", ()) or ():
|
|
299
|
+
name = getattr(metric, "name", None)
|
|
300
|
+
score = getattr(metric, "score", None)
|
|
301
|
+
if (
|
|
302
|
+
isinstance(name, str)
|
|
303
|
+
and name
|
|
304
|
+
and isinstance(score, (int, float))
|
|
305
|
+
and not isinstance(score, bool)
|
|
306
|
+
):
|
|
307
|
+
components[name] = (float(score), _metric_weight(metric))
|
|
308
|
+
if components:
|
|
309
|
+
return components
|
|
310
|
+
return {"score": (float(reward_result.score), 1.0)}
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def decoupled_group_advantages(reward_results: list[Any]) -> list[float]:
|
|
314
|
+
"""Group advantages with reward-decoupled normalization (arXiv:2601.05242).
|
|
315
|
+
|
|
316
|
+
Joint normalization of a summed multi-reward score collapses distinct
|
|
317
|
+
reward combinations into identical advantages. Instead, each reward
|
|
318
|
+
component is z-normalized within the group independently and the
|
|
319
|
+
weighted normalized advantages are summed, preserving per-reward
|
|
320
|
+
resolution and making components comparable regardless of their raw
|
|
321
|
+
scales. Components come from RewardResult.metrics (one per contract
|
|
322
|
+
reward function); RewardMetric.weight multiplies the component's
|
|
323
|
+
normalized advantage (default 1.0 — raw score scale carries no weight,
|
|
324
|
+
z-normalization cancels it). Only keys reported by every result in the
|
|
325
|
+
group are compared: filling a missing component with an invented value
|
|
326
|
+
lets components cancel (a metric-less error-path result's fallback
|
|
327
|
+
score against zero-filled metric rows can zero out a group that has
|
|
328
|
+
real reward differences). When the shared components carry no variance,
|
|
329
|
+
the combined scores are z-normalized instead so reward differences
|
|
330
|
+
still train.
|
|
331
|
+
"""
|
|
332
|
+
component_maps = [
|
|
333
|
+
_component_scores(reward_result) for reward_result in reward_results
|
|
334
|
+
]
|
|
335
|
+
shared_keys = sorted(
|
|
336
|
+
set.intersection(*(set(mapping) for mapping in component_maps))
|
|
337
|
+
if component_maps
|
|
338
|
+
else set()
|
|
339
|
+
)
|
|
340
|
+
advantages = [0.0] * len(reward_results)
|
|
341
|
+
for key in shared_keys:
|
|
342
|
+
values = [mapping[key][0] for mapping in component_maps]
|
|
343
|
+
# The weight is a property of the reward function, not the rollout;
|
|
344
|
+
# the first rollout's entry is canonical when they disagree.
|
|
345
|
+
weight = component_maps[0][key][1]
|
|
346
|
+
mean = sum(values) / len(values)
|
|
347
|
+
std = math.sqrt(sum((value - mean) ** 2 for value in values) / len(values))
|
|
348
|
+
if std == 0.0:
|
|
349
|
+
continue
|
|
350
|
+
for index, value in enumerate(values):
|
|
351
|
+
advantages[index] += weight * (value - mean) / std
|
|
352
|
+
if not any(value != 0.0 for value in advantages):
|
|
353
|
+
scores = [float(reward_result.score) for reward_result in reward_results]
|
|
354
|
+
mean = sum(scores) / len(scores) if scores else 0.0
|
|
355
|
+
std = (
|
|
356
|
+
math.sqrt(sum((value - mean) ** 2 for value in scores) / len(scores))
|
|
357
|
+
if scores
|
|
358
|
+
else 0.0
|
|
359
|
+
)
|
|
360
|
+
if std > 0.0:
|
|
361
|
+
advantages = [(value - mean) / std for value in scores]
|
|
362
|
+
return advantages
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def batch_normalized_advantages(
|
|
366
|
+
advantages: list[float],
|
|
367
|
+
*,
|
|
368
|
+
advantage_clip: float,
|
|
369
|
+
) -> list[float]:
|
|
370
|
+
"""Batch-wise normalization from the same paper.
|
|
371
|
+
|
|
372
|
+
Stabilizes advantage magnitude regardless of how many reward components
|
|
373
|
+
contributed; group advantages are zero-mean by construction, so this is
|
|
374
|
+
primarily a rescale. advantage_clip clamps the normalized values.
|
|
375
|
+
"""
|
|
376
|
+
if not advantages:
|
|
377
|
+
return []
|
|
378
|
+
mean = sum(advantages) / len(advantages)
|
|
379
|
+
std = math.sqrt(sum((value - mean) ** 2 for value in advantages) / len(advantages))
|
|
380
|
+
normalized = [(value - mean) / (std + _ADVANTAGE_EPSILON) for value in advantages]
|
|
284
381
|
if advantage_clip <= 0:
|
|
285
|
-
return
|
|
286
|
-
return [max(-advantage_clip, min(advantage_clip, value)) for value in
|
|
382
|
+
return normalized
|
|
383
|
+
return [max(-advantage_clip, min(advantage_clip, value)) for value in normalized]
|
|
287
384
|
|
|
288
385
|
|
|
289
386
|
def build_importance_sampling_trace_datum(
|
|
@@ -32,22 +32,26 @@ from freesolo.training.grpo.sampling import (
|
|
|
32
32
|
from freesolo.training.storage import attach_stored_training_run
|
|
33
33
|
from freesolo.training.types import (
|
|
34
34
|
DEFAULT_TRAINING_LORA_RANK,
|
|
35
|
-
TRAINING_BASE_MODEL,
|
|
36
35
|
TrainGrpoOptions,
|
|
37
36
|
resolve_tinker_base_url,
|
|
37
|
+
resolve_training_model,
|
|
38
38
|
tinker_checkpoint_run_config,
|
|
39
39
|
tinker_run_config,
|
|
40
40
|
)
|
|
41
41
|
from freesolo.utils.checkpoints import (
|
|
42
42
|
CheckpointUtils,
|
|
43
|
+
ensure_log_dir_base_model,
|
|
43
44
|
get_last_tinker_checkpoint,
|
|
45
|
+
has_training_state,
|
|
44
46
|
next_training_position,
|
|
47
|
+
read_log_dir_base_model,
|
|
45
48
|
resolve_checkpoint_sampler_path,
|
|
46
49
|
resolve_checkpoint_state_path,
|
|
47
50
|
resolve_sft_sampler_path,
|
|
48
51
|
resolve_sft_state_path,
|
|
49
52
|
resolve_training_position,
|
|
50
53
|
save_tinker_checkpoint,
|
|
54
|
+
sft_state_path_in_log_dir,
|
|
51
55
|
write_training_progress,
|
|
52
56
|
)
|
|
53
57
|
from freesolo.utils.core import load_dotenv_if_available, required_path
|
|
@@ -70,6 +74,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
70
74
|
parser.add_argument("--sft-state-path")
|
|
71
75
|
parser.add_argument("--reward-command")
|
|
72
76
|
parser.add_argument("--base-url")
|
|
77
|
+
parser.add_argument("--base-model")
|
|
73
78
|
return parser.parse_args()
|
|
74
79
|
|
|
75
80
|
|
|
@@ -84,7 +89,9 @@ async def train_grpo_async(
|
|
|
84
89
|
sft_state_path: str | None = None,
|
|
85
90
|
reward_command: str | None = None,
|
|
86
91
|
base_url: str | None = None,
|
|
92
|
+
base_model: str | None = None,
|
|
87
93
|
) -> int:
|
|
94
|
+
resolved_base_model = resolve_training_model(base_model)
|
|
88
95
|
try:
|
|
89
96
|
import numpy
|
|
90
97
|
import tinker
|
|
@@ -112,12 +119,37 @@ async def train_grpo_async(
|
|
|
112
119
|
if not examples:
|
|
113
120
|
raise RuntimeError(f"No GRPO records found in {dataset_path}")
|
|
114
121
|
|
|
115
|
-
tokenizer = get_tokenizer(
|
|
122
|
+
tokenizer = get_tokenizer(resolved_base_model)
|
|
116
123
|
renderer = renderers.get_renderer(grpo_config.renderer_name, tokenizer)
|
|
117
124
|
resolved_tinker_base_url = resolve_tinker_base_url(base_url)
|
|
118
125
|
service_client = tinker.ServiceClient(base_url=resolved_tinker_base_url or None)
|
|
119
126
|
log_dir = Path(log_dir)
|
|
120
127
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
128
|
+
if sft_log_dir is not None:
|
|
129
|
+
sft_base_model = read_log_dir_base_model(sft_log_dir)
|
|
130
|
+
if sft_base_model is None and has_training_state(sft_log_dir):
|
|
131
|
+
raise RuntimeError(
|
|
132
|
+
f"sft_log_dir {sft_log_dir} holds training state but no "
|
|
133
|
+
"base_model.json marker, so its model cannot be verified. "
|
|
134
|
+
"Train SFT for this model first."
|
|
135
|
+
)
|
|
136
|
+
if sft_base_model is not None and sft_base_model != resolved_base_model:
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
f"sft_log_dir {sft_log_dir} holds checkpoints for base model "
|
|
139
|
+
f"{sft_base_model!r}; GRPO cannot initialize "
|
|
140
|
+
f"{resolved_base_model!r} from them. Train SFT for this model "
|
|
141
|
+
"first or point sft_log_dir at a matching run."
|
|
142
|
+
)
|
|
143
|
+
if sft_state_path and not sft_state_path_in_log_dir(sft_state_path, sft_log_dir):
|
|
144
|
+
print(
|
|
145
|
+
f"[freesolo] GRPO initializing from explicit sft_state_path with "
|
|
146
|
+
f"base_model {resolved_base_model!r}; the path is not recorded in "
|
|
147
|
+
"the validated sft_log_dir, so its remote tinker:// lineage "
|
|
148
|
+
"cannot be verified locally - make sure that state came from "
|
|
149
|
+
"the same base model.",
|
|
150
|
+
file=sys.stderr,
|
|
151
|
+
)
|
|
152
|
+
ensure_log_dir_base_model(log_dir, resolved_base_model)
|
|
121
153
|
run_name = f"freesolo-grpo-{log_dir.name}"
|
|
122
154
|
run_config = {
|
|
123
155
|
"phase": "grpo",
|
|
@@ -128,7 +160,7 @@ async def train_grpo_async(
|
|
|
128
160
|
"sft_log_dir": str(sft_log_dir) if sft_log_dir is not None else None,
|
|
129
161
|
"sft_checkpoint_name": sft_checkpoint_name,
|
|
130
162
|
"sft_state_path": sft_state_path,
|
|
131
|
-
"base_model":
|
|
163
|
+
"base_model": resolved_base_model,
|
|
132
164
|
"renderer": grpo_config.renderer_name,
|
|
133
165
|
"lora_rank": DEFAULT_TRAINING_LORA_RANK,
|
|
134
166
|
"batch_size": grpo_config.batch_size,
|
|
@@ -188,7 +220,7 @@ async def train_grpo_async(
|
|
|
188
220
|
else:
|
|
189
221
|
training_client = await asyncio.to_thread(
|
|
190
222
|
service_client.create_lora_training_client,
|
|
191
|
-
base_model=
|
|
223
|
+
base_model=resolved_base_model,
|
|
192
224
|
rank=DEFAULT_TRAINING_LORA_RANK,
|
|
193
225
|
)
|
|
194
226
|
kl_reference_sampling_client = None
|
|
@@ -207,7 +239,7 @@ async def train_grpo_async(
|
|
|
207
239
|
else:
|
|
208
240
|
kl_reference_sampling_client = await asyncio.to_thread(
|
|
209
241
|
service_client.create_sampling_client,
|
|
210
|
-
base_model=
|
|
242
|
+
base_model=resolved_base_model,
|
|
211
243
|
)
|
|
212
244
|
|
|
213
245
|
# renderer stop sequences may be token ids or strings; the grpo config
|
|
@@ -377,6 +409,7 @@ async def train_grpo_async(
|
|
|
377
409
|
training_client = await _warm_restart_training_client(
|
|
378
410
|
service_client=service_client,
|
|
379
411
|
state_path=recovery_state_path,
|
|
412
|
+
base_model=resolved_base_model,
|
|
380
413
|
)
|
|
381
414
|
global_step, epoch, batch_index = recovery_position
|
|
382
415
|
sampler_state = SamplerState(
|
|
@@ -497,11 +530,6 @@ async def train_grpo_async(
|
|
|
497
530
|
|
|
498
531
|
def train_grpo(**kwargs: Unpack[TrainGrpoOptions]) -> int:
|
|
499
532
|
load_dotenv_if_available()
|
|
500
|
-
if "base_model" in kwargs:
|
|
501
|
-
raise TypeError(
|
|
502
|
-
"train_grpo() does not accept base_model; Freesolo training is pinned "
|
|
503
|
-
f"to {TRAINING_BASE_MODEL}"
|
|
504
|
-
)
|
|
505
533
|
return asyncio.run(train_grpo_async(**kwargs))
|
|
506
534
|
|
|
507
535
|
|
|
@@ -590,6 +618,7 @@ async def _warm_restart_training_client(
|
|
|
590
618
|
*,
|
|
591
619
|
service_client: Any,
|
|
592
620
|
state_path: str | None,
|
|
621
|
+
base_model: str,
|
|
593
622
|
) -> Any:
|
|
594
623
|
"""Open a fresh Tinker training session from the last durable state.
|
|
595
624
|
|
|
@@ -606,7 +635,7 @@ async def _warm_restart_training_client(
|
|
|
606
635
|
)
|
|
607
636
|
return await asyncio.to_thread(
|
|
608
637
|
service_client.create_lora_training_client,
|
|
609
|
-
base_model=
|
|
638
|
+
base_model=base_model,
|
|
610
639
|
rank=DEFAULT_TRAINING_LORA_RANK,
|
|
611
640
|
)
|
|
612
641
|
|
|
@@ -745,6 +774,7 @@ async def main_async() -> int:
|
|
|
745
774
|
sft_state_path=args.sft_state_path,
|
|
746
775
|
reward_command=args.reward_command,
|
|
747
776
|
base_url=args.base_url,
|
|
777
|
+
base_model=args.base_model,
|
|
748
778
|
)
|
|
749
779
|
|
|
750
780
|
|
|
@@ -16,16 +16,17 @@ from freesolo.datasets import load_dataset
|
|
|
16
16
|
from freesolo.environments.base import load_environment
|
|
17
17
|
from freesolo.training.storage import attach_stored_training_run
|
|
18
18
|
from freesolo.training.types import (
|
|
19
|
-
TRAINING_BASE_MODEL,
|
|
20
19
|
SftConfig,
|
|
21
20
|
TrainSftOptions,
|
|
22
21
|
resolve_sft_config,
|
|
23
22
|
resolve_tinker_base_url,
|
|
23
|
+
resolve_training_model,
|
|
24
24
|
tinker_checkpoint_run_config,
|
|
25
25
|
tinker_run_config,
|
|
26
26
|
)
|
|
27
27
|
from freesolo.utils.checkpoints import (
|
|
28
28
|
checkpoint_step,
|
|
29
|
+
ensure_log_dir_base_model,
|
|
29
30
|
get_last_tinker_checkpoint,
|
|
30
31
|
resolve_checkpoint_sampler_path,
|
|
31
32
|
resolve_checkpoint_state_path,
|
|
@@ -42,6 +43,7 @@ def _parse_args() -> argparse.Namespace:
|
|
|
42
43
|
parser.add_argument("--environment")
|
|
43
44
|
parser.add_argument("--log-dir", default="./logs/sft")
|
|
44
45
|
parser.add_argument("--base-url")
|
|
46
|
+
parser.add_argument("--base-model")
|
|
45
47
|
parser.add_argument("--max-length", type=int, required=True)
|
|
46
48
|
return parser.parse_args()
|
|
47
49
|
|
|
@@ -51,17 +53,13 @@ def train_sft(**kwargs: Unpack[TrainSftOptions]) -> int:
|
|
|
51
53
|
raise TypeError(
|
|
52
54
|
"train_sft() missing required keyword-only argument: 'dataset_path'"
|
|
53
55
|
)
|
|
54
|
-
if "base_model" in kwargs:
|
|
55
|
-
raise TypeError(
|
|
56
|
-
"train_sft() does not accept base_model; Freesolo training is pinned "
|
|
57
|
-
f"to {TRAINING_BASE_MODEL}"
|
|
58
|
-
)
|
|
59
56
|
return _train_sft(
|
|
60
57
|
contract_path=kwargs.get("contract_path", "TRAINING_CONTRACT.md"),
|
|
61
58
|
dataset_path=kwargs["dataset_path"],
|
|
62
59
|
environment=kwargs.get("environment"),
|
|
63
60
|
log_dir=kwargs.get("log_dir", "./logs/sft"),
|
|
64
61
|
base_url=kwargs.get("base_url"),
|
|
62
|
+
base_model=kwargs.get("base_model"),
|
|
65
63
|
sft_config=kwargs.get("sft_config"),
|
|
66
64
|
)
|
|
67
65
|
|
|
@@ -73,9 +71,11 @@ def _train_sft(
|
|
|
73
71
|
environment: str | None = None,
|
|
74
72
|
log_dir: str | Path = "./logs/sft",
|
|
75
73
|
base_url: str | None = None,
|
|
74
|
+
base_model: str | None = None,
|
|
76
75
|
sft_config: SftConfig | None = None,
|
|
77
76
|
) -> int:
|
|
78
77
|
load_dotenv_if_available()
|
|
78
|
+
resolved_base_model = resolve_training_model(base_model)
|
|
79
79
|
try:
|
|
80
80
|
import tinker
|
|
81
81
|
from tinker_cookbook import checkpoint_utils, renderers
|
|
@@ -114,12 +114,13 @@ def _train_sft(
|
|
|
114
114
|
f"No assistant turns found in SFT records from {dataset_path}"
|
|
115
115
|
)
|
|
116
116
|
|
|
117
|
-
tokenizer = get_tokenizer(
|
|
117
|
+
tokenizer = get_tokenizer(resolved_base_model)
|
|
118
118
|
renderer = renderers.get_renderer(resolved_sft_config.renderer_name, tokenizer)
|
|
119
119
|
resolved_tinker_base_url = resolve_tinker_base_url(base_url)
|
|
120
120
|
service_client = tinker.ServiceClient(base_url=resolved_tinker_base_url or None)
|
|
121
121
|
log_dir = Path(log_dir)
|
|
122
122
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
ensure_log_dir_base_model(log_dir, resolved_base_model)
|
|
123
124
|
run_name = f"freesolo-sft-{log_dir.name}"
|
|
124
125
|
run_config = {
|
|
125
126
|
"phase": "sft",
|
|
@@ -127,7 +128,7 @@ def _train_sft(
|
|
|
127
128
|
"dataset_path": str(dataset_path),
|
|
128
129
|
"environment": environment,
|
|
129
130
|
"log_dir": str(log_dir),
|
|
130
|
-
"base_model":
|
|
131
|
+
"base_model": resolved_base_model,
|
|
131
132
|
"renderer": resolved_sft_config.renderer_name,
|
|
132
133
|
"batch_size": batch_size,
|
|
133
134
|
"learning_rate": learning_rate,
|
|
@@ -177,7 +178,7 @@ def _train_sft(
|
|
|
177
178
|
)
|
|
178
179
|
else:
|
|
179
180
|
training_client = service_client.create_lora_training_client(
|
|
180
|
-
base_model=
|
|
181
|
+
base_model=resolved_base_model,
|
|
181
182
|
rank=lora_rank,
|
|
182
183
|
)
|
|
183
184
|
start_step = 0
|
|
@@ -317,6 +318,7 @@ def main() -> int:
|
|
|
317
318
|
environment=args.environment,
|
|
318
319
|
log_dir=args.log_dir,
|
|
319
320
|
base_url=args.base_url,
|
|
321
|
+
base_model=args.base_model,
|
|
320
322
|
sft_config=SftConfig(max_length=args.max_length),
|
|
321
323
|
)
|
|
322
324
|
|
|
@@ -6,12 +6,23 @@ from typing import Any, TypedDict
|
|
|
6
6
|
from typing_extensions import Required
|
|
7
7
|
|
|
8
8
|
TRAINING_BASE_MODEL = "Qwen/Qwen3.6-35B-A3B"
|
|
9
|
+
SUPPORTED_TRAINING_MODELS = ("Qwen/Qwen3.6-35B-A3B", "Qwen/Qwen3.5-4B")
|
|
9
10
|
TRAINING_RENDERER_NAME = "qwen3_5_disable_thinking"
|
|
10
11
|
SUPPORTED_TRAINING_RENDERERS = ("qwen3_5", "qwen3_5_disable_thinking")
|
|
11
12
|
DEFAULT_TRAINING_LORA_RANK = 64
|
|
12
13
|
TINKER_BASE_URL_ENV = "TINKER_BASE_URL"
|
|
13
14
|
|
|
14
15
|
|
|
16
|
+
def resolve_training_model(base_model: str | None) -> str:
|
|
17
|
+
if base_model is None:
|
|
18
|
+
return TRAINING_BASE_MODEL
|
|
19
|
+
resolved = str(base_model).strip()
|
|
20
|
+
if resolved not in SUPPORTED_TRAINING_MODELS:
|
|
21
|
+
supported = ", ".join(SUPPORTED_TRAINING_MODELS)
|
|
22
|
+
raise ValueError(f"base_model must be one of: {supported}; got {base_model!r}")
|
|
23
|
+
return resolved
|
|
24
|
+
|
|
25
|
+
|
|
15
26
|
def resolve_training_renderer(renderer_name: str | None) -> str:
|
|
16
27
|
if renderer_name is None:
|
|
17
28
|
return TRAINING_RENDERER_NAME
|
|
@@ -124,6 +135,7 @@ class TrainSftOptions(TypedDict, total=False):
|
|
|
124
135
|
environment: str | None
|
|
125
136
|
log_dir: str | Path
|
|
126
137
|
base_url: str | None
|
|
138
|
+
base_model: str | None
|
|
127
139
|
sft_config: SftConfig | None
|
|
128
140
|
|
|
129
141
|
|
|
@@ -137,10 +149,12 @@ class TrainGrpoOptions(TypedDict, total=False):
|
|
|
137
149
|
sft_state_path: str | None
|
|
138
150
|
reward_command: str | None
|
|
139
151
|
base_url: str | None
|
|
152
|
+
base_model: str | None
|
|
140
153
|
|
|
141
154
|
|
|
142
155
|
__all__ = [
|
|
143
156
|
"DEFAULT_TRAINING_LORA_RANK",
|
|
157
|
+
"SUPPORTED_TRAINING_MODELS",
|
|
144
158
|
"SUPPORTED_TRAINING_RENDERERS",
|
|
145
159
|
"TRAINING_BASE_MODEL",
|
|
146
160
|
"TRAINING_RENDERER_NAME",
|
|
@@ -150,6 +164,7 @@ __all__ = [
|
|
|
150
164
|
"TrainSftOptions",
|
|
151
165
|
"resolve_sft_config",
|
|
152
166
|
"resolve_tinker_base_url",
|
|
167
|
+
"resolve_training_model",
|
|
153
168
|
"resolve_training_renderer",
|
|
154
169
|
"tinker_checkpoint_run_config",
|
|
155
170
|
"tinker_run_config",
|
|
@@ -201,6 +201,20 @@ def resolve_checkpoint_state_path(record: object | None) -> str | None:
|
|
|
201
201
|
return str(state_path) if state_path else None
|
|
202
202
|
|
|
203
203
|
|
|
204
|
+
def sft_state_path_in_log_dir(
|
|
205
|
+
state_path: str,
|
|
206
|
+
sft_log_dir: str | Path | None,
|
|
207
|
+
) -> bool:
|
|
208
|
+
"""Whether an explicit SFT state path is recorded in the (already
|
|
209
|
+
marker-validated) SFT log dir, tying its lineage to that dir's model."""
|
|
210
|
+
if sft_log_dir is None:
|
|
211
|
+
return False
|
|
212
|
+
return any(
|
|
213
|
+
resolve_checkpoint_state_path(record) == state_path
|
|
214
|
+
for record in read_checkpoint_records(sft_log_dir)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
204
218
|
def resolve_checkpoint_sampler_path(record: object | None) -> str | None:
|
|
205
219
|
sampler_path = checkpoint_value(record, "sampler_path")
|
|
206
220
|
return str(sampler_path) if sampler_path else None
|
|
@@ -282,3 +296,56 @@ def save_tinker_checkpoint(
|
|
|
282
296
|
loop_state=loop_state,
|
|
283
297
|
kind=kind,
|
|
284
298
|
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
_BASE_MODEL_MARKER = "base_model.json"
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def read_log_dir_base_model(log_dir: str | Path) -> str | None:
|
|
305
|
+
marker = Path(log_dir) / _BASE_MODEL_MARKER
|
|
306
|
+
if not marker.is_file():
|
|
307
|
+
return None
|
|
308
|
+
try:
|
|
309
|
+
recorded = json.loads(marker.read_text(encoding="utf-8")).get("base_model")
|
|
310
|
+
except (OSError, ValueError):
|
|
311
|
+
return None
|
|
312
|
+
return recorded if isinstance(recorded, str) and recorded else None
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def has_training_state(log_dir: str | Path) -> bool:
|
|
316
|
+
base = Path(log_dir)
|
|
317
|
+
return any(
|
|
318
|
+
(base / name).exists()
|
|
319
|
+
for name in (
|
|
320
|
+
"checkpoints.jsonl",
|
|
321
|
+
TRAINING_PROGRESS_FILENAME,
|
|
322
|
+
"kl_reference.json",
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def ensure_log_dir_base_model(log_dir: str | Path, base_model: str) -> None:
|
|
328
|
+
"""Pin a log dir to one base model.
|
|
329
|
+
|
|
330
|
+
Checkpoint resume loads optimizer/weight state blindly from the log dir, so
|
|
331
|
+
continuing a dir produced by a different base model would silently mix
|
|
332
|
+
weights and tokenizer. A dir that already holds training state without a
|
|
333
|
+
marker cannot be attributed to any model, so it is refused outright.
|
|
334
|
+
"""
|
|
335
|
+
recorded = read_log_dir_base_model(log_dir)
|
|
336
|
+
if recorded is None and has_training_state(log_dir):
|
|
337
|
+
raise RuntimeError(
|
|
338
|
+
f"log_dir {log_dir} holds training state but no base_model.json "
|
|
339
|
+
"marker, so its model cannot be verified. Use a fresh log_dir."
|
|
340
|
+
)
|
|
341
|
+
if recorded is not None and recorded != base_model:
|
|
342
|
+
raise RuntimeError(
|
|
343
|
+
f"log_dir {log_dir} holds training state for base model "
|
|
344
|
+
f"{recorded!r}; refusing to continue it with {base_model!r}. "
|
|
345
|
+
"Use a fresh log_dir to train a different model."
|
|
346
|
+
)
|
|
347
|
+
if recorded is None:
|
|
348
|
+
marker = Path(log_dir) / _BASE_MODEL_MARKER
|
|
349
|
+
marker.write_text(
|
|
350
|
+
json.dumps({"base_model": base_model}) + "\n", encoding="utf-8"
|
|
351
|
+
)
|