@mlx-node/trl 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. package/dist/data/dataset.d.ts +22 -0
  2. package/dist/data/dataset.d.ts.map +1 -0
  3. package/dist/data/dataset.js +142 -0
  4. package/dist/data/sft-dataset.d.ts +156 -0
  5. package/dist/data/sft-dataset.d.ts.map +1 -0
  6. package/dist/data/sft-dataset.js +415 -0
  7. package/dist/index.d.ts +33 -0
  8. package/dist/index.d.ts.map +1 -0
  9. package/dist/index.js +47 -0
  10. package/dist/trainers/grpo-config.d.ts +42 -0
  11. package/dist/trainers/grpo-config.d.ts.map +1 -0
  12. package/dist/trainers/grpo-config.js +220 -0
  13. package/dist/trainers/grpo-entropy.d.ts +33 -0
  14. package/dist/trainers/grpo-entropy.d.ts.map +1 -0
  15. package/dist/trainers/grpo-entropy.js +18 -0
  16. package/dist/trainers/grpo-trainer.d.ts +602 -0
  17. package/dist/trainers/grpo-trainer.d.ts.map +1 -0
  18. package/dist/trainers/grpo-trainer.js +1439 -0
  19. package/dist/trainers/sft-config.d.ts +32 -0
  20. package/dist/trainers/sft-config.d.ts.map +1 -0
  21. package/dist/trainers/sft-config.js +186 -0
  22. package/dist/trainers/sft-trainer.d.ts +141 -0
  23. package/dist/trainers/sft-trainer.d.ts.map +1 -0
  24. package/dist/trainers/sft-trainer.js +502 -0
  25. package/dist/trainers/training-logger.d.ts +375 -0
  26. package/dist/trainers/training-logger.d.ts.map +1 -0
  27. package/dist/trainers/training-logger.js +542 -0
  28. package/dist/types.d.ts +54 -0
  29. package/dist/types.d.ts.map +1 -0
  30. package/dist/types.js +1 -0
  31. package/dist/utils/path-security.d.ts +51 -0
  32. package/dist/utils/path-security.d.ts.map +1 -0
  33. package/dist/utils/path-security.js +69 -0
  34. package/dist/utils/xml-parser.d.ts +6 -0
  35. package/dist/utils/xml-parser.d.ts.map +1 -0
  36. package/dist/utils/xml-parser.js +184 -0
  37. package/package.json +29 -0
@@ -0,0 +1,220 @@
1
+ import { parse as parseToml } from '@std/toml';
2
+ import { readFileSync } from 'node:fs';
3
+ import { resolve as resolvePath } from 'node:path';
4
+ export class ConfigError extends Error {
5
+ constructor(message) {
6
+ super(message);
7
+ this.name = 'ConfigError';
8
+ }
9
+ }
10
+ const DEFAULT_CONFIG = Object.freeze({
11
+ model_name: 'Qwen/Qwen2.5-1.5B-Instruct',
12
+ output_dir: 'outputs/Qwen-1.5B-MLX-GRPO',
13
+ run_name: 'Qwen-1.5B-MLX-GRPO-gsm8k',
14
+ learning_rate: 1e-6,
15
+ batch_size: 1,
16
+ gradient_accumulation_steps: 4,
17
+ num_epochs: 1,
18
+ max_train_samples: 0,
19
+ warmup_ratio: 0.1,
20
+ max_grad_norm: 0.1,
21
+ logging_steps: 1,
22
+ num_generations: 64,
23
+ max_prompt_length: 512,
24
+ max_completion_length: 1024,
25
+ max_new_tokens: 512,
26
+ temperature: 0.7,
27
+ clip_eps: 0.2,
28
+ kl_coeff: 0.0,
29
+ adam_beta1: 0.9,
30
+ adam_beta2: 0.999,
31
+ weight_decay: 0.0,
32
+ lr_scheduler_type: 'cosine',
33
+ save_steps: 100,
34
+ eval_steps: 50,
35
+ eval_samples: 200,
36
+ seed: 0,
37
+ use_compile: true,
38
+ quantize_for_rollouts: true,
39
+ eval_every_updates: 25,
40
+ eval_subset_size: 200,
41
+ eval_max_new_tokens: 128,
42
+ log_jsonl: true,
43
+ });
44
+ const CONFIG_VALUE_TYPES = {
45
+ model_name: 'string',
46
+ output_dir: 'string',
47
+ run_name: 'string',
48
+ learning_rate: 'number',
49
+ batch_size: 'number',
50
+ gradient_accumulation_steps: 'number',
51
+ num_epochs: 'number',
52
+ max_train_samples: 'number',
53
+ warmup_ratio: 'number',
54
+ max_grad_norm: 'number',
55
+ logging_steps: 'number',
56
+ num_generations: 'number',
57
+ max_prompt_length: 'number',
58
+ max_completion_length: 'number',
59
+ max_new_tokens: 'number',
60
+ temperature: 'number',
61
+ clip_eps: 'number',
62
+ kl_coeff: 'number',
63
+ adam_beta1: 'number',
64
+ adam_beta2: 'number',
65
+ weight_decay: 'number',
66
+ lr_scheduler_type: 'string',
67
+ save_steps: 'number',
68
+ eval_steps: 'number',
69
+ eval_samples: 'number',
70
+ seed: 'number',
71
+ use_compile: 'boolean',
72
+ quantize_for_rollouts: 'boolean',
73
+ eval_every_updates: 'number',
74
+ eval_subset_size: 'number',
75
+ eval_max_new_tokens: 'number',
76
+ log_jsonl: 'boolean',
77
+ };
78
+ const INTEGER_KEYS = new Set([
79
+ 'batch_size',
80
+ 'gradient_accumulation_steps',
81
+ 'num_epochs',
82
+ 'max_train_samples',
83
+ 'logging_steps',
84
+ 'num_generations',
85
+ 'max_prompt_length',
86
+ 'max_completion_length',
87
+ 'max_new_tokens',
88
+ 'save_steps',
89
+ 'eval_steps',
90
+ 'eval_samples',
91
+ 'seed',
92
+ 'eval_every_updates',
93
+ 'eval_subset_size',
94
+ 'eval_max_new_tokens',
95
+ ]);
96
+ function cloneDefaults() {
97
+ return { ...DEFAULT_CONFIG };
98
+ }
99
+ function isConfigKey(value) {
100
+ return Object.prototype.hasOwnProperty.call(CONFIG_VALUE_TYPES, value);
101
+ }
102
+ function coerceBoolean(value, key) {
103
+ if (typeof value === 'boolean')
104
+ return value;
105
+ if (typeof value === 'string') {
106
+ const normalized = value.trim().toLowerCase();
107
+ if (['true', '1', 'yes', 'on'].includes(normalized))
108
+ return true;
109
+ if (['false', '0', 'no', 'off'].includes(normalized))
110
+ return false;
111
+ }
112
+ throw new ConfigError(`Invalid boolean for ${key}: ${String(value)}`);
113
+ }
114
+ function coerceNumber(value, key) {
115
+ if (typeof value === 'string' && value.trim() === '') {
116
+ throw new ConfigError(`Invalid number for ${key}: empty string`);
117
+ }
118
+ const parsed = typeof value === 'number' ? value : Number(value);
119
+ if (!Number.isFinite(parsed)) {
120
+ throw new ConfigError(`Invalid number for ${key}: ${String(value)}`);
121
+ }
122
+ if (INTEGER_KEYS.has(key) && !Number.isInteger(parsed)) {
123
+ throw new ConfigError(`Expected integer for ${key}, received ${parsed}`);
124
+ }
125
+ return parsed;
126
+ }
127
+ function coerceString(value, key) {
128
+ if (typeof value === 'string')
129
+ return value;
130
+ throw new ConfigError(`Invalid string for ${key}: ${String(value)}`);
131
+ }
132
+ function coerceValue(key, value) {
133
+ const expected = CONFIG_VALUE_TYPES[key];
134
+ if (expected === 'boolean') {
135
+ return coerceBoolean(value, key);
136
+ }
137
+ if (expected === 'number') {
138
+ return coerceNumber(value, key);
139
+ }
140
+ return coerceString(value, key);
141
+ }
142
+ /**
143
+ * Type-safe setter for config values.
144
+ * Uses a controlled type assertion to safely assign dynamically-typed values.
145
+ */
146
+ function setConfigValue(config, key, value) {
147
+ config[key] = value;
148
+ }
149
+ function normalizeTomlRecord(record) {
150
+ const normalized = {};
151
+ for (const [rawKey, rawValue] of Object.entries(record)) {
152
+ if (!isConfigKey(rawKey)) {
153
+ continue;
154
+ }
155
+ setConfigValue(normalized, rawKey, coerceValue(rawKey, rawValue));
156
+ }
157
+ return normalized;
158
+ }
159
+ export function getDefaultConfig() {
160
+ return cloneDefaults();
161
+ }
162
+ export function mergeConfig(base, update) {
163
+ if (!update) {
164
+ return { ...base };
165
+ }
166
+ const result = { ...base };
167
+ for (const [key, value] of Object.entries(update)) {
168
+ if (value === undefined)
169
+ continue;
170
+ if (!isConfigKey(key)) {
171
+ throw new ConfigError(`Unknown configuration key: ${key}`);
172
+ }
173
+ setConfigValue(result, key, value);
174
+ }
175
+ return result;
176
+ }
177
+ export function loadTomlConfig(filePath) {
178
+ const absolutePath = resolvePath(filePath);
179
+ let fileContents;
180
+ try {
181
+ fileContents = readFileSync(absolutePath, 'utf8');
182
+ }
183
+ catch (error) {
184
+ const message = error instanceof Error ? error.message : String(error);
185
+ throw new ConfigError(`Failed to read config at ${absolutePath}: ${message}`);
186
+ }
187
+ let parsedRaw;
188
+ try {
189
+ parsedRaw = parseToml(fileContents);
190
+ }
191
+ catch (error) {
192
+ const message = error instanceof Error ? error.message : String(error);
193
+ throw new ConfigError(`Failed to parse TOML at ${absolutePath}: ${message}`);
194
+ }
195
+ if (parsedRaw === null || typeof parsedRaw !== 'object' || Array.isArray(parsedRaw)) {
196
+ throw new ConfigError(`Expected table at ${absolutePath}`);
197
+ }
198
+ const parsed = parsedRaw;
199
+ const normalized = normalizeTomlRecord(parsed);
200
+ return mergeConfig(getDefaultConfig(), normalized);
201
+ }
202
+ export function applyOverrides(config, overrides) {
203
+ if (!overrides.length) {
204
+ return { ...config };
205
+ }
206
+ const accumulated = {};
207
+ for (const entry of overrides) {
208
+ const idx = entry.indexOf('=');
209
+ if (idx === -1) {
210
+ throw new ConfigError(`Invalid override "${entry}", expected key=value format`);
211
+ }
212
+ const key = entry.slice(0, idx).trim();
213
+ const rawValue = entry.slice(idx + 1).trim();
214
+ if (!isConfigKey(key)) {
215
+ throw new ConfigError(`Unknown configuration key in override: ${key}`);
216
+ }
217
+ setConfigValue(accumulated, key, coerceValue(key, rawValue));
218
+ }
219
+ return mergeConfig(config, accumulated);
220
+ }
@@ -0,0 +1,33 @@
1
+ /**
2
+ * Entropy filtering utilities for GRPO training
3
+ *
4
+ * Reference: trl/trl/trainer/grpo_trainer.py:get_high_entropy_mask
5
+ *
6
+ * Implements selective training on high-entropy (uncertain) tokens,
7
+ * which is a key optimization in GRPO to focus learning on challenging predictions.
8
+ *
9
+ * This module re-exports the Rust implementation for optimal performance.
10
+ * All entropy filtering operations are implemented in Rust (node/src/grpo_entropy.rs).
11
+ */
12
+ /**
13
+ * Configuration for entropy-based filtering in GRPO training
14
+ */
15
+ export interface EntropyFilteringConfig {
16
+ /**
17
+ * Whether to enable entropy filtering (default: false)
18
+ */
19
+ enabled: boolean;
20
+ /**
21
+ * Quantile threshold for selecting high-entropy tokens (default: 0.8)
22
+ * - 0.0: all non-pad tokens
23
+ * - 0.5: top 50% highest entropy
24
+ * - 0.8: top 20% highest entropy (recommended)
25
+ * - 1.0: only highest entropy token
26
+ */
27
+ topEntropyQuantile: number;
28
+ }
29
+ /**
30
+ * Default entropy filtering configuration
31
+ */
32
+ export declare const DEFAULT_ENTROPY_CONFIG: EntropyFilteringConfig;
33
+ //# sourceMappingURL=grpo-entropy.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"grpo-entropy.d.ts","sourceRoot":"","sources":["../../src/trainers/grpo-entropy.ts"],"names":[],"mappings":"AAAA;;;;;;;;;;GAUG;AAKH;;GAEG;AACH,MAAM,WAAW,sBAAsB;IACrC;;OAEG;IACH,OAAO,EAAE,OAAO,CAAC;IAEjB;;;;;;OAMG;IACH,kBAAkB,EAAE,MAAM,CAAC;CAC5B;AAED;;GAEG;AACH,eAAO,MAAM,sBAAsB,EAAE,sBAGpC,CAAC"}
@@ -0,0 +1,18 @@
1
+ /**
2
+ * Entropy filtering utilities for GRPO training
3
+ *
4
+ * Reference: trl/trl/trainer/grpo_trainer.py:get_high_entropy_mask
5
+ *
6
+ * Implements selective training on high-entropy (uncertain) tokens,
7
+ * which is a key optimization in GRPO to focus learning on challenging predictions.
8
+ *
9
+ * This module re-exports the Rust implementation for optimal performance.
10
+ * All entropy filtering operations are implemented in Rust (node/src/grpo_entropy.rs).
11
+ */
12
+ /**
13
+ * Default entropy filtering configuration
14
+ */
15
+ export const DEFAULT_ENTROPY_CONFIG = {
16
+ enabled: false,
17
+ topEntropyQuantile: 0.8,
18
+ };