mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
import json
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
from transformers import AutoTokenizer
|
|
6
|
+
|
|
7
|
+
REPLACEMENT_CHAR = "\ufffd"
|
|
8
|
+
|
|
9
|
+
def _remove_space(x):
|
|
10
|
+
if x and x[0] == " ":
|
|
11
|
+
return x[1:]
|
|
12
|
+
return x
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StreamingDetokenizer:
|
|
16
|
+
"""The streaming detokenizer interface so that we can detokenize one token at a time.
|
|
17
|
+
|
|
18
|
+
Example usage is as follows:
|
|
19
|
+
|
|
20
|
+
detokenizer = ...
|
|
21
|
+
|
|
22
|
+
# Reset the tokenizer state
|
|
23
|
+
detokenizer.reset()
|
|
24
|
+
|
|
25
|
+
for token in generate(...):
|
|
26
|
+
detokenizer.add_token(token.item())
|
|
27
|
+
|
|
28
|
+
# Contains the whole text so far. Some tokens may not be included
|
|
29
|
+
# since it contains whole words usually.
|
|
30
|
+
detokenizer.text
|
|
31
|
+
|
|
32
|
+
# Contains the printable segment (usually a word) since the last
|
|
33
|
+
# time it was accessed
|
|
34
|
+
detokenizer.last_segment
|
|
35
|
+
|
|
36
|
+
# Contains all the tokens added so far
|
|
37
|
+
detokenizer.tokens
|
|
38
|
+
|
|
39
|
+
# Make sure that we detokenize any remaining tokens
|
|
40
|
+
detokenizer.finalize()
|
|
41
|
+
|
|
42
|
+
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
__slots__ = ("text", "tokens", "offset")
|
|
46
|
+
|
|
47
|
+
def reset(self):
|
|
48
|
+
raise NotImplementedError()
|
|
49
|
+
|
|
50
|
+
def add_token(self, token):
|
|
51
|
+
raise NotImplementedError()
|
|
52
|
+
|
|
53
|
+
def finalize(self):
|
|
54
|
+
raise NotImplementedError()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def last_segment(self):
|
|
58
|
+
"""Return the last segment of readable text since last time this property was accessed."""
|
|
59
|
+
text = self.text
|
|
60
|
+
if text and text[-1] != REPLACEMENT_CHAR:
|
|
61
|
+
segment = text[self.offset :]
|
|
62
|
+
self.offset = len(text)
|
|
63
|
+
return segment
|
|
64
|
+
return ""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
|
68
|
+
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
|
|
69
|
+
implementation and should work with every tokenizer.
|
|
70
|
+
|
|
71
|
+
Its complexity is O(T^2) where T is the longest line since it will
|
|
72
|
+
repeatedly detokenize the same tokens until a new line is generated.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, tokenizer):
|
|
76
|
+
self._tokenizer = tokenizer
|
|
77
|
+
self._tokenizer.decode([0])
|
|
78
|
+
self.reset()
|
|
79
|
+
|
|
80
|
+
def reset(self):
|
|
81
|
+
self.offset = 0
|
|
82
|
+
self._tokens = []
|
|
83
|
+
self._text = ""
|
|
84
|
+
self._current_tokens = []
|
|
85
|
+
self._current_text = ""
|
|
86
|
+
|
|
87
|
+
def add_token(self, token):
|
|
88
|
+
self._current_tokens.append(token)
|
|
89
|
+
|
|
90
|
+
def finalize(self):
|
|
91
|
+
self._tokens.extend(self._current_tokens)
|
|
92
|
+
self._text += self._tokenizer.decode(self._current_tokens)
|
|
93
|
+
self._current_tokens = []
|
|
94
|
+
self._current_text = ""
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def text(self):
|
|
98
|
+
if self._current_tokens:
|
|
99
|
+
self._current_text = self._tokenizer.decode(self._current_tokens)
|
|
100
|
+
if (
|
|
101
|
+
self._tokenizer.clean_up_tokenization_spaces
|
|
102
|
+
and self._current_text[-1] == " "
|
|
103
|
+
):
|
|
104
|
+
self._current_text = self._current_text[:-1]
|
|
105
|
+
if self._current_text and self._current_text[-1] == "\n":
|
|
106
|
+
self._tokens.extend(self._current_tokens)
|
|
107
|
+
self._text += self._current_text
|
|
108
|
+
self._current_tokens.clear()
|
|
109
|
+
self._current_text = ""
|
|
110
|
+
return self._text + self._current_text
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def tokens(self):
|
|
114
|
+
return self._tokens
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class SPMStreamingDetokenizer(StreamingDetokenizer): ## SPM = SentencePiece Model, relevant for Llama and Gemma
|
|
118
|
+
"""A streaming detokenizer for SPM models.
|
|
119
|
+
|
|
120
|
+
It adds tokens to the text if the next token starts with the special SPM
|
|
121
|
+
underscore which results in linear complexity.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, tokenizer, trim_space=True):
|
|
125
|
+
self.trim_space = trim_space
|
|
126
|
+
|
|
127
|
+
# Extract the tokens in a list from id to text
|
|
128
|
+
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
|
|
129
|
+
for value, tokenid in tokenizer.vocab.items():
|
|
130
|
+
self.tokenmap[tokenid] = value
|
|
131
|
+
|
|
132
|
+
# Replace bytes with their value
|
|
133
|
+
for i in range(len(self.tokenmap)):
|
|
134
|
+
if self.tokenmap[i].startswith("<0x"):
|
|
135
|
+
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
|
|
136
|
+
|
|
137
|
+
self.reset()
|
|
138
|
+
|
|
139
|
+
def reset(self):
|
|
140
|
+
self.offset = 0
|
|
141
|
+
self._unflushed = ""
|
|
142
|
+
self.text = ""
|
|
143
|
+
self.tokens = []
|
|
144
|
+
|
|
145
|
+
def add_token(self, token):
|
|
146
|
+
self.tokens.append(token) ### append token to tokens list, not sure why it wasn't there in the first place as reset() sets an empty list
|
|
147
|
+
v = self.tokenmap[token]
|
|
148
|
+
if v[0] == "\u2581":
|
|
149
|
+
if self.text or not self.trim_space:
|
|
150
|
+
self.text += self._unflushed.replace("\u2581", " ")
|
|
151
|
+
else:
|
|
152
|
+
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
153
|
+
self._unflushed = v
|
|
154
|
+
else:
|
|
155
|
+
self._unflushed += v
|
|
156
|
+
|
|
157
|
+
def finalize(self):
|
|
158
|
+
if self.text or not self.trim_space:
|
|
159
|
+
self.text += self._unflushed.replace("\u2581", " ")
|
|
160
|
+
else:
|
|
161
|
+
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
|
|
162
|
+
self._unflushed = ""
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class BPEStreamingDetokenizer(StreamingDetokenizer): ## BPE = Byte Pair Encoding, relevant for GPT-2
|
|
166
|
+
"""A streaming detokenizer for OpenAI style BPE models.
|
|
167
|
+
|
|
168
|
+
It adds tokens to the text if the next token starts with a space similar to
|
|
169
|
+
the SPM detokenizer.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
_byte_decoder = None
|
|
173
|
+
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
|
|
174
|
+
|
|
175
|
+
def __init__(self, tokenizer):
|
|
176
|
+
|
|
177
|
+
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
|
178
|
+
|
|
179
|
+
# Extract the tokens in a list from id to text
|
|
180
|
+
self.tokenmap = [None] * len(tokenizer.vocab)
|
|
181
|
+
for value, tokenid in tokenizer.vocab.items():
|
|
182
|
+
self.tokenmap[tokenid] = value
|
|
183
|
+
|
|
184
|
+
self.reset()
|
|
185
|
+
|
|
186
|
+
# Make the BPE byte decoder from
|
|
187
|
+
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
|
188
|
+
self.make_byte_decoder()
|
|
189
|
+
|
|
190
|
+
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
|
|
191
|
+
|
|
192
|
+
def reset(self):
|
|
193
|
+
self.offset = 0
|
|
194
|
+
self._unflushed = ""
|
|
195
|
+
self.text = ""
|
|
196
|
+
self.tokens = []
|
|
197
|
+
|
|
198
|
+
def _maybe_trim_space(self, current_text):
|
|
199
|
+
if len(current_text) == 0:
|
|
200
|
+
return current_text
|
|
201
|
+
elif current_text[0] != " ":
|
|
202
|
+
return current_text
|
|
203
|
+
elif not self.text:
|
|
204
|
+
return current_text[1:]
|
|
205
|
+
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
|
|
206
|
+
return current_text[1:]
|
|
207
|
+
return current_text
|
|
208
|
+
|
|
209
|
+
def add_token(self, token):
|
|
210
|
+
self.tokens.append(token) ### append token to tokens list, not sure why it wasn't there in the first place as reset() sets an empty list
|
|
211
|
+
v = self.tokenmap[token]
|
|
212
|
+
is_added = token in self._added_ids
|
|
213
|
+
if is_added or self._byte_decoder[v[0]] == 32:
|
|
214
|
+
current_text = bytearray(
|
|
215
|
+
self._byte_decoder[c] for c in self._unflushed
|
|
216
|
+
).decode("utf-8")
|
|
217
|
+
self.text += self._maybe_trim_space(current_text)
|
|
218
|
+
if is_added:
|
|
219
|
+
self.text += v
|
|
220
|
+
self._unflushed = ""
|
|
221
|
+
else:
|
|
222
|
+
self._unflushed = v
|
|
223
|
+
else:
|
|
224
|
+
self._unflushed += v
|
|
225
|
+
|
|
226
|
+
def finalize(self):
|
|
227
|
+
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
|
228
|
+
"utf-8"
|
|
229
|
+
)
|
|
230
|
+
self.text += self._maybe_trim_space(current_text)
|
|
231
|
+
self._unflushed = ""
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def make_byte_decoder(cls):
|
|
235
|
+
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
|
|
236
|
+
if cls._byte_decoder is not None:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
char_to_bytes = {}
|
|
240
|
+
limits = [
|
|
241
|
+
0,
|
|
242
|
+
ord("!"),
|
|
243
|
+
ord("~") + 1,
|
|
244
|
+
ord("¡"),
|
|
245
|
+
ord("¬") + 1,
|
|
246
|
+
ord("®"),
|
|
247
|
+
ord("ÿ") + 1,
|
|
248
|
+
]
|
|
249
|
+
n = 0
|
|
250
|
+
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
|
|
251
|
+
if i % 2 == 0:
|
|
252
|
+
for b in range(start, stop):
|
|
253
|
+
char_to_bytes[chr(2**8 + n)] = b
|
|
254
|
+
n += 1
|
|
255
|
+
else:
|
|
256
|
+
for b in range(start, stop):
|
|
257
|
+
char_to_bytes[chr(b)] = b
|
|
258
|
+
cls._byte_decoder = char_to_bytes
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class TokenizerWrapper:
|
|
262
|
+
"""A wrapper that combines an HF tokenizer and a detokenizer.
|
|
263
|
+
|
|
264
|
+
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
|
265
|
+
huggingface tokenizer.
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
|
|
269
|
+
self._tokenizer = tokenizer
|
|
270
|
+
self._detokenizer = detokenizer_class(tokenizer)
|
|
271
|
+
|
|
272
|
+
def __getattr__(self, attr):
|
|
273
|
+
if attr == "detokenizer":
|
|
274
|
+
return self._detokenizer
|
|
275
|
+
elif attr.startswith("_"):
|
|
276
|
+
return self.__getattribute__(attr)
|
|
277
|
+
else:
|
|
278
|
+
return getattr(self._tokenizer, attr)
|
|
279
|
+
|
|
280
|
+
def __setattr__(self, attr, value):
|
|
281
|
+
if attr == "detokenizer":
|
|
282
|
+
raise AttributeError("Cannot set the detokenizer.")
|
|
283
|
+
elif attr.startswith("_"):
|
|
284
|
+
super().__setattr__(attr, value)
|
|
285
|
+
else:
|
|
286
|
+
setattr(self._tokenizer, attr, value)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _match(a, b):
|
|
290
|
+
if type(a) != type(b):
|
|
291
|
+
return False
|
|
292
|
+
if isinstance(a, dict):
|
|
293
|
+
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
|
|
294
|
+
if isinstance(a, list):
|
|
295
|
+
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
|
|
296
|
+
|
|
297
|
+
return a == b
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _is_spm_decoder(decoder):
|
|
301
|
+
_target_description = {
|
|
302
|
+
"type": "Sequence",
|
|
303
|
+
"decoders": [
|
|
304
|
+
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
|
305
|
+
{"type": "ByteFallback"},
|
|
306
|
+
{"type": "Fuse"},
|
|
307
|
+
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
|
|
308
|
+
],
|
|
309
|
+
}
|
|
310
|
+
return _match(_target_description, decoder)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _is_spm_decoder_no_space(decoder):
|
|
314
|
+
_target_description = {
|
|
315
|
+
"type": "Sequence",
|
|
316
|
+
"decoders": [
|
|
317
|
+
{"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
|
|
318
|
+
{"type": "ByteFallback"},
|
|
319
|
+
{"type": "Fuse"},
|
|
320
|
+
],
|
|
321
|
+
}
|
|
322
|
+
return _match(_target_description, decoder)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _is_bpe_decoder(decoder):
|
|
326
|
+
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
|
330
|
+
"""Load a huggingface tokenizer and try to infer the type of streaming
|
|
331
|
+
detokenizer to use.
|
|
332
|
+
|
|
333
|
+
Note, to use a fast streaming tokenizer, pass a local file path rather than
|
|
334
|
+
a Hugging Face repo ID.
|
|
335
|
+
"""
|
|
336
|
+
detokenizer_class = NaiveStreamingDetokenizer
|
|
337
|
+
|
|
338
|
+
tokenizer_file = model_path / "tokenizer.json"
|
|
339
|
+
if tokenizer_file.exists():
|
|
340
|
+
with open(tokenizer_file, "r") as fid:
|
|
341
|
+
tokenizer_content = json.load(fid)
|
|
342
|
+
if "decoder" in tokenizer_content:
|
|
343
|
+
if _is_spm_decoder(tokenizer_content["decoder"]):
|
|
344
|
+
detokenizer_class = SPMStreamingDetokenizer
|
|
345
|
+
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
|
|
346
|
+
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
|
|
347
|
+
elif _is_bpe_decoder(tokenizer_content["decoder"]):
|
|
348
|
+
detokenizer_class = BPEStreamingDetokenizer
|
|
349
|
+
|
|
350
|
+
return TokenizerWrapper(
|
|
351
|
+
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
|
|
352
|
+
detokenizer_class,
|
|
353
|
+
)
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
from mlx_raclate.utils.utils import load, PIPELINES
|
|
5
|
+
from mlx_raclate.tuner.datasets import load_dataset, DatasetArgs
|
|
6
|
+
from mlx_raclate.tuner.trainer import Trainer, TrainingArgs
|
|
7
|
+
|
|
8
|
+
train_tested = {
|
|
9
|
+
"text-classification": [
|
|
10
|
+
{"model": "Qwen/Qwen3-Embedding-0.6B", "special_model_config" : {}, "special_trainer_config" : {"use_chat_template": True}, "special_training_args" : {"max_length":8192}},
|
|
11
|
+
{"model": "answerdotai/ModernBERT-base", "special_model_config" : {}, "special_training_args" : {}},
|
|
12
|
+
{"model": "LiquidAI/LFM2-350M", "special_model_config" : {"use_chat_template": True}, "special_training_args" : {}},
|
|
13
|
+
{"model": "google/t5gemma-b-b-ul2", "special_model_config" : {}, "special_training_args" : {"max_length":8192}}, # failed
|
|
14
|
+
{"model": "google/embeddinggemma-300m", "special_model_config" : {}, "special_training_args" : {}} # failed
|
|
15
|
+
],
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
# LFM2 CHAT TEMPLATE
|
|
19
|
+
FORCED_CHAT_TEMPLATE = """
|
|
20
|
+
{- bos_token -}}
|
|
21
|
+
{%- set system_prompt = "" -%}
|
|
22
|
+
{%- set ns = namespace(system_prompt="") -%}
|
|
23
|
+
{%- if messages[0]["role"] == "system" -%}
|
|
24
|
+
{%- set ns.system_prompt = messages[0]["content"] -%}
|
|
25
|
+
{%- set messages = messages[1:] -%}
|
|
26
|
+
{%- endif -%}
|
|
27
|
+
{%- if tools -%}
|
|
28
|
+
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
|
|
29
|
+
{%- for tool in tools -%}
|
|
30
|
+
{%- if tool is not string -%}
|
|
31
|
+
{%- set tool = tool | tojson -%}
|
|
32
|
+
{%- endif -%}
|
|
33
|
+
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
|
34
|
+
{%- if not loop.last -%}
|
|
35
|
+
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
|
36
|
+
{%- endif -%}
|
|
37
|
+
{%- endfor -%}
|
|
38
|
+
{%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
|
|
39
|
+
{%- endif -%}
|
|
40
|
+
{%- if ns.system_prompt -%}
|
|
41
|
+
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
|
42
|
+
{%- endif -%}
|
|
43
|
+
{%- for message in messages -%}
|
|
44
|
+
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
|
45
|
+
{%- set content = message["content"] -%}
|
|
46
|
+
{%- if content is not string -%}
|
|
47
|
+
{%- set content = content | tojson -%}
|
|
48
|
+
{%- endif -%}
|
|
49
|
+
{%- if message["role"] == "tool" -%}
|
|
50
|
+
{%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
|
|
51
|
+
{%- endif -%}
|
|
52
|
+
{{- content + "<|im_end|>\n" -}}
|
|
53
|
+
{%- endfor -%}
|
|
54
|
+
{%- if add_generation_prompt -%}
|
|
55
|
+
{{- "<|im_start|>assistant\n" -}}
|
|
56
|
+
{%- endif -%}
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
DEFAULT_MODEL_PATH : str = "./trained_models/Qwen3-Embedding-0.6B_text-classification_20251219_001137/checkpoint-39939" #"Qwen/Qwen3-Embedding-0.6B" "answerdotai/ModernBERT-base" "google/t5gemma-b-b-ul2"
|
|
60
|
+
DEFAULT_DATASET : str = "data/wines" # can be a local path or HF "argilla/synthetic-domain-text-classification" "data/20251205_1125"
|
|
61
|
+
DEFAULT_TASK_TYPE : str = "text-classification"
|
|
62
|
+
DEFAULT_BATCH_SIZE : int = 8
|
|
63
|
+
DEFAULT_GRADIENT_ACCUMULATION_STEPS : int = 8
|
|
64
|
+
DEFAULT_TRAIN_EPOCHS : int = 2
|
|
65
|
+
DEFAULT_WEIGHT_DECAY : float = 0.01
|
|
66
|
+
DEFAULT_LR : float = 2e-5 # 3e-5 for ModernBERT, 5e-5 for T5Gemma, 1e-5 for Qwen
|
|
67
|
+
DEFAULT_LR_SCHEDULER_TYPE : str = "linear_schedule"
|
|
68
|
+
DEFAULT_MIN_LR : float = 2e-6
|
|
69
|
+
DEFAULT_WARMUP_RATIO : float = 0.03
|
|
70
|
+
DEFAULT_WARMUP_STEPS : int = 0
|
|
71
|
+
DEFAULT_SAVE_STEPS : int = 5000
|
|
72
|
+
DEFAULT_LOGGING_STEPS : int = 64
|
|
73
|
+
DEFAULT_EVAL_BATCH_SIZE : int = 8
|
|
74
|
+
|
|
75
|
+
def init_args():
|
|
76
|
+
parser = argparse.ArgumentParser(description="Train or evaluate a classification model using MLX Raclate.")
|
|
77
|
+
# Dataset Init Params
|
|
78
|
+
parser.add_argument("--dataset", type=str, default=DEFAULT_DATASET, help="Local path or HF identifier of the dataset to use for training/evaluation.")
|
|
79
|
+
parser.add_argument("--text_field", type=str, default=None, help="Name of the text field in the dataset (if different from default).")
|
|
80
|
+
parser.add_argument("--text_pair_field", type=str, default=None, help="Name of the text pair field in the dataset (if applicable and different from default).")
|
|
81
|
+
parser.add_argument("--label_field", type=str, default=None, help="Name of the label field in the dataset (if different from default).")
|
|
82
|
+
parser.add_argument("--negative_field", type=str, default=None, help="Name of the negative samples field in the dataset (if applicable and different from default).")
|
|
83
|
+
parser.add_argument("--create_test", action='store_true', help="Set this flag to create a test split, if not already present in the dataset, out of the training set (validation set not affected).")
|
|
84
|
+
|
|
85
|
+
# Trainer / End Model Init Params
|
|
86
|
+
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Path to the pre-trained model or model identifier from a model hub.")
|
|
87
|
+
parser.add_argument("--task_type", type=str, default=DEFAULT_TASK_TYPE, help="Type of task (default: text-classification).")
|
|
88
|
+
parser.add_argument("--is_regression", default=False, action='store_true', help="Set this flag if the task is regression.")
|
|
89
|
+
parser.add_argument("--use_late_interaction", default=False, action='store_true', help="Set this flag to use late interaction for retrieval tasks (if applicable).")
|
|
90
|
+
parser.add_argument("--eval_only", dest="train", action='store_false', help="Set this flag to skip training and only evaluate.")
|
|
91
|
+
parser.add_argument("--use_chat_template", default=False, action='store_true', help="Use chat template for decoder models when there are text pairs.")
|
|
92
|
+
parser.add_argument("--force_separator", type=str, default=None, help="Force a specific separator between text pairs for decoder models, if not using chat template.")
|
|
93
|
+
|
|
94
|
+
# Training Params
|
|
95
|
+
parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Batch size for training.")
|
|
96
|
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=DEFAULT_GRADIENT_ACCUMULATION_STEPS, help="Number of gradient accumulation steps.")
|
|
97
|
+
parser.add_argument("--num_train_epochs", type=int, default=DEFAULT_TRAIN_EPOCHS, help="Number of training epochs.")
|
|
98
|
+
parser.add_argument("--max_length", type=int, default=None, help="Maximum sequence length for the model inputs. If not specified, the model's default max length will be used.")
|
|
99
|
+
parser.add_argument("--freeze_embeddings", default=False, action='store_true', help="Set this flag to freeze embedding layers during training.")
|
|
100
|
+
parser.add_argument("--weight_decay", type=float, default=DEFAULT_WEIGHT_DECAY, help="Weight decay for the optimizer.")
|
|
101
|
+
# Optimizer and Scheduler Params (AdamW + schedulers)
|
|
102
|
+
parser.add_argument("--lr", type=float, default=DEFAULT_LR, help="Initial learning rate for the optimizer.")
|
|
103
|
+
parser.add_argument("--lr_scheduler_type", type=str, default=DEFAULT_LR_SCHEDULER_TYPE, help="Type of learning rate scheduler to use.")
|
|
104
|
+
parser.add_argument("--min_lr", type=float, default=DEFAULT_MIN_LR, help="Minimum learning rate for the scheduler.")
|
|
105
|
+
parser.add_argument("--warmup_ratio", type=float, default=DEFAULT_WARMUP_RATIO, help="Warmup ratio for learning rate scheduler.")
|
|
106
|
+
parser.add_argument("--warmup_steps", type=int, default=DEFAULT_WARMUP_STEPS, help="Number of warmup steps for learning rate scheduler (if set, steps override ratio).")
|
|
107
|
+
parser.add_argument("--max_grad_norm", type=float, default=1, help="Maximum gradient norm for gradient clipping (Default: 1).")
|
|
108
|
+
parser.add_argument("--resume_from_step", type=int, default=0, help="Step number to resume training from (if applicable). Will override warmup if steps are after warmup period.")
|
|
109
|
+
# Other Training Params
|
|
110
|
+
parser.add_argument("--logging_steps", type=int, default=DEFAULT_LOGGING_STEPS, help="Number of steps between logging training metrics.")
|
|
111
|
+
parser.add_argument("--save_steps", type=int, default=DEFAULT_SAVE_STEPS, help="Number of steps between model checkpoints.")
|
|
112
|
+
parser.add_argument("--eval_batch_size", type=int, default=DEFAULT_EVAL_BATCH_SIZE, help="Batch size for evaluation.")
|
|
113
|
+
parser.add_argument("--output_dir", type=str, default=None, help="Directory to save model checkpoints and logs.")
|
|
114
|
+
parser.set_defaults(train=True)
|
|
115
|
+
return parser.parse_args()
|
|
116
|
+
|
|
117
|
+
def main():
|
|
118
|
+
args = init_args()
|
|
119
|
+
|
|
120
|
+
# Dataset Params
|
|
121
|
+
dataset : str = args.dataset
|
|
122
|
+
text_field : str = args.text_field
|
|
123
|
+
text_pair_field : str = args.text_pair_field
|
|
124
|
+
label_field : str = args.label_field
|
|
125
|
+
negative_field : str = args.negative_field
|
|
126
|
+
create_test : bool = args.create_test
|
|
127
|
+
|
|
128
|
+
# Trainer / End Model Params
|
|
129
|
+
model_path : str = args.model_path
|
|
130
|
+
task_type : str = args.task_type
|
|
131
|
+
is_regression : bool = args.is_regression
|
|
132
|
+
use_late_interaction : bool = args.use_late_interaction
|
|
133
|
+
train : bool = args.train
|
|
134
|
+
use_chat_template : bool = args.use_chat_template
|
|
135
|
+
force_separator : str = args.force_separator
|
|
136
|
+
|
|
137
|
+
# Training Params
|
|
138
|
+
batch_size : int = args.batch_size
|
|
139
|
+
gradient_accumulation_steps : int = args.gradient_accumulation_steps
|
|
140
|
+
num_train_epochs : int = args.num_train_epochs
|
|
141
|
+
weight_decay : float = args.weight_decay
|
|
142
|
+
learning_rate : float = args.lr
|
|
143
|
+
lr_scheduler_type : str = args.lr_scheduler_type
|
|
144
|
+
min_lr : float = args.min_lr
|
|
145
|
+
warmup_ratio : float = args.warmup_ratio
|
|
146
|
+
warmup_steps : int = args.warmup_steps
|
|
147
|
+
logging_steps : int = args.logging_steps
|
|
148
|
+
save_steps : int = args.save_steps
|
|
149
|
+
eval_batch_size : int = args.eval_batch_size
|
|
150
|
+
resume_from_step : int = args.resume_from_step
|
|
151
|
+
max_length : int = args.max_length
|
|
152
|
+
freeze_embeddings : bool = args.freeze_embeddings
|
|
153
|
+
max_grad_norm : float = args.max_grad_norm
|
|
154
|
+
|
|
155
|
+
print(f"Training Mode : {train}")
|
|
156
|
+
|
|
157
|
+
if task_type not in PIPELINES:
|
|
158
|
+
raise ValueError(f"Task type {task_type} not supported. Choose from {PIPELINES.items()}")
|
|
159
|
+
|
|
160
|
+
output_dir : str = args.output_dir if args.output_dir else model_path.split("/")[-1] + "_" + task_type + "_" + time.strftime("%Y%m%d_%H%M%S")
|
|
161
|
+
|
|
162
|
+
# Load datasets
|
|
163
|
+
dataset_args = DatasetArgs(
|
|
164
|
+
data=dataset,
|
|
165
|
+
task_type=task_type,
|
|
166
|
+
text_field=text_field,
|
|
167
|
+
text_pair_field=text_pair_field,
|
|
168
|
+
label_field=label_field,
|
|
169
|
+
negative_field=negative_field,
|
|
170
|
+
test=create_test
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
train_dataset, valid_dataset, test_dataset, id2label, label2id = load_dataset(dataset_args)
|
|
174
|
+
|
|
175
|
+
model_config={}
|
|
176
|
+
if task_type == "text-classification" and is_regression:
|
|
177
|
+
model_config={"is_regression":True}
|
|
178
|
+
if use_late_interaction and task_type in ["sentence-transformers","sentence-similarity"]:
|
|
179
|
+
model_config["use_late_interaction"] = True
|
|
180
|
+
if id2label:
|
|
181
|
+
model_config["id2label"] = id2label
|
|
182
|
+
if label2id:
|
|
183
|
+
model_config["label2id"] = label2id
|
|
184
|
+
|
|
185
|
+
# Load model and tokenizer
|
|
186
|
+
model, tokenizer = load(
|
|
187
|
+
model_path,
|
|
188
|
+
model_config=model_config,
|
|
189
|
+
pipeline=task_type,
|
|
190
|
+
train=train,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# testing chat template
|
|
194
|
+
if use_chat_template:
|
|
195
|
+
messages = [
|
|
196
|
+
{"role": "user", "content": "test_prompt"},
|
|
197
|
+
{"role": "assistant", "content": "test_response"}
|
|
198
|
+
]
|
|
199
|
+
if not getattr(tokenizer, "chat_template", None) and FORCED_CHAT_TEMPLATE:
|
|
200
|
+
tokenizer.chat_template = FORCED_CHAT_TEMPLATE
|
|
201
|
+
|
|
202
|
+
templated = tokenizer.apply_chat_template(messages, tokenize=False)
|
|
203
|
+
print("Chat template working:", templated)
|
|
204
|
+
|
|
205
|
+
# Training arguments
|
|
206
|
+
training_args = TrainingArgs(
|
|
207
|
+
batch_size=batch_size,
|
|
208
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
209
|
+
max_length= max_length if max_length else model.config.max_position_embeddings,
|
|
210
|
+
resume_from_step=resume_from_step, # warmup will be ingnored if before this step and schedulers will only start after
|
|
211
|
+
num_train_epochs=num_train_epochs,
|
|
212
|
+
learning_rate=learning_rate,
|
|
213
|
+
weight_decay=weight_decay,
|
|
214
|
+
freeze_embeddings=freeze_embeddings,
|
|
215
|
+
warmup_ratio=warmup_ratio, # can use warmup_steps or warmup_ratio
|
|
216
|
+
warmup_steps=warmup_steps, # if both set, warmup_steps will be used
|
|
217
|
+
lr_scheduler_type=lr_scheduler_type, # would default to "constant", can also use "cosine_decay" or "linear_schedule"
|
|
218
|
+
min_lr=min_lr,
|
|
219
|
+
max_grad_norm=max_grad_norm,
|
|
220
|
+
save_steps=save_steps,
|
|
221
|
+
logging_steps=logging_steps, # will be adjusted to be multiple of gradient_accumulation_steps inside Trainer
|
|
222
|
+
eval_batch_size=eval_batch_size,
|
|
223
|
+
output_dir=output_dir,
|
|
224
|
+
save_total_limit=None,
|
|
225
|
+
grad_checkpoint=True,
|
|
226
|
+
push_to_hub=False,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Initialize trainer
|
|
230
|
+
trainer = Trainer(
|
|
231
|
+
model=model,
|
|
232
|
+
tokenizer=tokenizer,
|
|
233
|
+
task_type=task_type,
|
|
234
|
+
training_args=training_args,
|
|
235
|
+
train_dataset=train_dataset,
|
|
236
|
+
eval_dataset=valid_dataset,
|
|
237
|
+
use_chat_template=use_chat_template if task_type == "text-classification" else False,
|
|
238
|
+
force_separator=force_separator if task_type == "text-classification" else None,
|
|
239
|
+
label2id=label2id
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Train or evaluate
|
|
243
|
+
if train:
|
|
244
|
+
trainer.train()
|
|
245
|
+
if test_dataset:
|
|
246
|
+
trainer.test(test_dataset)
|
|
247
|
+
|
|
248
|
+
if __name__ == "__main__":
|
|
249
|
+
main()
|