mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,353 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+ import json
3
+ from functools import partial
4
+
5
+ from transformers import AutoTokenizer
6
+
7
+ REPLACEMENT_CHAR = "\ufffd"
8
+
9
+ def _remove_space(x):
10
+ if x and x[0] == " ":
11
+ return x[1:]
12
+ return x
13
+
14
+
15
+ class StreamingDetokenizer:
16
+ """The streaming detokenizer interface so that we can detokenize one token at a time.
17
+
18
+ Example usage is as follows:
19
+
20
+ detokenizer = ...
21
+
22
+ # Reset the tokenizer state
23
+ detokenizer.reset()
24
+
25
+ for token in generate(...):
26
+ detokenizer.add_token(token.item())
27
+
28
+ # Contains the whole text so far. Some tokens may not be included
29
+ # since it contains whole words usually.
30
+ detokenizer.text
31
+
32
+ # Contains the printable segment (usually a word) since the last
33
+ # time it was accessed
34
+ detokenizer.last_segment
35
+
36
+ # Contains all the tokens added so far
37
+ detokenizer.tokens
38
+
39
+ # Make sure that we detokenize any remaining tokens
40
+ detokenizer.finalize()
41
+
42
+ # Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
43
+ """
44
+
45
+ __slots__ = ("text", "tokens", "offset")
46
+
47
+ def reset(self):
48
+ raise NotImplementedError()
49
+
50
+ def add_token(self, token):
51
+ raise NotImplementedError()
52
+
53
+ def finalize(self):
54
+ raise NotImplementedError()
55
+
56
+ @property
57
+ def last_segment(self):
58
+ """Return the last segment of readable text since last time this property was accessed."""
59
+ text = self.text
60
+ if text and text[-1] != REPLACEMENT_CHAR:
61
+ segment = text[self.offset :]
62
+ self.offset = len(text)
63
+ return segment
64
+ return ""
65
+
66
+
67
+ class NaiveStreamingDetokenizer(StreamingDetokenizer):
68
+ """NaiveStreamingDetokenizer relies on the underlying tokenizer
69
+ implementation and should work with every tokenizer.
70
+
71
+ Its complexity is O(T^2) where T is the longest line since it will
72
+ repeatedly detokenize the same tokens until a new line is generated.
73
+ """
74
+
75
+ def __init__(self, tokenizer):
76
+ self._tokenizer = tokenizer
77
+ self._tokenizer.decode([0])
78
+ self.reset()
79
+
80
+ def reset(self):
81
+ self.offset = 0
82
+ self._tokens = []
83
+ self._text = ""
84
+ self._current_tokens = []
85
+ self._current_text = ""
86
+
87
+ def add_token(self, token):
88
+ self._current_tokens.append(token)
89
+
90
+ def finalize(self):
91
+ self._tokens.extend(self._current_tokens)
92
+ self._text += self._tokenizer.decode(self._current_tokens)
93
+ self._current_tokens = []
94
+ self._current_text = ""
95
+
96
+ @property
97
+ def text(self):
98
+ if self._current_tokens:
99
+ self._current_text = self._tokenizer.decode(self._current_tokens)
100
+ if (
101
+ self._tokenizer.clean_up_tokenization_spaces
102
+ and self._current_text[-1] == " "
103
+ ):
104
+ self._current_text = self._current_text[:-1]
105
+ if self._current_text and self._current_text[-1] == "\n":
106
+ self._tokens.extend(self._current_tokens)
107
+ self._text += self._current_text
108
+ self._current_tokens.clear()
109
+ self._current_text = ""
110
+ return self._text + self._current_text
111
+
112
+ @property
113
+ def tokens(self):
114
+ return self._tokens
115
+
116
+
117
+ class SPMStreamingDetokenizer(StreamingDetokenizer): ## SPM = SentencePiece Model, relevant for Llama and Gemma
118
+ """A streaming detokenizer for SPM models.
119
+
120
+ It adds tokens to the text if the next token starts with the special SPM
121
+ underscore which results in linear complexity.
122
+ """
123
+
124
+ def __init__(self, tokenizer, trim_space=True):
125
+ self.trim_space = trim_space
126
+
127
+ # Extract the tokens in a list from id to text
128
+ self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
129
+ for value, tokenid in tokenizer.vocab.items():
130
+ self.tokenmap[tokenid] = value
131
+
132
+ # Replace bytes with their value
133
+ for i in range(len(self.tokenmap)):
134
+ if self.tokenmap[i].startswith("<0x"):
135
+ self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
136
+
137
+ self.reset()
138
+
139
+ def reset(self):
140
+ self.offset = 0
141
+ self._unflushed = ""
142
+ self.text = ""
143
+ self.tokens = []
144
+
145
+ def add_token(self, token):
146
+ self.tokens.append(token) ### append token to tokens list, not sure why it wasn't there in the first place as reset() sets an empty list
147
+ v = self.tokenmap[token]
148
+ if v[0] == "\u2581":
149
+ if self.text or not self.trim_space:
150
+ self.text += self._unflushed.replace("\u2581", " ")
151
+ else:
152
+ self.text = _remove_space(self._unflushed.replace("\u2581", " "))
153
+ self._unflushed = v
154
+ else:
155
+ self._unflushed += v
156
+
157
+ def finalize(self):
158
+ if self.text or not self.trim_space:
159
+ self.text += self._unflushed.replace("\u2581", " ")
160
+ else:
161
+ self.text = _remove_space(self._unflushed.replace("\u2581", " "))
162
+ self._unflushed = ""
163
+
164
+
165
+ class BPEStreamingDetokenizer(StreamingDetokenizer): ## BPE = Byte Pair Encoding, relevant for GPT-2
166
+ """A streaming detokenizer for OpenAI style BPE models.
167
+
168
+ It adds tokens to the text if the next token starts with a space similar to
169
+ the SPM detokenizer.
170
+ """
171
+
172
+ _byte_decoder = None
173
+ _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
174
+
175
+ def __init__(self, tokenizer):
176
+
177
+ self.clean_spaces = tokenizer.clean_up_tokenization_spaces
178
+
179
+ # Extract the tokens in a list from id to text
180
+ self.tokenmap = [None] * len(tokenizer.vocab)
181
+ for value, tokenid in tokenizer.vocab.items():
182
+ self.tokenmap[tokenid] = value
183
+
184
+ self.reset()
185
+
186
+ # Make the BPE byte decoder from
187
+ # https://github.com/openai/gpt-2/blob/master/src/encoder.py
188
+ self.make_byte_decoder()
189
+
190
+ self._added_ids = set(tokenizer.added_tokens_decoder.keys())
191
+
192
+ def reset(self):
193
+ self.offset = 0
194
+ self._unflushed = ""
195
+ self.text = ""
196
+ self.tokens = []
197
+
198
+ def _maybe_trim_space(self, current_text):
199
+ if len(current_text) == 0:
200
+ return current_text
201
+ elif current_text[0] != " ":
202
+ return current_text
203
+ elif not self.text:
204
+ return current_text[1:]
205
+ elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
206
+ return current_text[1:]
207
+ return current_text
208
+
209
+ def add_token(self, token):
210
+ self.tokens.append(token) ### append token to tokens list, not sure why it wasn't there in the first place as reset() sets an empty list
211
+ v = self.tokenmap[token]
212
+ is_added = token in self._added_ids
213
+ if is_added or self._byte_decoder[v[0]] == 32:
214
+ current_text = bytearray(
215
+ self._byte_decoder[c] for c in self._unflushed
216
+ ).decode("utf-8")
217
+ self.text += self._maybe_trim_space(current_text)
218
+ if is_added:
219
+ self.text += v
220
+ self._unflushed = ""
221
+ else:
222
+ self._unflushed = v
223
+ else:
224
+ self._unflushed += v
225
+
226
+ def finalize(self):
227
+ current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
228
+ "utf-8"
229
+ )
230
+ self.text += self._maybe_trim_space(current_text)
231
+ self._unflushed = ""
232
+
233
+ @classmethod
234
+ def make_byte_decoder(cls):
235
+ """See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
236
+ if cls._byte_decoder is not None:
237
+ return
238
+
239
+ char_to_bytes = {}
240
+ limits = [
241
+ 0,
242
+ ord("!"),
243
+ ord("~") + 1,
244
+ ord("¡"),
245
+ ord("¬") + 1,
246
+ ord("®"),
247
+ ord("ÿ") + 1,
248
+ ]
249
+ n = 0
250
+ for i, (start, stop) in enumerate(zip(limits, limits[1:])):
251
+ if i % 2 == 0:
252
+ for b in range(start, stop):
253
+ char_to_bytes[chr(2**8 + n)] = b
254
+ n += 1
255
+ else:
256
+ for b in range(start, stop):
257
+ char_to_bytes[chr(b)] = b
258
+ cls._byte_decoder = char_to_bytes
259
+
260
+
261
+ class TokenizerWrapper:
262
+ """A wrapper that combines an HF tokenizer and a detokenizer.
263
+
264
+ Accessing any attribute other than the ``detokenizer`` is forwarded to the
265
+ huggingface tokenizer.
266
+ """
267
+
268
+ def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
269
+ self._tokenizer = tokenizer
270
+ self._detokenizer = detokenizer_class(tokenizer)
271
+
272
+ def __getattr__(self, attr):
273
+ if attr == "detokenizer":
274
+ return self._detokenizer
275
+ elif attr.startswith("_"):
276
+ return self.__getattribute__(attr)
277
+ else:
278
+ return getattr(self._tokenizer, attr)
279
+
280
+ def __setattr__(self, attr, value):
281
+ if attr == "detokenizer":
282
+ raise AttributeError("Cannot set the detokenizer.")
283
+ elif attr.startswith("_"):
284
+ super().__setattr__(attr, value)
285
+ else:
286
+ setattr(self._tokenizer, attr, value)
287
+
288
+
289
+ def _match(a, b):
290
+ if type(a) != type(b):
291
+ return False
292
+ if isinstance(a, dict):
293
+ return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
294
+ if isinstance(a, list):
295
+ return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
296
+
297
+ return a == b
298
+
299
+
300
+ def _is_spm_decoder(decoder):
301
+ _target_description = {
302
+ "type": "Sequence",
303
+ "decoders": [
304
+ {"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
305
+ {"type": "ByteFallback"},
306
+ {"type": "Fuse"},
307
+ {"type": "Strip", "content": " ", "start": 1, "stop": 0},
308
+ ],
309
+ }
310
+ return _match(_target_description, decoder)
311
+
312
+
313
+ def _is_spm_decoder_no_space(decoder):
314
+ _target_description = {
315
+ "type": "Sequence",
316
+ "decoders": [
317
+ {"type": "Replace", "pattern": {"String": "▁"}, "content": " "},
318
+ {"type": "ByteFallback"},
319
+ {"type": "Fuse"},
320
+ ],
321
+ }
322
+ return _match(_target_description, decoder)
323
+
324
+
325
+ def _is_bpe_decoder(decoder):
326
+ return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
327
+
328
+
329
+ def load_tokenizer(model_path, tokenizer_config_extra={}):
330
+ """Load a huggingface tokenizer and try to infer the type of streaming
331
+ detokenizer to use.
332
+
333
+ Note, to use a fast streaming tokenizer, pass a local file path rather than
334
+ a Hugging Face repo ID.
335
+ """
336
+ detokenizer_class = NaiveStreamingDetokenizer
337
+
338
+ tokenizer_file = model_path / "tokenizer.json"
339
+ if tokenizer_file.exists():
340
+ with open(tokenizer_file, "r") as fid:
341
+ tokenizer_content = json.load(fid)
342
+ if "decoder" in tokenizer_content:
343
+ if _is_spm_decoder(tokenizer_content["decoder"]):
344
+ detokenizer_class = SPMStreamingDetokenizer
345
+ elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
346
+ detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
347
+ elif _is_bpe_decoder(tokenizer_content["decoder"]):
348
+ detokenizer_class = BPEStreamingDetokenizer
349
+
350
+ return TokenizerWrapper(
351
+ AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
352
+ detokenizer_class,
353
+ )
@@ -0,0 +1,249 @@
1
+ import argparse
2
+ import time
3
+
4
+ from mlx_raclate.utils.utils import load, PIPELINES
5
+ from mlx_raclate.tuner.datasets import load_dataset, DatasetArgs
6
+ from mlx_raclate.tuner.trainer import Trainer, TrainingArgs
7
+
8
+ train_tested = {
9
+ "text-classification": [
10
+ {"model": "Qwen/Qwen3-Embedding-0.6B", "special_model_config" : {}, "special_trainer_config" : {"use_chat_template": True}, "special_training_args" : {"max_length":8192}},
11
+ {"model": "answerdotai/ModernBERT-base", "special_model_config" : {}, "special_training_args" : {}},
12
+ {"model": "LiquidAI/LFM2-350M", "special_model_config" : {"use_chat_template": True}, "special_training_args" : {}},
13
+ {"model": "google/t5gemma-b-b-ul2", "special_model_config" : {}, "special_training_args" : {"max_length":8192}}, # failed
14
+ {"model": "google/embeddinggemma-300m", "special_model_config" : {}, "special_training_args" : {}} # failed
15
+ ],
16
+ }
17
+
18
+ # LFM2 CHAT TEMPLATE
19
+ FORCED_CHAT_TEMPLATE = """
20
+ {- bos_token -}}
21
+ {%- set system_prompt = "" -%}
22
+ {%- set ns = namespace(system_prompt="") -%}
23
+ {%- if messages[0]["role"] == "system" -%}
24
+ {%- set ns.system_prompt = messages[0]["content"] -%}
25
+ {%- set messages = messages[1:] -%}
26
+ {%- endif -%}
27
+ {%- if tools -%}
28
+ {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%}
29
+ {%- for tool in tools -%}
30
+ {%- if tool is not string -%}
31
+ {%- set tool = tool | tojson -%}
32
+ {%- endif -%}
33
+ {%- set ns.system_prompt = ns.system_prompt + tool -%}
34
+ {%- if not loop.last -%}
35
+ {%- set ns.system_prompt = ns.system_prompt + ", " -%}
36
+ {%- endif -%}
37
+ {%- endfor -%}
38
+ {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%}
39
+ {%- endif -%}
40
+ {%- if ns.system_prompt -%}
41
+ {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
42
+ {%- endif -%}
43
+ {%- for message in messages -%}
44
+ {{- "<|im_start|>" + message["role"] + "\n" -}}
45
+ {%- set content = message["content"] -%}
46
+ {%- if content is not string -%}
47
+ {%- set content = content | tojson -%}
48
+ {%- endif -%}
49
+ {%- if message["role"] == "tool" -%}
50
+ {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%}
51
+ {%- endif -%}
52
+ {{- content + "<|im_end|>\n" -}}
53
+ {%- endfor -%}
54
+ {%- if add_generation_prompt -%}
55
+ {{- "<|im_start|>assistant\n" -}}
56
+ {%- endif -%}
57
+ """
58
+
59
+ DEFAULT_MODEL_PATH : str = "./trained_models/Qwen3-Embedding-0.6B_text-classification_20251219_001137/checkpoint-39939" #"Qwen/Qwen3-Embedding-0.6B" "answerdotai/ModernBERT-base" "google/t5gemma-b-b-ul2"
60
+ DEFAULT_DATASET : str = "data/wines" # can be a local path or HF "argilla/synthetic-domain-text-classification" "data/20251205_1125"
61
+ DEFAULT_TASK_TYPE : str = "text-classification"
62
+ DEFAULT_BATCH_SIZE : int = 8
63
+ DEFAULT_GRADIENT_ACCUMULATION_STEPS : int = 8
64
+ DEFAULT_TRAIN_EPOCHS : int = 2
65
+ DEFAULT_WEIGHT_DECAY : float = 0.01
66
+ DEFAULT_LR : float = 2e-5 # 3e-5 for ModernBERT, 5e-5 for T5Gemma, 1e-5 for Qwen
67
+ DEFAULT_LR_SCHEDULER_TYPE : str = "linear_schedule"
68
+ DEFAULT_MIN_LR : float = 2e-6
69
+ DEFAULT_WARMUP_RATIO : float = 0.03
70
+ DEFAULT_WARMUP_STEPS : int = 0
71
+ DEFAULT_SAVE_STEPS : int = 5000
72
+ DEFAULT_LOGGING_STEPS : int = 64
73
+ DEFAULT_EVAL_BATCH_SIZE : int = 8
74
+
75
+ def init_args():
76
+ parser = argparse.ArgumentParser(description="Train or evaluate a classification model using MLX Raclate.")
77
+ # Dataset Init Params
78
+ parser.add_argument("--dataset", type=str, default=DEFAULT_DATASET, help="Local path or HF identifier of the dataset to use for training/evaluation.")
79
+ parser.add_argument("--text_field", type=str, default=None, help="Name of the text field in the dataset (if different from default).")
80
+ parser.add_argument("--text_pair_field", type=str, default=None, help="Name of the text pair field in the dataset (if applicable and different from default).")
81
+ parser.add_argument("--label_field", type=str, default=None, help="Name of the label field in the dataset (if different from default).")
82
+ parser.add_argument("--negative_field", type=str, default=None, help="Name of the negative samples field in the dataset (if applicable and different from default).")
83
+ parser.add_argument("--create_test", action='store_true', help="Set this flag to create a test split, if not already present in the dataset, out of the training set (validation set not affected).")
84
+
85
+ # Trainer / End Model Init Params
86
+ parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH, help="Path to the pre-trained model or model identifier from a model hub.")
87
+ parser.add_argument("--task_type", type=str, default=DEFAULT_TASK_TYPE, help="Type of task (default: text-classification).")
88
+ parser.add_argument("--is_regression", default=False, action='store_true', help="Set this flag if the task is regression.")
89
+ parser.add_argument("--use_late_interaction", default=False, action='store_true', help="Set this flag to use late interaction for retrieval tasks (if applicable).")
90
+ parser.add_argument("--eval_only", dest="train", action='store_false', help="Set this flag to skip training and only evaluate.")
91
+ parser.add_argument("--use_chat_template", default=False, action='store_true', help="Use chat template for decoder models when there are text pairs.")
92
+ parser.add_argument("--force_separator", type=str, default=None, help="Force a specific separator between text pairs for decoder models, if not using chat template.")
93
+
94
+ # Training Params
95
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Batch size for training.")
96
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=DEFAULT_GRADIENT_ACCUMULATION_STEPS, help="Number of gradient accumulation steps.")
97
+ parser.add_argument("--num_train_epochs", type=int, default=DEFAULT_TRAIN_EPOCHS, help="Number of training epochs.")
98
+ parser.add_argument("--max_length", type=int, default=None, help="Maximum sequence length for the model inputs. If not specified, the model's default max length will be used.")
99
+ parser.add_argument("--freeze_embeddings", default=False, action='store_true', help="Set this flag to freeze embedding layers during training.")
100
+ parser.add_argument("--weight_decay", type=float, default=DEFAULT_WEIGHT_DECAY, help="Weight decay for the optimizer.")
101
+ # Optimizer and Scheduler Params (AdamW + schedulers)
102
+ parser.add_argument("--lr", type=float, default=DEFAULT_LR, help="Initial learning rate for the optimizer.")
103
+ parser.add_argument("--lr_scheduler_type", type=str, default=DEFAULT_LR_SCHEDULER_TYPE, help="Type of learning rate scheduler to use.")
104
+ parser.add_argument("--min_lr", type=float, default=DEFAULT_MIN_LR, help="Minimum learning rate for the scheduler.")
105
+ parser.add_argument("--warmup_ratio", type=float, default=DEFAULT_WARMUP_RATIO, help="Warmup ratio for learning rate scheduler.")
106
+ parser.add_argument("--warmup_steps", type=int, default=DEFAULT_WARMUP_STEPS, help="Number of warmup steps for learning rate scheduler (if set, steps override ratio).")
107
+ parser.add_argument("--max_grad_norm", type=float, default=1, help="Maximum gradient norm for gradient clipping (Default: 1).")
108
+ parser.add_argument("--resume_from_step", type=int, default=0, help="Step number to resume training from (if applicable). Will override warmup if steps are after warmup period.")
109
+ # Other Training Params
110
+ parser.add_argument("--logging_steps", type=int, default=DEFAULT_LOGGING_STEPS, help="Number of steps between logging training metrics.")
111
+ parser.add_argument("--save_steps", type=int, default=DEFAULT_SAVE_STEPS, help="Number of steps between model checkpoints.")
112
+ parser.add_argument("--eval_batch_size", type=int, default=DEFAULT_EVAL_BATCH_SIZE, help="Batch size for evaluation.")
113
+ parser.add_argument("--output_dir", type=str, default=None, help="Directory to save model checkpoints and logs.")
114
+ parser.set_defaults(train=True)
115
+ return parser.parse_args()
116
+
117
+ def main():
118
+ args = init_args()
119
+
120
+ # Dataset Params
121
+ dataset : str = args.dataset
122
+ text_field : str = args.text_field
123
+ text_pair_field : str = args.text_pair_field
124
+ label_field : str = args.label_field
125
+ negative_field : str = args.negative_field
126
+ create_test : bool = args.create_test
127
+
128
+ # Trainer / End Model Params
129
+ model_path : str = args.model_path
130
+ task_type : str = args.task_type
131
+ is_regression : bool = args.is_regression
132
+ use_late_interaction : bool = args.use_late_interaction
133
+ train : bool = args.train
134
+ use_chat_template : bool = args.use_chat_template
135
+ force_separator : str = args.force_separator
136
+
137
+ # Training Params
138
+ batch_size : int = args.batch_size
139
+ gradient_accumulation_steps : int = args.gradient_accumulation_steps
140
+ num_train_epochs : int = args.num_train_epochs
141
+ weight_decay : float = args.weight_decay
142
+ learning_rate : float = args.lr
143
+ lr_scheduler_type : str = args.lr_scheduler_type
144
+ min_lr : float = args.min_lr
145
+ warmup_ratio : float = args.warmup_ratio
146
+ warmup_steps : int = args.warmup_steps
147
+ logging_steps : int = args.logging_steps
148
+ save_steps : int = args.save_steps
149
+ eval_batch_size : int = args.eval_batch_size
150
+ resume_from_step : int = args.resume_from_step
151
+ max_length : int = args.max_length
152
+ freeze_embeddings : bool = args.freeze_embeddings
153
+ max_grad_norm : float = args.max_grad_norm
154
+
155
+ print(f"Training Mode : {train}")
156
+
157
+ if task_type not in PIPELINES:
158
+ raise ValueError(f"Task type {task_type} not supported. Choose from {PIPELINES.items()}")
159
+
160
+ output_dir : str = args.output_dir if args.output_dir else model_path.split("/")[-1] + "_" + task_type + "_" + time.strftime("%Y%m%d_%H%M%S")
161
+
162
+ # Load datasets
163
+ dataset_args = DatasetArgs(
164
+ data=dataset,
165
+ task_type=task_type,
166
+ text_field=text_field,
167
+ text_pair_field=text_pair_field,
168
+ label_field=label_field,
169
+ negative_field=negative_field,
170
+ test=create_test
171
+ )
172
+
173
+ train_dataset, valid_dataset, test_dataset, id2label, label2id = load_dataset(dataset_args)
174
+
175
+ model_config={}
176
+ if task_type == "text-classification" and is_regression:
177
+ model_config={"is_regression":True}
178
+ if use_late_interaction and task_type in ["sentence-transformers","sentence-similarity"]:
179
+ model_config["use_late_interaction"] = True
180
+ if id2label:
181
+ model_config["id2label"] = id2label
182
+ if label2id:
183
+ model_config["label2id"] = label2id
184
+
185
+ # Load model and tokenizer
186
+ model, tokenizer = load(
187
+ model_path,
188
+ model_config=model_config,
189
+ pipeline=task_type,
190
+ train=train,
191
+ )
192
+
193
+ # testing chat template
194
+ if use_chat_template:
195
+ messages = [
196
+ {"role": "user", "content": "test_prompt"},
197
+ {"role": "assistant", "content": "test_response"}
198
+ ]
199
+ if not getattr(tokenizer, "chat_template", None) and FORCED_CHAT_TEMPLATE:
200
+ tokenizer.chat_template = FORCED_CHAT_TEMPLATE
201
+
202
+ templated = tokenizer.apply_chat_template(messages, tokenize=False)
203
+ print("Chat template working:", templated)
204
+
205
+ # Training arguments
206
+ training_args = TrainingArgs(
207
+ batch_size=batch_size,
208
+ gradient_accumulation_steps=gradient_accumulation_steps,
209
+ max_length= max_length if max_length else model.config.max_position_embeddings,
210
+ resume_from_step=resume_from_step, # warmup will be ingnored if before this step and schedulers will only start after
211
+ num_train_epochs=num_train_epochs,
212
+ learning_rate=learning_rate,
213
+ weight_decay=weight_decay,
214
+ freeze_embeddings=freeze_embeddings,
215
+ warmup_ratio=warmup_ratio, # can use warmup_steps or warmup_ratio
216
+ warmup_steps=warmup_steps, # if both set, warmup_steps will be used
217
+ lr_scheduler_type=lr_scheduler_type, # would default to "constant", can also use "cosine_decay" or "linear_schedule"
218
+ min_lr=min_lr,
219
+ max_grad_norm=max_grad_norm,
220
+ save_steps=save_steps,
221
+ logging_steps=logging_steps, # will be adjusted to be multiple of gradient_accumulation_steps inside Trainer
222
+ eval_batch_size=eval_batch_size,
223
+ output_dir=output_dir,
224
+ save_total_limit=None,
225
+ grad_checkpoint=True,
226
+ push_to_hub=False,
227
+ )
228
+
229
+ # Initialize trainer
230
+ trainer = Trainer(
231
+ model=model,
232
+ tokenizer=tokenizer,
233
+ task_type=task_type,
234
+ training_args=training_args,
235
+ train_dataset=train_dataset,
236
+ eval_dataset=valid_dataset,
237
+ use_chat_template=use_chat_template if task_type == "text-classification" else False,
238
+ force_separator=force_separator if task_type == "text-classification" else None,
239
+ label2id=label2id
240
+ )
241
+
242
+ # Train or evaluate
243
+ if train:
244
+ trainer.train()
245
+ if test_dataset:
246
+ trainer.test(test_dataset)
247
+
248
+ if __name__ == "__main__":
249
+ main()